torch_geometric.nn.norm.DiffGroupNorm

class DiffGroupNorm(in_channels: int, groups: int, lamda: float = 0.01, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)[source]

Bases: Module

来自“Towards Deeper Graph Neural Networks with Differentiable Group Normalization”论文的可微分组归一化层,该层通过可学习的软聚类分配对节点特征进行分组归一化。

\[\mathbf{S} = \text{softmax} (\mathbf{X} \mathbf{W})\]

其中 \(\mathbf{W} \in \mathbb{R}^{F \times G}\) 表示一个可训练的权重矩阵,将每个节点映射到 \(G\) 个聚类中的一个。然后通过以下方式进行分组归一化:

\[\mathbf{X}^{\prime} = \mathbf{X} + \lambda \sum_{i = 1}^G \text{BatchNorm}(\mathbf{S}[:, i] \odot \mathbf{X})\]
Parameters:
  • in_channels (int) – Size of each input sample \(F\).

  • groups (int) – 组的数量 \(G\)

  • lamda (float, optional) – 输入嵌入和归一化嵌入之间的平衡因子 \(\lambda\)。(默认值:0.01

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)

reset_parameters()[source]

重置模块的所有可学习参数。

forward(x: Tensor) Tensor[source]

前向传播。

Parameters:

x (torch.Tensor) – The source tensor.

Return type:

Tensor

static group_distance_ratio(x: Tensor, y: Tensor, eps: float = 1e-05) float[source]

测量组间距离与组内距离的比率。

\[R_{\text{Group}} = \frac{\frac{1}{(C-1)^2} \sum_{i!=j} \frac{1}{|\mathbf{X}_i||\mathbf{X}_j|} \sum_{\mathbf{x}_{iv} \in \mathbf{X}_i } \sum_{\mathbf{x}_{jv^{\prime}} \in \mathbf{X}_j} {\| \mathbf{x}_{iv} - \mathbf{x}_{jv^{\prime}} \|}_2 }{ \frac{1}{C} \sum_{i} \frac{1}{{|\mathbf{X}_i|}^2} \sum_{\mathbf{x}_{iv}, \mathbf{x}_{iv^{\prime}} \in \mathbf{X}_i } {\| \mathbf{x}_{iv} - \mathbf{x}_{iv^{\prime}} \|}_2 }\]

其中 \(\mathbf{X}_i\) 表示属于类 \(i\) 的所有节点的集合,而 \(C\) 表示 y 中的总类别数。

Return type:

float