torch_geometric.nn.pool.max_pool
- max_pool(cluster: Tensor, data: Data, transform: Optional[Callable] = None) Data[source]
根据
cluster中定义的聚类,对由torch_geometric.data.Data对象给出的图进行池化和粗化。 同一聚类中的所有节点将被表示为一个节点。 最终节点特征由同一聚类中所有节点的最大特征定义,节点位置取平均值,边索引定义为同一聚类中所有节点的边索引的并集。- Parameters:
cluster (torch.Tensor) – The cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.
data (Data) – Graph data object.
transform (callable, optional) – A function/transform that takes in the coarsened and pooled
torch_geometric.data.Dataobject and returns a transformed version. (default:None)
- Return type: