torch_geometric.nn.conv.MessagePassing

class MessagePassing(aggr: Optional[Union[str, List[str], Aggregation]] = 'sum', *, aggr_kwargs: Optional[Dict[str, Any]] = None, flow: str = 'source_to_target', node_dim: int = -2, decomposed_layers: int = 1)[source]

Bases: Module

用于创建消息传递层的基类。

消息传递层遵循以下形式

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]

其中 \(\bigoplus\) 表示一个可微的、排列不变的函数,例如,求和、均值、最小值、最大值或乘积,而 \(\gamma_{\mathbf{\Theta}}\)\(\phi_{\mathbf{\Theta}}\) 表示可微函数,例如多层感知机(MLPs)。请参阅 这里 获取相关教程。

Parameters:
  • aggr (str[str] 或 Aggregation, 可选) – 使用的聚合方案,例如"sum" "mean", "min", "max""mul"。 此外,可以是任何 Aggregation 模块(或任何自动解析为它的字符串)。 如果以列表形式给出,将使用多个聚合,其中不同的输出将在最后一个维度上连接。 如果设置为 None,则 MessagePassing 实例化应通过 aggregate() 实现其自己的聚合逻辑。(默认值:"add"

  • aggr_kwargs (Dict[str, Any], optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: None)

  • flow (str, 可选) – 消息传递的流向 ("source_to_target""target_to_source"). (默认: "source_to_target")

  • node_dim (int, optional) – 沿此轴传播。 (default: -2)

  • decomposed_layers (int, optional) – 特征分解层数,如“优化边缘计算平台上图神经网络的内存效率”论文中所述。特征分解通过在GNN聚合期间将特征维度切片为独立的特征分解层来减少峰值内存使用。这种方法可以加速基于CPU平台上的GNN执行(例如,在Reddit数据集上实现2-3倍的加速),适用于常见的GNN模型,如GCNGraphSAGEGIN等。然而,这种方法并不适用于所有可用的GNN操作符,特别是那些消息计算不易分解的操作符,例如基于注意力的GNN。选择decomposed_layers的最佳值取决于具体的图数据集和可用的硬件资源。在大多数情况下,值为2是合适的。尽管峰值内存使用与特征分解的粒度直接相关,但执行速度的提升并不一定如此。(默认值:1

reset_parameters() None[source]

重置模块的所有可学习参数。

Return type:

None

forward(*args: Any, **kwargs: Any) Any[source]

运行模块的前向传播。

Return type:

Any

propagate(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs: Any) Tensor[source]

初始调用以开始传播消息。

Parameters:
  • edge_index (torch.TensorSparseTensor) – 一个 torch.Tensor, 一个 torch_sparse.SparseTensor 或一个 torch.sparse.Tensor,用于定义底层的 图连接/消息传递流程。 edge_index 保存了一个形状为 [N, M] 的通用(稀疏) 分配矩阵的索引。 如果 edge_index 是一个 torch.Tensor,它的 dtype 应该是 torch.long,并且它的形状需要定义为 [2, num_messages],其中来自 edge_index[0] 的节点的消息 被发送到 edge_index[1] 中的节点 (在 flow="source_to_target" 的情况下)。 如果 edge_index 是一个 torch_sparse.SparseTensor 或 一个 torch.sparse.Tensor,它的稀疏索引 (row, col) 应该与 row = edge_index[1]col = edge_index[0] 相关。 两种格式之间的主要区别在于我们需要将 转置 的稀疏邻接矩阵输入到 propagate() 中。

  • size ((int, int), optional) – 如果 edge_index 是一个 torch.Tensor,则分配矩阵的大小为 (N, M)。 如果设置为 None,大小将自动推断并假定为二次方。 如果 edge_indextorch_sparse.SparseTensortorch.sparse.Tensor,则忽略此参数。(默认值: None)

  • **kwargs – 任何用于构建和聚合消息以及更新节点嵌入的额外数据。

Return type:

Tensor

message(x_j: Tensor) Tensor[source]

从节点 \(j\) 向节点 \(i\) 构造消息,类似于 \(\phi_{\mathbf{\Theta}}\) 对于 edge_index 中的每条边。 此函数可以接受最初传递给 propagate() 的任何参数作为输入。 此外,传递给 propagate() 的张量可以通过在变量名后附加 _i_j 来映射到相应的节点 \(i\)\(j\)例如 x_ix_j

Return type:

Tensor

aggregate(inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) Tensor[source]

从邻居聚合消息为 \(\bigoplus_{j \in \mathcal{N}(i)}\)

将消息计算的输出作为第一个参数,并将最初传递给propagate()的任何参数作为其他参数。

默认情况下,此函数会将其调用委托给底层的 Aggregation 模块,以按照 __init__() 中由 aggr 参数指定的方式减少消息。

Return type:

Tensor

abstract message_and_aggregate(edge_index: Union[Tensor, SparseTensor]) Tensor[source]

message()aggregate()的计算融合到一个函数中。 如果适用,这将节省时间和内存,因为消息不需要显式地具体化。 只有在实现此函数并且传播基于torch_sparse.SparseTensortorch.sparse.Tensor时,才会调用此函数。

Return type:

Tensor

update(inputs: Tensor) Tensor[source]

更新节点嵌入,类似于每个节点 \(i \in \mathcal{V}\)\(\gamma_{\mathbf{\Theta}}\)。 将聚合的输出作为第一个参数,并将最初传递给 propagate() 的任何参数作为其他参数。

Return type:

Tensor

edge_updater(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs: Any) Tensor[source]

对图中每条边进行特征计算或更新的初始调用。

Parameters:
  • edge_index (torch.TensorSparseTensor) – 一个 torch.Tensor,一个 torch_sparse.SparseTensor 或一个 torch.sparse.Tensor,用于定义底层图的连接性/消息传递流程。 更多信息请参见 propagate()

  • size ((int, int), optional) – 如果 edge_index 是一个 torch.Tensor,则指定分配矩阵的大小 (N, M)。 如果设置为 None,大小将自动推断并假定为二次方。 如果 edge_index 是一个 torch_sparse.SparseTensor 或 一个 torch.sparse.Tensor,则忽略此参数。(默认值: None)

  • **kwargs – 任何用于计算或更新图中每条边特征所需的额外数据。

Return type:

Tensor

abstract edge_update() Tensor[source]

计算或更新图中每条边的特征。 此函数可以接受最初传递给edge_updater()的任何参数作为输入。 此外,传递给edge_updater()的张量可以通过在变量名后附加_i_j映射到相应的节点\(i\)\(j\)例如 x_ix_j

Return type:

Tensor

register_propagate_forward_pre_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向预钩子。

每次在调用propagate()之前,钩子都会被调用。 它应具有以下签名:

hook(module, inputs) -> None or modified input

钩子可以修改输入。 输入关键字参数作为字典传递给钩子,位于 inputs[-1]

返回一个torch.utils.hooks.RemovableHandle,可以通过调用handle.remove()来移除添加的钩子。

Return type:

RemovableHandle

register_propagate_forward_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向钩子。

每次在propagate()计算输出后,钩子将被调用。 它应具有以下签名:

hook(module, inputs, output) -> None or modified output

钩子可以修改输出。 输入的关键字参数作为字典传递给钩子,位于 inputs[-1]

返回一个torch.utils.hooks.RemovableHandle,可以通过调用handle.remove()来移除添加的钩子。

Return type:

RemovableHandle

register_message_forward_pre_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向预钩子。 每次在调用message()之前,都会调用这个钩子。 更多信息请参见register_propagate_forward_pre_hook()

Return type:

RemovableHandle

register_message_forward_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向钩子。 每次在message()计算输出后,钩子将被调用。 有关更多信息,请参见register_propagate_forward_hook()

Return type:

RemovableHandle

register_aggregate_forward_pre_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向预钩子。 每次在调用aggregate()之前,都会调用这个钩子。 更多信息请参见register_propagate_forward_pre_hook()

Return type:

RemovableHandle

register_aggregate_forward_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向钩子。 每次在aggregate()计算输出后,钩子将被调用。 有关更多信息,请参见register_propagate_forward_hook()

Return type:

RemovableHandle

register_message_and_aggregate_forward_pre_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向预钩子。 每次在调用message_and_aggregate()之前,都会调用这个钩子。 更多信息请参见register_propagate_forward_pre_hook()

Return type:

RemovableHandle

register_message_and_aggregate_forward_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向钩子。 每次在message_and_aggregate()计算输出后,钩子将被调用。 有关更多信息,请参见register_propagate_forward_hook()

Return type:

RemovableHandle

register_edge_update_forward_pre_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向预钩子。 每次在调用edge_update()之前,都会调用这个钩子。有关更多信息,请参见register_propagate_forward_pre_hook()

Return type:

RemovableHandle

register_edge_update_forward_hook(hook: Callable) RemovableHandle[source]

在模块上注册一个前向钩子。 每次在edge_update()计算输出后,钩子将被调用。 有关更多信息,请参见register_propagate_forward_hook()

Return type:

RemovableHandle

jittable(typing: Optional[str] = None) MessagePassing[source]

分析MessagePassing实例并生成一个新的可JIT编译的模块,该模块可以与torch.jit.script()结合使用。

注意

jittable() 已弃用,并且从 2.5 开始不再有效。

Return type:

MessagePassing