torch_geometric.nn.conv.HeteroConv

class HeteroConv(convs: Dict[Tuple[str, str, str], MessagePassing], aggr: Optional[str] = 'sum')[source]

Bases: Module

用于在异质图上计算图卷积的通用包装器。 该层将根据为特定边类型提供的二分图神经网络层,从源节点向目标节点传递消息。 如果多个关系指向同一个目标,它们的结果将根据aggr进行聚合。 与torch_geometric.nn.to_hetero()相比,如果您想为不同的边类型应用不同的消息传递模块,该层特别有用。

hetero_conv = HeteroConv({
    ('paper', 'cites', 'paper'): GCNConv(-1, 64),
    ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
    ('paper', 'written_by', 'author'): GATConv((-1, -1), 64),
}, aggr='sum')

out_dict = hetero_conv(x_dict, edge_index_dict)

print(list(out_dict.keys()))
>>> ['paper', 'author']
Parameters:
  • convs (Dict[Tuple[str, str, str], MessagePassing]) – 一个字典 包含一个二分图 MessagePassing 层,用于每个 单独的边类型。

  • aggr (str, 可选) – 用于分组节点嵌入的聚合方案,这些嵌入由不同关系生成 ("sum", "mean", "min", "max", "cat", None). (默认: "sum")

forward(*args_dict, **kwargs_dict) Dict[str, Tensor][source]

运行模块的前向传播。

Parameters:
  • x_dict (Dict[str, torch.Tensor]) – A dictionary holding node feature information for each individual node type.

  • edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a torch.Tensor of shape [2, num_edges] or a torch_sparse.SparseTensor.

  • *args_dict (可选) – 各个torch_geometric.nn.conv.MessagePassing层的额外前向参数。

  • **kwargs_dict (可选) – 各个torch_geometric.nn.conv.MessagePassing层的额外前向参数。 例如,如果特定GNN层在边类型edge_type处期望边属性edge_attr作为前向参数,那么你可以通过edge_attr_dict = { edge_type: edge_attr }将它们传递给forward()

Return type:

Dict[str, Tensor]

reset_parameters()[source]

重置模块的所有可学习参数。