torch_geometric.nn.pool.avg_pool
- avg_pool(cluster: Tensor, data: Data, transform: Optional[Callable] = None) Data[source]
根据
cluster中定义的聚类,对由torch_geometric.data.Data对象给出的图进行池化和粗化。 最终节点特征由同一聚类内所有节点的平均特征定义。 更多详情请参见torch_geometric.nn.pool.max_pool()。- Parameters:
cluster (torch.Tensor) – The cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.
data (Data) – 图数据对象。
transform (callable, optional) – 一个函数/转换,它接收粗化和池化的
torch_geometric.data.Data对象并返回转换后的版本。(默认值:None)
- Return type: