消息传递表示

class MessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[源代码]

基础类: Representation, ABC

一个利用PyTorch Geometric消息传递层的抽象表示类。

It comprises:
  • 基础(实体)表示,也可以作为提示传递

  • 一系列消息传递层。它们在一个抽象的MessagePassingRepresentation._message_passing()中被使用,以通过邻居信息丰富基础表示。

  • 在消息传递层之间的一系列激活层。

  • 一个edge_index缓冲区,用于存储边索引,并与模块一起移动到设备上。

初始化表示。

Parameters:
  • triples_factory (CoreTriplesFactory) – 包含用于消息传递的训练三元组的工厂

  • layers (Sequence[None]) – 消息传递层或其提示

  • layers_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – 在实例化时传递给层的额外基于关键字的参数

  • base (str | Representation | type[Representation] | None) – 实体的基础表示,或其提示

  • base_kwargs (Mapping[str, Any] | None) – 在实例化时传递给基础表示的额外基于关键字的参数

  • shape (tuple[int, ...]) – 输出的形状。默认为基础表示形状。必须与最后的消息传递层的输出形状匹配。

  • max_id (int) – 表示的数量。如果提供,必须与基础表示的max_id匹配

  • activations (str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None]) – 激活函数或其提示

  • activations_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – 在实例化时传递给激活函数的额外基于关键字的参数

  • restrict_k_hop (bool) – 是否仅在请求某些索引时将消息传递限制在k跳邻域内。这利用了torch_geometric.utils.k_hop_subgraph()

  • kwargs – 传递给 Representation.__init__() 的额外基于关键字的参数

Raises:
  • ImportError – 如果未安装PyTorch Geometric

  • ValueError – 如果激活层和消息传递层的数量不匹配(在输入归一化之后)

方法总结

pass_messages(x, edge_index[, edge_mask])

执行消息传递步骤。

方法文档

abstract pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

执行消息传递步骤。

Parameters:
  • x (Tensor) – 形状: (n, d_in) 基础实体表示

  • edge_index (Tensor) – 形状: (num_selected_edges,) 边缘索引(可能已经是完整边缘索引的一个选择)

  • edge_mask (Tensor | None) – 形状: (num_edges,) 如果消息传递受到限制,则为边掩码

Returns:

形状: (n, d_out) 丰富的实体表示

Return type:

Tensor