PyG 消息传递
基于PyTorch Geometric的表示模块。
这些模块使得实体表示能够与其图邻居的表示相连接。类似的表示方法包括CompGCN或R-GCN。然而,该模块提供了通用模块,可以将PyTorch Geometric中的许多消息传递层与基础表示结合起来。可用的消息传递层的摘要可以在torch_geometric.nn.conv中找到。
这三个类在使用关系类型信息的方式上有所不同:
SimpleMessagePassingRepresentation仅使用来自训练三元组的连接信息, 但忽略了关系类型,例如torch_geometric.nn.conv.GCNConv。TypedMessagePassingRepresentation用于消息传递层,它内部通过 edge_type 输入处理分类关系类型信息,例如torch_geometric.nn.conv.RGCNConv。FeaturizedMessagePassingRepresentation用于可以通过参数 edge_attr 使用边属性的消息传递层,例如torch_geometric.nn.conv.GMMConv。
我们也可以轻松地利用这些表示与pykeen.models.ERModel。在这里,我们展示了如何将基于静态标签的实体特征与可训练的GCN编码器结合用于实体表示,以及用于关系表示的学习嵌入和DistMult交互函数。
from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn.init import LabelBasedInitializer
from pykeen.pipeline import pipeline
dataset = get_dataset(dataset="nations", dataset_kwargs=dict(create_inverse_triples=True))
entity_initializer = LabelBasedInitializer.from_triples_factory(
triples_factory=dataset.training,
for_entities=True,
)
(embedding_dim,) = entity_initializer.tensor.shape[1:]
r = pipeline(
dataset=dataset,
model=ERModel,
model_kwargs=dict(
interaction="distmult",
entity_representations="SimpleMessagePassing",
entity_representations_kwargs=dict(
triples_factory=dataset.training,
base_kwargs=dict(
shape=embedding_dim,
initializer=entity_initializer,
trainable=False,
),
layers=["GCN"] * 2,
layers_kwargs=dict(in_channels=embedding_dim, out_channels=embedding_dim),
),
relation_representations_kwargs=dict(
shape=embedding_dim,
),
),
)
类
|
一个利用PyTorch Geometric消息传递层的抽象表示类。 |
|
不使用关系类型的消息传递表示。 |
一种使用从关系表示中获得的边缘特征进行消息传递的表示。 |
|
一种使用分类关系类型信息的消息传递表示。 |
类继承图
