特征化消息传递表示

class FeaturizedMessagePassingRepresentation(triples_factory: CoreTriplesFactory, relation_representation: str | Representation | type[Representation] | None = None, relation_representation_kwargs: Mapping[str, Any] | None = None, relation_transformation: Module | None = None, **kwargs)[source]

基础类:TypedMessagePassingRepresentation

一种使用从关系表示中获得的边缘特征进行消息传递的表示。

它(重新)使用关系表示层来获取边缘特征,然后通过适当的消息传递层(例如,torch_geometric.nn.conv.GMMConv,或 torch_geometric.nn.conv.GATConv)来利用这些特征。我们进一步允许在层之间对边缘特征进行(共享)转换。

以下示例在基础表示之上创建了一个两层的GAT:

from pykeen.datasets import get_dataset

embedding_dim = 64
dataset = get_dataset(dataset="nations")
r = FeaturizedMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    relation_representation_kwargs=dict(
        shape=embedding_dim,
    ),
    layers="gat",
    layers_kwargs=dict(
        in_channels=embedding_dim,
        out_channels=embedding_dim,
        edge_dim=embedding_dim,  # should match relation dim
    ),
)

初始化表示。

Parameters:
  • triples_factory (CoreTriplesFactory) – 包含用于消息传递的训练三元组的工厂

  • relation_representation (Representation) – 关系的基础表示,或其提示

  • relation_representation_kwargs (Mapping[str, Any] | None) – 在实例化时传递给基础表示的额外基于关键字的参数

  • relation_transformation (Module | None) – 一个可选的转换,用于在每次消息传递步骤后应用于关系表示。 如果为None,则不修改表示。

  • kwargs – 传递给 TypedMessagePassingRepresentation.__init__()的额外基于关键字的参数,除了triples_factory

方法总结

pass_messages(x, edge_index[, edge_mask])

执行消息传递步骤。

方法文档

pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[源代码]

执行消息传递步骤。

Parameters:
  • x (Tensor) – 形状: (n, d_in) 基础实体表示

  • edge_index (Tensor) – 形状: (num_selected_edges,) 边缘索引(可能已经是完整边缘索引的一个选择)

  • edge_mask (Tensor | None) – 形状: (num_edges,) 如果消息传递受到限制,则为边掩码

Returns:

形状: (n, d_out) 丰富的实体表示

Return type:

Tensor