torch_geometric.nn.conv.GPSConv
- class GPSConv(channels: int, conv: Optional[MessagePassing], heads: int = 1, dropout: float = 0.0, act: str = 'relu', act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[str] = 'batch_norm', norm_kwargs: Optional[Dict[str, Any]] = None, attn_type: str = 'multihead', attn_kwargs: Optional[Dict[str, Any]] = None)[source]
Bases:
Module来自“通用、强大、可扩展的图变换器配方”论文的通用、强大、可扩展(GPS)图变换器层。
GPS 层基于三部分配方:
将位置编码(PE)和结构编码(SE)包含到输入特征中(通过
torch_geometric.transforms在预处理步骤中完成)。一个在输入图上操作的本地消息传递层(MPNN)。
一个在整个图上操作的全局注意力层。
注意
有关使用
GPSConv的示例,请参见 examples/graph_gps.py。- Parameters:
channels (int) – Size of each input sample.
conv (MessagePassing, optional) – 本地消息传递层。
heads (int, optional) – Number of multi-head-attentions. (default:
1)dropout (float, optional) – 中间嵌入的丢弃概率。(默认:
0.)act (str or Callable, optional) – The non-linear activation function to use. (default:
"relu")act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by
act. (default:None)norm (str 或 Callable, 可选) – 使用的归一化函数。(默认值:
"batch_norm")norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by
norm. (default:None)attn_type (str) – 全局注意力类型,
multihead或performer。(默认值:multihead)attn_kwargs (Dict[str, Any], optional) – 传递给注意力层的参数。(默认值:
None)