torch_geometric.nn.conv.HEATConv

class HEATConv(in_channels: int, out_channels: int, num_node_types: int, num_edge_types: int, edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, root_weight: bool = True, bias: bool = True, **kwargs)[source]

Bases: MessagePassing

来自“Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction”论文的异构边缘增强图注意力操作符。

HEATConv 通过以下方式增强了 GATConv

  1. 不同类型节点的特定类型转换

  2. 边类型和边特征结合,其中假设边具有不同的类型但包含相同类型的属性

Parameters:
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • num_node_types (int) – 节点类型的数量。

  • num_edge_types (int) – 边的类型数量。

  • edge_type_emb_dim (int) – 边类型的嵌入大小。

  • edge_dim (int) – Edge feature dimensionality.

  • edge_attr_emb_dim (int) – 边特征的嵌入大小。

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, 可选) – 如果设置为 False,多头注意力机制将被平均而不是连接。 (默认: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • root_weight (bool, 可选) – 如果设置为 False,该层将 不会将转换后的根节点特征添加到输出中。 (默认: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • 输入: 节点特征 \((|\mathcal{V}|, F_{in})\), 边索引 \((2, |\mathcal{E}|)\), 节点类型 \((|\mathcal{V}|)\), 边类型 \((|\mathcal{E}|)\), 边特征 \((|\mathcal{E}|, D)\) (可选)

  • output: node features \((|\mathcal{V}|, F_{out})\)

forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], node_type: Tensor, edge_type: Tensor, edge_attr: Optional[Tensor] = None) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。