torch_geometric.nn.conv.GPSConv

class GPSConv(channels: int, conv: Optional[MessagePassing], heads: int = 1, dropout: float = 0.0, act: str = 'relu', act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[str] = 'batch_norm', norm_kwargs: Optional[Dict[str, Any]] = None, attn_type: str = 'multihead', attn_kwargs: Optional[Dict[str, Any]] = None)[source]

Bases: Module

来自“通用、强大、可扩展的图变换器配方”论文的通用、强大、可扩展(GPS)图变换器层。

GPS 层基于三部分配方:

  1. 将位置编码(PE)和结构编码(SE)包含到输入特征中(通过torch_geometric.transforms在预处理步骤中完成)。

  2. 一个在输入图上操作的本地消息传递层(MPNN)。

  3. 一个在整个图上操作的全局注意力层。

注意

有关使用 GPSConv 的示例,请参见 examples/graph_gps.py

Parameters:
  • channels (int) – Size of each input sample.

  • conv (MessagePassing, optional) – 本地消息传递层。

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • dropout (float, optional) – 中间嵌入的丢弃概率。(默认: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (strCallable, 可选) – 使用的归一化函数。(默认值:"batch_norm"

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • attn_type (str) – 全局注意力类型,multiheadperformer。(默认值:multihead

  • attn_kwargs (Dict[str, Any], optional) – 传递给注意力层的参数。(默认值:None

forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], batch: Optional[Tensor] = None, **kwargs) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。