torch_geometric.nn.pool.ClusterPooling

class ClusterPooling(in_channels: int, edge_score_method: str = 'tanh', dropout: float = 0.0, threshold: Optional[float] = None)[source]

Bases: Module

来自“基于边的图组件池化”论文的集群池化操作符。

ClusterPooling 为每条边计算一个分数。 根据选定的边,计算图簇并使用内射的 "sum" 聚合函数将其压缩为一个节点。 边根据每个簇创建的节点和原始边重新映射。

Parameters:
  • in_channels (int) – Size of each input sample.

  • edge_score_method (str, optional) – 用于从原始边分数计算边分数的函数("tanh", "sigmoid", "log_softmax")。(默认值:"tanh"

  • dropout (float, optional) – The probability with which to drop edge scores during training. (default: 0.0)

  • threshold (float, optional) – 边缘分数的阈值。如果设置为 None,将根据 edge_score_method自动推断。(默认值: None)

reset_parameters()[source]

重置模块的所有可学习参数。

forward(x: Tensor, edge_index: Tensor, batch: Tensor) Tuple[Tensor, Tensor, Tensor, UnpoolInfo][source]

前向传播。

Parameters:
  • x (torch.Tensor) – The node features.

  • edge_index (torch.Tensor) – The edge indices.

  • batch (torch.Tensor) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点分配给一个特定的示例。

Return types:
  • x (torch.Tensor) - The pooled node features.

  • edge_index (torch.Tensor) - The coarsened edge indices.

  • batch (torch.Tensor) - The coarsened batch vector.

  • unpool_info (UnpoolInfo) - 可以用于反池化的信息。

Return type:

Tuple[Tensor, Tensor, Tensor, UnpoolInfo]