torch_geometric.nn.conv.MessagePassing
- class MessagePassing(aggr: Optional[Union[str, List[str], Aggregation]] = 'sum', *, aggr_kwargs: Optional[Dict[str, Any]] = None, flow: str = 'source_to_target', node_dim: int = -2, decomposed_layers: int = 1)[source]
Bases:
Module用于创建消息传递层的基类。
消息传递层遵循以下形式
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]其中 \(\bigoplus\) 表示一个可微的、排列不变的函数,例如,求和、均值、最小值、最大值或乘积,而 \(\gamma_{\mathbf{\Theta}}\) 和 \(\phi_{\mathbf{\Theta}}\) 表示可微函数,例如多层感知机(MLPs)。请参阅 这里 获取相关教程。
- Parameters:
aggr (str 或 [str] 或 Aggregation, 可选) – 使用的聚合方案,例如,
"sum""mean","min","max"或"mul"。 此外,可以是任何Aggregation模块(或任何自动解析为它的字符串)。 如果以列表形式给出,将使用多个聚合,其中不同的输出将在最后一个维度上连接。 如果设置为None,则MessagePassing实例化应通过aggregate()实现其自己的聚合逻辑。(默认值:"add")aggr_kwargs (Dict[str, Any], optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default:
None)flow (str, 可选) – 消息传递的流向 (
"source_to_target"或"target_to_source"). (默认:"source_to_target")node_dim (int, optional) – 沿此轴传播。 (default:
-2)decomposed_layers (int, optional) – 特征分解层数,如“优化边缘计算平台上图神经网络的内存效率”论文中所述。特征分解通过在GNN聚合期间将特征维度切片为独立的特征分解层来减少峰值内存使用。这种方法可以加速基于CPU平台上的GNN执行(例如,在
Reddit数据集上实现2-3倍的加速),适用于常见的GNN模型,如GCN、GraphSAGE、GIN等。然而,这种方法并不适用于所有可用的GNN操作符,特别是那些消息计算不易分解的操作符,例如基于注意力的GNN。选择decomposed_layers的最佳值取决于具体的图数据集和可用的硬件资源。在大多数情况下,值为2是合适的。尽管峰值内存使用与特征分解的粒度直接相关,但执行速度的提升并不一定如此。(默认值:1)
- propagate(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs: Any) Tensor[source]
初始调用以开始传播消息。
- Parameters:
edge_index (torch.Tensor 或 SparseTensor) – 一个
torch.Tensor, 一个torch_sparse.SparseTensor或一个torch.sparse.Tensor,用于定义底层的 图连接/消息传递流程。edge_index保存了一个形状为[N, M]的通用(稀疏) 分配矩阵的索引。 如果edge_index是一个torch.Tensor,它的dtype应该是torch.long,并且它的形状需要定义为[2, num_messages],其中来自edge_index[0]的节点的消息 被发送到edge_index[1]中的节点 (在flow="source_to_target"的情况下)。 如果edge_index是一个torch_sparse.SparseTensor或 一个torch.sparse.Tensor,它的稀疏索引(row, col)应该与row = edge_index[1]和col = edge_index[0]相关。 两种格式之间的主要区别在于我们需要将 转置 的稀疏邻接矩阵输入到propagate()中。size ((int, int), optional) – 如果
edge_index是一个torch.Tensor,则分配矩阵的大小为(N, M)。 如果设置为None,大小将自动推断并假定为二次方。 如果edge_index是torch_sparse.SparseTensor或torch.sparse.Tensor,则忽略此参数。(默认值:None)**kwargs – 任何用于构建和聚合消息以及更新节点嵌入的额外数据。
- Return type:
- message(x_j: Tensor) Tensor[source]
从节点 \(j\) 向节点 \(i\) 构造消息,类似于 \(\phi_{\mathbf{\Theta}}\) 对于
edge_index中的每条边。 此函数可以接受最初传递给propagate()的任何参数作为输入。 此外,传递给propagate()的张量可以通过在变量名后附加_i或_j来映射到相应的节点 \(i\) 和 \(j\),例如x_i和x_j。- Return type:
- aggregate(inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) Tensor[source]
从邻居聚合消息为 \(\bigoplus_{j \in \mathcal{N}(i)}\)。
将消息计算的输出作为第一个参数,并将最初传递给
propagate()的任何参数作为其他参数。默认情况下,此函数会将其调用委托给底层的
Aggregation模块,以按照__init__()中由aggr参数指定的方式减少消息。- Return type:
- abstract message_and_aggregate(edge_index: Union[Tensor, SparseTensor]) Tensor[source]
将
message()和aggregate()的计算融合到一个函数中。 如果适用,这将节省时间和内存,因为消息不需要显式地具体化。 只有在实现此函数并且传播基于torch_sparse.SparseTensor或torch.sparse.Tensor时,才会调用此函数。- Return type:
- update(inputs: Tensor) Tensor[source]
更新节点嵌入,类似于每个节点 \(i \in \mathcal{V}\) 的 \(\gamma_{\mathbf{\Theta}}\)。 将聚合的输出作为第一个参数,并将最初传递给
propagate()的任何参数作为其他参数。- Return type:
- edge_updater(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs: Any) Tensor[source]
对图中每条边进行特征计算或更新的初始调用。
- Parameters:
edge_index (torch.Tensor 或 SparseTensor) – 一个
torch.Tensor,一个torch_sparse.SparseTensor或一个torch.sparse.Tensor,用于定义底层图的连接性/消息传递流程。 更多信息请参见propagate()。size ((int, int), optional) – 如果
edge_index是一个torch.Tensor,则指定分配矩阵的大小(N, M)。 如果设置为None,大小将自动推断并假定为二次方。 如果edge_index是一个torch_sparse.SparseTensor或 一个torch.sparse.Tensor,则忽略此参数。(默认值:None)**kwargs – 任何用于计算或更新图中每条边特征所需的额外数据。
- Return type:
- abstract edge_update() Tensor[source]
计算或更新图中每条边的特征。 此函数可以接受最初传递给
edge_updater()的任何参数作为输入。 此外,传递给edge_updater()的张量可以通过在变量名后附加_i或_j映射到相应的节点\(i\)和\(j\),例如x_i和x_j。- Return type:
- register_propagate_forward_pre_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向预钩子。
每次在调用
propagate()之前,钩子都会被调用。 它应具有以下签名:hook(module, inputs) -> None or modified input
钩子可以修改输入。 输入关键字参数作为字典传递给钩子,位于
inputs[-1]。返回一个
torch.utils.hooks.RemovableHandle,可以通过调用handle.remove()来移除添加的钩子。- Return type:
RemovableHandle
- register_propagate_forward_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向钩子。
每次在
propagate()计算输出后,钩子将被调用。 它应具有以下签名:hook(module, inputs, output) -> None or modified output
钩子可以修改输出。 输入的关键字参数作为字典传递给钩子,位于
inputs[-1]。返回一个
torch.utils.hooks.RemovableHandle,可以通过调用handle.remove()来移除添加的钩子。- Return type:
RemovableHandle
- register_message_forward_pre_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向预钩子。 每次在调用
message()之前,都会调用这个钩子。 更多信息请参见register_propagate_forward_pre_hook()。- Return type:
RemovableHandle
- register_message_forward_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向钩子。 每次在
message()计算输出后,钩子将被调用。 有关更多信息,请参见register_propagate_forward_hook()。- Return type:
RemovableHandle
- register_aggregate_forward_pre_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向预钩子。 每次在调用
aggregate()之前,都会调用这个钩子。 更多信息请参见register_propagate_forward_pre_hook()。- Return type:
RemovableHandle
- register_aggregate_forward_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向钩子。 每次在
aggregate()计算输出后,钩子将被调用。 有关更多信息,请参见register_propagate_forward_hook()。- Return type:
RemovableHandle
- register_message_and_aggregate_forward_pre_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向预钩子。 每次在调用
message_and_aggregate()之前,都会调用这个钩子。 更多信息请参见register_propagate_forward_pre_hook()。- Return type:
RemovableHandle
- register_message_and_aggregate_forward_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向钩子。 每次在
message_and_aggregate()计算输出后,钩子将被调用。 有关更多信息,请参见register_propagate_forward_hook()。- Return type:
RemovableHandle
- register_edge_update_forward_pre_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向预钩子。 每次在调用
edge_update()之前,都会调用这个钩子。有关更多信息,请参见register_propagate_forward_pre_hook()。- Return type:
RemovableHandle
- register_edge_update_forward_hook(hook: Callable) RemovableHandle[source]
在模块上注册一个前向钩子。 每次在
edge_update()计算输出后,钩子将被调用。 有关更多信息,请参见register_propagate_forward_hook()。- Return type:
RemovableHandle
- jittable(typing: Optional[str] = None) MessagePassing[source]
分析
MessagePassing实例并生成一个新的可JIT编译的模块,该模块可以与torch.jit.script()结合使用。注意
jittable()已弃用,并且从 PyG 2.5 开始不再有效。- Return type: