简单消息传递表示
- class SimpleMessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[source]
基础类:
MessagePassingRepresentation不使用关系类型的消息传递表示。
仅使用连接信息,而不使用关系类型信息,该模块可以利用在单关系图上定义的消息传递层,这些层是PyTorch Geometric库中大多数可用层。
在这里,我们在
pykeen.nn.representation.Embedding之上创建了一个两层的torch_geometric.nn.conv.GCNConv:from pykeen.datasets import get_dataset embedding_dim = 64 dataset = get_dataset(dataset="nations") r = SimpleMessagePassingRepresentation( triples_factory=dataset.training, base_kwargs=dict(shape=embedding_dim), layers=["gcn"] * 2, layers_kwargs=dict(in_channels=embedding_dim, out_channels=embedding_dim), )
初始化表示。
- Parameters:
triples_factory (CoreTriplesFactory) – 包含用于消息传递的训练三元组的工厂
layers (Sequence[None]) – 消息传递层或其提示
layers_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – 在实例化时传递给层的额外基于关键字的参数
base (str | Representation | type[Representation] | None) – 实体的基础表示,或其提示
base_kwargs (Mapping[str, Any] | None) – 在实例化时传递给基础表示的额外基于关键字的参数
shape (tuple[int, ...]) – 输出的形状。默认为基础表示形状。必须与最后的消息传递层的输出形状匹配。
max_id (int) – 表示的数量。如果提供,必须与基础表示的max_id匹配
activations (str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None]) – 激活函数或其提示
activations_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – 在实例化时传递给激活函数的额外基于关键字的参数
restrict_k_hop (bool) – 是否仅在请求某些索引时将消息传递限制在k跳邻域内。这利用了
torch_geometric.utils.k_hop_subgraph()。kwargs – 传递给
Representation.__init__()的额外基于关键字的参数
- Raises:
ImportError – 如果未安装PyTorch Geometric
ValueError – 如果激活层和消息传递层的数量不匹配(在输入归一化之后)
方法总结
pass_messages(x, edge_index[, edge_mask])执行消息传递步骤。
方法文档