简单消息传递表示

class SimpleMessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[source]

基础类:MessagePassingRepresentation

不使用关系类型的消息传递表示。

仅使用连接信息,而不使用关系类型信息,该模块可以利用在单关系图上定义的消息传递层,这些层是PyTorch Geometric库中大多数可用层。

在这里,我们在pykeen.nn.representation.Embedding之上创建了一个两层的torch_geometric.nn.conv.GCNConv

from pykeen.datasets import get_dataset

embedding_dim = 64
dataset = get_dataset(dataset="nations")
r = SimpleMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    layers=["gcn"] * 2,
    layers_kwargs=dict(in_channels=embedding_dim, out_channels=embedding_dim),
)

初始化表示。

Parameters:
  • triples_factory (CoreTriplesFactory) – 包含用于消息传递的训练三元组的工厂

  • layers (Sequence[None]) – 消息传递层或其提示

  • layers_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – 在实例化时传递给层的额外基于关键字的参数

  • base (str | Representation | type[Representation] | None) – 实体的基础表示,或其提示

  • base_kwargs (Mapping[str, Any] | None) – 在实例化时传递给基础表示的额外基于关键字的参数

  • shape (tuple[int, ...]) – 输出的形状。默认为基础表示形状。必须与最后的消息传递层的输出形状匹配。

  • max_id (int) – 表示的数量。如果提供,必须与基础表示的max_id匹配

  • activations (str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None]) – 激活函数或其提示

  • activations_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – 在实例化时传递给激活函数的额外基于关键字的参数

  • restrict_k_hop (bool) – 是否仅在请求某些索引时将消息传递限制在k跳邻域内。这利用了torch_geometric.utils.k_hop_subgraph()

  • kwargs – 传递给 Representation.__init__() 的额外基于关键字的参数

Raises:
  • ImportError – 如果未安装PyTorch Geometric

  • ValueError – 如果激活层和消息传递层的数量不匹配(在输入归一化之后)

方法总结

pass_messages(x, edge_index[, edge_mask])

执行消息传递步骤。

方法文档

pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

执行消息传递步骤。

Parameters:
  • x (Tensor) – 形状: (n, d_in) 基础实体表示

  • edge_index (Tensor) – 形状: (num_selected_edges,) 边缘索引(可能已经是完整边缘索引的一个选择)

  • edge_mask (Tensor | None) – 形状: (num_edges,) 如果消息传递受到限制,则为边掩码

Returns:

形状: (n, d_out) 丰富的实体表示

Return type:

Tensor