torch_geometric.nn.models.HeteroJumpingKnowledge

class HeteroJumpingKnowledge(types: List[str], mode: str, channels: Optional[int] = None, num_layers: Optional[int] = None)[source]

Bases: Module

JumpingKnowledge 模块的异构版本。

Parameters:
  • types (List[str]) – 输入字典的键。

  • mode (str) – The aggregation scheme to use ("cat", "max" or "lstm").

  • channels (int, optional) – The number of channels per representation. Needs to be only set for LSTM-style aggregation. (default: None)

  • num_layers (int, optional) – The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default: None)

forward(xs_dict: Dict[str, List[Tensor]]) Dict[str, Tensor][source]

前向传播。

Parameters:

xs_dict (Dict[str, List[torch.Tensor]]) – 一个字典,保存每种类型的逐层表示列表。

Return type:

Dict[str, Tensor]

reset_parameters() None[source]

重置模块的所有可学习参数。

Return type:

None