torch_geometric.nn.pool.ASAPooling
- class ASAPooling(in_channels: int, ratio: Union[float, int] = 0.5, GNN: Optional[Callable] = None, dropout: float = 0.0, negative_slope: float = 0.2, add_self_loops: bool = False, **kwargs)[source]
Bases:
Module来自“ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations”论文的自适应结构感知池化操作符。
- Parameters:
in_channels (int) – Size of each input sample.
ratio (float 或 int) – 图池化比率,用于计算 \(k = \lceil \mathrm{ratio} \cdot N \rceil\),或者直接作为 \(k\) 的值,具体取决于
ratio的类型是float还是int。(默认值:0.5)GNN (torch.nn.Module, optional) – 用于利用集群内属性的图神经网络层。 对于具有较高邻域度的图特别有帮助 (其中之一是
torch_geometric.nn.conv.GraphConv,torch_geometric.nn.conv.GCNConv或 任何支持edge_weight参数的 GNN)。 (默认值:None)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2)add_self_loops (bool, 可选) – 如果设置为
True,将会在新的图连接中添加自环。(默认:False)**kwargs (可选) – 用于初始化图神经网络层的额外参数。
- forward(x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor][source]
前向传播。
- Parameters:
x (torch.Tensor) – 节点特征矩阵。
edge_index (torch.Tensor) – The edge indices.
edge_weight (torch.Tensor, optional) – The edge weights. (default:
None)batch (torch.Tensor, optional) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点分配到一个特定的示例中。(默认:
None)
- Return types:
x (torch.Tensor): 池化后的节点嵌入。
edge_index (torch.Tensor): 粗化的边索引。
edge_weight (torch.Tensor, optional): 粗化的边权重。
batch (torch.Tensor): 粗化的批次向量。
index (torch.Tensor): 在池化后保留的节点的前\(k\)个节点索引。