torch_geometric.nn.norm.HeteroBatchNorm

class HeteroBatchNorm(in_channels: int, num_types: int, eps: float = 1e-05, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True)[source]

Bases: Module

在一批异构特征上应用批量归一化,如“批量归一化:通过减少内部协变量偏移加速深度网络训练”论文中所述。 与BatchNorm相比,HeteroBatchNorm对每种节点或边类型分别应用归一化。

Parameters:
  • in_channels (int) – Size of each input sample.

  • num_types (int) – The number of types.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)

reset_running_stats()[source]

重置模块的所有运行统计信息。

reset_parameters()[source]

重置模块的所有可学习参数。

forward(x: Tensor, type_vec: Tensor) Tensor[source]

前向传播。

Parameters:
Return type:

Tensor