torch_geometric.nn.conv.CuGraphSAGEConv

class CuGraphSAGEConv(in_channels: int, out_channels: int, aggr: str = 'mean', normalize: bool = False, root_weight: bool = True, project: bool = False, bias: bool = True)[source]

基础类:CuGraphModule

来自“大型图上的归纳表示学习”论文的GraphSAGE操作符。

CuGraphSAGEConvSAGEConv 的优化版本,基于 cugraph-ops 包,该包融合了消息传递计算以加速执行并减少内存占用。

forward(x: Tensor, edge_index: EdgeIndex, max_num_neighbors: Optional[int] = None) Tensor[source]

运行模块的前向传播。

Parameters:
  • x (torch.Tensor) – 节点特征。

  • edge_index (EdgeIndex) – 边的索引。

  • max_num_neighbors (int, optional) – 目标节点的最大邻居数。仅在二分图中操作时有效。如果未提供该值,将在运行时计算,这会导致性能略有下降。 (default: None)

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。