torch_geometric.nn.conv.FusedGATConv
- class FusedGATConv(*args, **kwargs)[source]
基础类:
GATConv来自“理解GNN计算图:从计算、IO和内存的协调视角”论文的融合图注意力操作符。
FusedGATConv是GATConv的优化版本,基于dgNN包,它融合了消息传递计算以加速执行并减少内存占用。注意
此实现基于
dgNN包。 有关如何安装的说明,请参见此处。- forward(x: Tensor, csr: Tuple[Tensor, Tensor], csc: Tuple[Tensor, Tensor], perm: Tensor) Tensor[source]
运行模块的前向传播。
- Parameters:
x (torch.Tensor) – The node features.
csr ((torch.Tensor, torch.Tensor)) – 一个包含图的CSR表示的元组,表示为
(rowptr, col)的元组。csc ((torch.Tensor, torch.Tensor)) – 一个包含图CSC表示的元组,表示为
(row, colptr)的元组。perm (torch.Tensor) – 用于将CSR表示映射到CSC表示的置换张量。
注意
使用
to_graph_format()方法从现有的edge_index表示中获取(csr, csc, perm)图格式。- Return type:
- reset_parameters()
重置模块的所有可学习参数。