torch_geometric.nn.pool.global_mean_pool

global_mean_pool(x: Tensor, batch: Optional[Tensor], size: Optional[int] = None) Tensor[source]

通过平均节点维度上的节点特征,返回批处理级别的图级输出。

对于单个图 \(\mathcal{G}_i\),其输出通过以下方式计算:

\[\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n.\]

MeanAggregation 模块的功能方法。

Parameters:
  • x (torch.Tensor) – 节点特征矩阵 \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (torch.Tensor, optional) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点分配给一个特定的示例。

  • size (int, optional) – 示例的数量 \(B\)。 如果未给出,则自动计算。(默认值:None

Return type:

Tensor