torch_geometric.nn.conv.FusedGATConv

class FusedGATConv(*args, **kwargs)[source]

基础类:GATConv

来自“理解GNN计算图:从计算、IO和内存的协调视角”论文的融合图注意力操作符。

FusedGATConvGATConv 的优化版本,基于 dgNN 包,它融合了消息传递计算以加速执行并减少内存占用。

注意

此实现基于dgNN包。 有关如何安装的说明,请参见此处

forward(x: Tensor, csr: Tuple[Tensor, Tensor], csc: Tuple[Tensor, Tensor], perm: Tensor) Tensor[source]

运行模块的前向传播。

Parameters:

注意

使用 to_graph_format() 方法从现有的 edge_index 表示中获取 (csr, csc, perm) 图格式。

Return type:

Tensor

reset_parameters()

重置模块的所有可学习参数。

static to_graph_format(edge_index: Tensor, size: Optional[Tuple[int, int]] = None) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor][source]

将图的edge_index表示转换为FusedGATConv所需的输入格式。

Parameters:
  • edge_index (torch.Tensor) – The edge indices.

  • size ((int, int), optional) – 每个维度中edge_index的形状。(默认值:None

Return type:

Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor]