消息传递表示
- class MessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[源代码]
基础类:
Representation,ABC一个利用PyTorch Geometric消息传递层的抽象表示类。
- It comprises:
基础(实体)表示,也可以作为提示传递
一系列消息传递层。它们在一个抽象的
MessagePassingRepresentation._message_passing()中被使用,以通过邻居信息丰富基础表示。在消息传递层之间的一系列激活层。
一个edge_index缓冲区,用于存储边索引,并与模块一起移动到设备上。
初始化表示。
- Parameters:
triples_factory (CoreTriplesFactory) – 包含用于消息传递的训练三元组的工厂
layers (Sequence[None]) – 消息传递层或其提示
layers_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – 在实例化时传递给层的额外基于关键字的参数
base (str | Representation | type[Representation] | None) – 实体的基础表示,或其提示
base_kwargs (Mapping[str, Any] | None) – 在实例化时传递给基础表示的额外基于关键字的参数
shape (tuple[int, ...]) – 输出的形状。默认为基础表示形状。必须与最后的消息传递层的输出形状匹配。
max_id (int) – 表示的数量。如果提供,必须与基础表示的max_id匹配
activations (str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None]) – 激活函数或其提示
activations_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – 在实例化时传递给激活函数的额外基于关键字的参数
restrict_k_hop (bool) – 是否仅在请求某些索引时将消息传递限制在k跳邻域内。这利用了
torch_geometric.utils.k_hop_subgraph()。kwargs – 传递给
Representation.__init__()的额外基于关键字的参数
- Raises:
ImportError – 如果未安装PyTorch Geometric
ValueError – 如果激活层和消息传递层的数量不匹配(在输入归一化之后)
方法总结
pass_messages(x, edge_index[, edge_mask])执行消息传递步骤。
方法文档