Shortcuts

torch.ao.nn.quantized.modules.normalization 的源代码

```html
import torch

__all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']

[docs]class LayerNorm(torch.nn.LayerNorm): r"""这是 :class:`~torch.nn.LayerNorm` 的量化版本。 额外参数: * **scale** - 输出的量化比例,类型:double。 * **zero_point** - 输出的量化零点,类型:long。 """ def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5, elementwise_affine=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, **factory_kwargs) self.weight = weight self.bias = bias self.register_buffer('scale', torch.tensor(scale, **factory_kwargs)) self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): return torch.ops.quantized.layer_norm( input, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point) def _get_name(self): return 'QuantizedLayerNorm' @classmethod def from_float(cls, mod): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.elementwise_affine) return new_mod @classmethod def from_reference(cls, mod, scale, zero_point): return cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.elementwise_affine)
[docs]class GroupNorm(torch.nn.GroupNorm): r"""这是 :class:`~torch.nn.GroupNorm` 的量化版本。 额外参数: * **scale** - 输出的量化比例,类型:double。 * **zero_point** - 输出的量化零点,类型:long。 """ __constants__ = ['num_groups', 'num_channels', 'eps', 'affine'] def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5, affine=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) self.weight = weight self.bias = bias self.register_buffer('scale', torch.tensor(scale, **factory_kwargs)) self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs)) def forward(<span