torch_geometric.nn.conv.HGTConv
- class HGTConv(in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Tuple[List[str], List[Tuple[str, str, str]]], heads: int = 1, **kwargs)[source]
Bases:
MessagePassing来自“异构图变换器”论文的异构图变换器(HGT)操作符。
注意
有关使用HGT的示例,请参见examples/hetero/hgt_dblp.py。
- Parameters:
in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or
-1to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See
torch_geometric.data.HeteroData.metadata()for more information.heads (int, optional) – Number of multi-head-attentions. (default:
1)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
- forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor, SparseTensor]]) Dict[str, Optional[Tensor]][source]
运行模块的前向传播。
- Parameters:
x_dict (Dict[str, torch.Tensor]) – 一个字典,保存每个单独节点类型的输入节点特征。
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a
torch.Tensorof shape[2, num_edges]or atorch_sparse.SparseTensor.
- Return type:
Dict[str, Optional[torch.Tensor]]- 每个节点类型的输出节点嵌入。 如果某个节点类型没有接收到任何消息,其输出将被设置为None。