Shortcuts

torch.ao.quantization.pt2e.export_utils 的源代码

```html
import types

import torch
import torch.nn.functional as F


__all__ = [
    "model_is_exported",
    "_WrapperModule",
]


class _WrapperModule(torch.nn.Module):
    """类用于将可调用对象包装在 :class:`torch.nn.Module` 中。如果你
    正在尝试导出一个可调用对象,请使用此方法。
    """

    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, *args, **kwargs):
        """简单的 forward 方法,只是调用传递给 :meth:`WrapperModule.__init__` 的 ``fn``。"""
        return self.fn(*args, **kwargs)


[docs]def model_is_exported(m: torch.nn.Module) -> bool: """ 如果 `torch.nn.Module` 已导出,则返回 True,否则返回 False (例如,如果模型是 FX 符号化追踪的或根本没有追踪)。 """ return isinstance(m, torch.fx.GraphModule) and any( "val" in n.meta for n in m.graph.nodes )
def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): """ 在模型的训练和评估模式之间切换 dropout 模式。 Dropout 在训练和评估模式下的行为不同。对于导出的模型, 然而,调用 `model.train()` 或 `model.eval()` 不会自动在两种模式之间切换 dropout 行为,因此我们需要手动重写 aten dropout 模式以达到相同的效果。 参见 https://github.com/pytorch/pytorch/issues/103681。 """ # 避免循环依赖 from .utils import get_aten_graph_module # 确保子图匹配是自包含的 m.graph.eliminate_dead_code() m.recompile() for inplace in [False, True]: def dropout_train(x): return F.dropout(x, p=0.5, training=True, inplace=inplace) def dropout_eval(x): return F.dropout(x, p=0.5, training=False, inplace=inplace) example_inputs = (torch.randn(1),) if train_to_eval: match_pattern = get_aten_graph_module( _WrapperModule(dropout_train), example_inputs ) replacement_pattern = get_aten_graph_module( _WrapperModule(dropout_eval), example_inputs ) else: match_pattern = get_aten_graph_module( _WrapperModule(dropout_eval), example_inputs ) replacement_pattern = get_aten_graph_module( _WrapperModule(dropout_train), example_inputs ) from torch.fx.subgraph_rewriter import replace_pattern_with_filters replace_pattern_with_filters( m, match_pattern, replacement_pattern, match_filters=[], ignore_literals=True, ) m.recompile() def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): """ 在模型的训练和评估模式之间切换 batchnorm 模式。 Batchnorm 在训练和评估模式下的行为不同。对于导出的模型, 然而,调用 `model.train()` 或 `model.eval()` 不会自动在两种模式之间切换 batchnorm 行为,因此我们需要手动重写 aten batchnorm 模式以达到相同的效果。 """ # TODO(Leslie): 此函数仍无法支持自定义动量和 eps 值。 # 在未来的更新中启用此支持。 # 避免循环依赖 from .utils import get_aten_graph_module # 确保子图匹配是自包含的 m.graph.eliminate_dead_code() m.recompile() def bn_train( x: torch.Tensor, bn_weight: torch.Tensor, bn_bias: torch.Tensor, bn_running_mean: torch.Tensor, bn_running_var: torch.Tensor, ): return F.batch_norm( x, bn_running_mean, <span class="n
优云智算