torch_geometric.nn.pool.avg_pool

avg_pool(cluster: Tensor, data: Data, transform: Optional[Callable] = None) Data[source]

根据cluster中定义的聚类,对由torch_geometric.data.Data对象给出的图进行池化和粗化。 最终节点特征由同一聚类内所有节点的平均特征定义。 更多详情请参见torch_geometric.nn.pool.max_pool()

Parameters:
  • cluster (torch.Tensor) – The cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • data (Data) – 图数据对象。

  • transform (callable, optional) – 一个函数/转换,它接收粗化和池化的 torch_geometric.data.Data 对象并返回转换后的版本。(默认值:None

Return type:

torch_geometric.data.Data