torch_geometric.nn.conv.HeteroConv
- class HeteroConv(convs: Dict[Tuple[str, str, str], MessagePassing], aggr: Optional[str] = 'sum')[source]
Bases:
Module用于在异质图上计算图卷积的通用包装器。 该层将根据为特定边类型提供的二分图神经网络层,从源节点向目标节点传递消息。 如果多个关系指向同一个目标,它们的结果将根据
aggr进行聚合。 与torch_geometric.nn.to_hetero()相比,如果您想为不同的边类型应用不同的消息传递模块,该层特别有用。hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'written_by', 'author'): GATConv((-1, -1), 64), }, aggr='sum') out_dict = hetero_conv(x_dict, edge_index_dict) print(list(out_dict.keys())) >>> ['paper', 'author']
- Parameters:
convs (Dict[Tuple[str, str, str], MessagePassing]) – 一个字典 包含一个二分图
MessagePassing层,用于每个 单独的边类型。aggr (str, 可选) – 用于分组节点嵌入的聚合方案,这些嵌入由不同关系生成 (
"sum","mean","min","max","cat",None). (默认:"sum")
- forward(*args_dict, **kwargs_dict) Dict[str, Tensor][source]
运行模块的前向传播。
- Parameters:
x_dict (Dict[str, torch.Tensor]) – A dictionary holding node feature information for each individual node type.
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a
torch.Tensorof shape[2, num_edges]or atorch_sparse.SparseTensor.*args_dict (可选) – 各个
torch_geometric.nn.conv.MessagePassing层的额外前向参数。**kwargs_dict (可选) – 各个
torch_geometric.nn.conv.MessagePassing层的额外前向参数。 例如,如果特定GNN层在边类型edge_type处期望边属性edge_attr作为前向参数,那么你可以通过edge_attr_dict = { edge_type: edge_attr }将它们传递给forward()。
- Return type: