torch_geometric.nn.conv.HANConv

class HANConv(in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Tuple[List[str], List[Tuple[str, str, str]]], heads: int = 1, negative_slope=0.2, dropout: float = 0.0, **kwargs)[source]

Bases: MessagePassing

来自“异质图注意力网络”论文的异质图注意力操作符。

注意

有关使用HANConv的示例,请参见examples/hetero/han_imdb.py

Parameters:
  • in_channels (intDict[str, int]) – 每个节点类型的输入样本的大小,或 -1 以从 forward 方法的第一个输入推导出大小。

  • out_channels (int) – Size of each output sample.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – 异构图(heterogeneous graph)的元数据,其节点和边类型分别由字符串列表和字符串三元组列表给出。更多信息请参见 torch_geometric.data.HeteroData.metadata()

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor, SparseTensor]], return_semantic_attention_weights: bool = False) Union[Dict[str, Optional[Tensor]], Tuple[Dict[str, Optional[Tensor]], Dict[str, Optional[Tensor]]]][source]

运行模块的前向传播。

Parameters:
  • x_dict (Dict[str, torch.Tensor]) – 一个字典,包含每个单独节点类型的节点特征信息。

  • edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – 一个 字典,用于保存每种边类型的图连接信息,可以是形状为 [2, num_edges]torch.Tensortorch_sparse.SparseTensor

  • return_semantic_attention_weights (bool, optional) – 如果设置为 True,将额外返回每个目标节点类型的语义级 注意力权重。 (默认: False)

Return type:

Union[Dict[str, Optional[Tensor]], Tuple[Dict[str, Optional[Tensor]], Dict[str, Optional[Tensor]]]]

reset_parameters()[source]

重置模块的所有可学习参数。