RMSNorm¶
- class torchtune.modules.RMSNorm(dim: int, eps: float = 1e-06)[source]¶
fp32中的均方根归一化。
参见:https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html
- forward(x: Tensor) Tensor[source]¶
- Parameters:
x (torch.Tensor) – 要标准化的输入张量
- Returns:
归一化和缩放的张量,其形状与
x相同。- Return type: