torch_geometric.nn.norm.HeteroBatchNorm
- class HeteroBatchNorm(in_channels: int, num_types: int, eps: float = 1e-05, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True)[source]
Bases:
Module在一批异构特征上应用批量归一化,如“批量归一化:通过减少内部协变量偏移加速深度网络训练”论文中所述。 与
BatchNorm相比,HeteroBatchNorm对每种节点或边类型分别应用归一化。- Parameters:
in_channels (int) – Size of each input sample.
num_types (int) – The number of types.
eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5)momentum (float, optional) – The value used for the running mean and running variance computation. (default:
0.1)affine (bool, optional) – If set to
True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:True)track_running_stats (bool, optional) – If set to
True, this module tracks the running mean and variance, and when set toFalse, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default:True)
- forward(x: Tensor, type_vec: Tensor) Tensor[source]
前向传播。
- Parameters:
x (torch.Tensor) – The input features.
type_vec (torch.Tensor) – 一个将每个条目映射到类型的向量。
- Return type: