类型化消息传递表示

class TypedMessagePassingRepresentation(triples_factory: CoreTriplesFactory, **kwargs)[源代码]

基础类:MessagePassingRepresentation

一种使用分类关系类型信息的消息传递表示。

该模块的消息传递层通过edge_type输入在内部处理分类关系类型信息,例如,torch_geometric.nn.conv.RGCNConv,或 torch_geometric.nn.conv.RGATConv

以下示例使用基础分解创建了一个单层RGCN:

from pykeen.datasets import get_dataset

embedding_dim = 64
dataset = get_dataset(dataset="nations")
r = TypedMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    layers="rgcn",
    layers_kwargs=dict(
        in_channels=embedding_dim,
        out_channels=embedding_dim,
        num_bases=2,
        num_relations=dataset.num_relations,
    ),
)

初始化表示。

Parameters:
  • triples_factory (CoreTriplesFactory) – 包含用于消息传递的训练三元组的工厂

  • kwargs – 传递给 MessagePassingRepresentation.__init__() 的额外基于关键字的参数

方法总结

pass_messages(x, edge_index[, edge_mask])

执行消息传递步骤。

方法文档

pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

执行消息传递步骤。

Parameters:
  • x (Tensor) – 形状: (n, d_in) 基础实体表示

  • edge_index (Tensor) – 形状: (num_selected_edges,) 边缘索引(可能已经是完整边缘索引的一个选择)

  • edge_mask (Tensor | None) – 形状: (num_edges,) 如果消息传递受到限制,则为边掩码

Returns:

形状: (n, d_out) 丰富的实体表示

Return type:

Tensor