类型化消息传递表示
- class TypedMessagePassingRepresentation(triples_factory: CoreTriplesFactory, **kwargs)[源代码]
基础类:
MessagePassingRepresentation一种使用分类关系类型信息的消息传递表示。
该模块的消息传递层通过edge_type输入在内部处理分类关系类型信息,例如,
torch_geometric.nn.conv.RGCNConv,或torch_geometric.nn.conv.RGATConv。以下示例使用基础分解创建了一个单层RGCN:
from pykeen.datasets import get_dataset embedding_dim = 64 dataset = get_dataset(dataset="nations") r = TypedMessagePassingRepresentation( triples_factory=dataset.training, base_kwargs=dict(shape=embedding_dim), layers="rgcn", layers_kwargs=dict( in_channels=embedding_dim, out_channels=embedding_dim, num_bases=2, num_relations=dataset.num_relations, ), )
初始化表示。
- Parameters:
triples_factory (CoreTriplesFactory) – 包含用于消息传递的训练三元组的工厂
kwargs – 传递给
MessagePassingRepresentation.__init__()的额外基于关键字的参数
方法总结
pass_messages(x, edge_index[, edge_mask])执行消息传递步骤。
方法文档