torch_geometric.nn.pool.MemPooling
- class MemPooling(in_channels: int, out_channels: int, heads: int, num_clusters: int, tau: float = 1.0)[source]
Bases:
Module基于内存的池化层来自“基于内存的图网络”论文,该层基于软聚类分配学习粗化的图表示。
\[ \begin{align}\begin{aligned}S_{i,j}^{(h)} &= \frac{ (1+{\| \mathbf{x}_i-\mathbf{k}^{(h)}_j \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}{ \sum_{k=1}^K (1 + {\| \mathbf{x}_i-\mathbf{k}^{(h)}_k \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}\\\mathbf{S} &= \textrm{softmax}(\textrm{Conv2d} (\Vert_{h=1}^H \mathbf{S}^{(h)})) \in \mathbb{R}^{N \times K}\\\mathbf{X}^{\prime} &= \mathbf{S}^{\top} \mathbf{X} \mathbf{W} \in \mathbb{R}^{K \times F^{\prime}}\end{aligned}\end{align} \]其中 \(H\) 表示头的数量,\(K\) 表示 簇的数量。
- Parameters:
- static kl_loss(S: Tensor) Tensor[source]
额外的基于KL散度的损失。
\[ \begin{align}\begin{aligned}P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K S_{i,k}^2 / \sum_{n=1}^N S_{n,k}}\\\mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert \mathbf{S})\end{aligned}\end{align} \]- Return type:
- forward(x: Tensor, batch: Optional[Tensor] = None, mask: Optional[Tensor] = None, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None) Tuple[Tensor, Tensor][source]
前向传播。
- Parameters:
x (torch.Tensor) – 节点特征张量,形状为 \(\mathbf{X} \in \mathbb{R}^{N \times F}\) 或 \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\)。
batch (torch.Tensor, optional) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点分配给一个特定的示例。 如果节点特征已经具有形状 \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\),则不应提供。 (默认:
None)mask (torch.Tensor, optional) – 一个掩码矩阵 \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\),当使用形状为 \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) 的节点特征时,它指示每个图中的有效节点。 (default:
None)max_num_nodes (int, optional) – \(B\) 节点维度的大小。如果未给出,则自动计算。 (默认值:
None)batch_size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default:
None)
- Return type: