Shortcuts

torch.ao.nn.quantized.modules.functional_modules 的源代码

```html
from typing import List

import torch
from torch import Tensor
from torch._ops import ops

__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional']

[docs]class FloatFunctional(torch.nn.Module): r"""用于浮点操作的状态收集器类。 该类的实例可以用来替代某些操作中的``torch.``前缀。请参见下面的示例用法。 .. 注意:: 该类不提供``forward``钩子。相反,您必须使用其中一个底层函数(例如``add``)。 示例:: >>> f_add = FloatFunctional() >>> a = torch.tensor(3.0) >>> b = torch.tensor(4.0) >>> f_add.add(a, b) # 等同于``torch.add(a, b)`` 有效的操作名称: - add - cat - mul - add_relu - add_scalar - mul_scalar """ def __init__(self): super().__init__() self.activation_post_process = torch.nn.Identity() def forward(self, x): raise RuntimeError("FloatFunctional 不打算使用" + "'forward'。请使用底层操作") r"""等同于``torch.add(Tensor, Tensor)``的操作""" def add(self, x: Tensor, y: Tensor) -> Tensor: r = torch.add(x, y) r = self.activation_post_process(r) return r r"""等同于``torch.add(Tensor, float)``的操作""" def add_scalar(self, x: Tensor, y: float) -> Tensor: r = torch.add(x, y) # 注意:此操作未被观察,因为量化操作不需要观察。 return r r"""等同于``torch.mul(Tensor, Tensor)``的操作""" def mul(self, x: Tensor, y: Tensor) -> Tensor: r = torch.mul(x, y) r = self.activation_post_process(r) return r r"""等同于``torch.mul(Tensor, float)``的操作""" def mul_scalar(self, x: Tensor, y: float) -> Tensor: r = torch.mul(x, y) # 注意:此操作未被观察,因为量化操作不需要观察。 return r r"""等同于``torch.cat``的操作""" def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: r = torch.cat(x, dim=dim) r = self.activation_post_process(r) return r r"""等同于``relu(torch.add(x,y))``的操作""" def add_relu(self, x: Tensor, y: Tensor) -> Tensor: r = torch.add(x, y) r = torch.nn.functional.relu(r) r = self.activation_post_process(r) return r r"""等同于``torch.matmul(Tensor, Tensor)``的操作""" def matmul(self, x: Tensor, y: Tensor) -> Tensor: r = torch.matmul(x, y) r = self.activation_post_process(r) return r
[docs]class FXFloatFunctional(torch.nn.Module): r"""在FX图模式量化之前替换FloatFunctional模块的模块, 因为activation_post_process将直接插入到顶级模块中 有效的操作名称: - add - cat - mul - add_relu - add_scalar - mul_scalar """ def forward(self, x