torch_geometric.nn.pool.MemPooling

class MemPooling(in_channels: int, out_channels: int, heads: int, num_clusters: int, tau: float = 1.0)[source]

Bases: Module

基于内存的池化层来自“基于内存的图网络”论文,该层基于软聚类分配学习粗化的图表示。

\[ \begin{align}\begin{aligned}S_{i,j}^{(h)} &= \frac{ (1+{\| \mathbf{x}_i-\mathbf{k}^{(h)}_j \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}{ \sum_{k=1}^K (1 + {\| \mathbf{x}_i-\mathbf{k}^{(h)}_k \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}\\\mathbf{S} &= \textrm{softmax}(\textrm{Conv2d} (\Vert_{h=1}^H \mathbf{S}^{(h)})) \in \mathbb{R}^{N \times K}\\\mathbf{X}^{\prime} &= \mathbf{S}^{\top} \mathbf{X} \mathbf{W} \in \mathbb{R}^{K \times F^{\prime}}\end{aligned}\end{align} \]

其中 \(H\) 表示头的数量,\(K\) 表示 簇的数量。

Parameters:
  • in_channels (int) – 每个输入样本的大小 \(F\)

  • out_channels (int) – 每个输出样本的大小 \(F^{\prime}\)

  • heads (int) – 头的数量 \(H\)

  • num_clusters (int) – 每个头的聚类数量 \(K\)

  • tau (int, optional) – 温度 \(\tau\)。 (默认值: 1.)

reset_parameters()[source]

重置模块的所有可学习参数。

static kl_loss(S: Tensor) Tensor[source]

额外的基于KL散度的损失。

\[ \begin{align}\begin{aligned}P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K S_{i,k}^2 / \sum_{n=1}^N S_{n,k}}\\\mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert \mathbf{S})\end{aligned}\end{align} \]
Return type:

Tensor

forward(x: Tensor, batch: Optional[Tensor] = None, mask: Optional[Tensor] = None, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None) Tuple[Tensor, Tensor][source]

前向传播。

Parameters:
  • x (torch.Tensor) – 节点特征张量,形状为 \(\mathbf{X} \in \mathbb{R}^{N \times F}\)\(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\)

  • batch (torch.Tensor, optional) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点分配给一个特定的示例。 如果节点特征已经具有形状 \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\),则不应提供。 (默认: None)

  • mask (torch.Tensor, optional) – 一个掩码矩阵 \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\),当使用形状为 \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) 的节点特征时,它指示每个图中的有效节点。 (default: None)

  • max_num_nodes (int, optional) – \(B\) 节点维度的大小。如果未给出,则自动计算。 (默认值: None)

  • batch_size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default: None)

Return type:

Tuple[Tensor, Tensor]