torch_geometric.nn.conv.CuGraphGATConv
- class CuGraphGATConv(in_channels: int, out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, bias: bool = True)[source]
Bases:
CuGraphModule来自“Graph Attention Networks”论文的图注意力操作符。
CuGraphGATConv是基于cugraph-ops包的优化版本,它融合了消息传递计算以加速执行并减少内存占用。- forward(x: Tensor, edge_index: EdgeIndex, max_num_neighbors: Optional[int] = None) Tensor[source]
运行模块的前向传播。
- Parameters:
x (torch.Tensor) – The node features.
edge_index (EdgeIndex) – The edge indices.
max_num_neighbors (int, optional) – The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph. When not given, the value will be computed on-the-fly, leading to slightly worse performance. (default:
None)
- Return type: