torch_geometric.nn.pool.ClusterPooling
- class ClusterPooling(in_channels: int, edge_score_method: str = 'tanh', dropout: float = 0.0, threshold: Optional[float] = None)[source]
Bases:
Module来自“基于边的图组件池化”论文的集群池化操作符。
ClusterPooling为每条边计算一个分数。 根据选定的边,计算图簇并使用内射的"sum"聚合函数将其压缩为一个节点。 边根据每个簇创建的节点和原始边重新映射。- Parameters:
in_channels (int) – Size of each input sample.
edge_score_method (str, optional) – 用于从原始边分数计算边分数的函数(
"tanh","sigmoid","log_softmax")。(默认值:"tanh")dropout (float, optional) – The probability with which to drop edge scores during training. (default:
0.0)threshold (float, optional) – 边缘分数的阈值。如果设置为
None,将根据edge_score_method自动推断。(默认值:None)
- forward(x: Tensor, edge_index: Tensor, batch: Tensor) Tuple[Tensor, Tensor, Tensor, UnpoolInfo][source]
前向传播。
- Parameters:
x (torch.Tensor) – The node features.
edge_index (torch.Tensor) – The edge indices.
batch (torch.Tensor) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点分配给一个特定的示例。
- Return types:
x (torch.Tensor) - The pooled node features.
edge_index (torch.Tensor) - The coarsened edge indices.
batch (torch.Tensor) - The coarsened batch vector.
unpool_info (UnpoolInfo) - 可以用于反池化的信息。