Shortcuts

torch.ao.quantization.fx.custom_config 的源代码

```html
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, _get_quant_type_to_str


__all__ = [
    "ConvertCustomConfig",
    "FuseCustomConfig",
    "PrepareCustomConfig",
    "StandaloneModuleConfigEntry",
]


# TODO: 用这些常量替换所有用法
STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name"
STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class"
FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class"
OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class"
NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name"
NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class"
INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs"
OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs"
PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes"


[docs]@dataclass class StandaloneModuleConfigEntry: # 为在子模块中调用的 prepare 函数设置 qconfig_mapping, # None 表示使用父 qconfig_mapping 的 qconfig qconfig_mapping: Optional[QConfigMapping] example_inputs: Tuple[Any, ...] prepare_custom_config: Optional[PrepareCustomConfig] backend_config: Optional[BackendConfig]
[docs]class PrepareCustomConfig: """ 用于 :func:`~torch.ao.quantization.quantize_fx.prepare_fx` 和 :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx` 的自定义配置。 示例用法:: prepare_custom_config = PrepareCustomConfig() \ .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \ child_prepare_custom_config, backend_config) \ .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \ child_prepare_custom_config, backend_config) \ .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \ .set_non_traceable_module_names(["module2", "module3"]) \ .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \ .set_input_quantized_indexes([0]) \ .set_output_quantized_indexes([0]) \ .set_preserved_attributes(["attr1", "attr2"]) """ def __init__(self): self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {} self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {} self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {} self.non_traceable_module_names: List[str] = [] self.non_traceable_module_classes: List[Type] = [] self.input_quantized_indexes: List[int] = [] self.output_quantized_indexes: List[int] = [] self.preserved_attributes: List[str] = [] def __repr__(self): dict_nonempty = { k: v for k, v in self.__dict__.items() if len(v) > 0 } return f"PrepareCustomConfig({dict_nonempty})"
[docs] def set_standalone_module_name( self, module_name: str, qconfig_mapping: Optional[QConfigMapping], example_inputs: Tuple[Any, ...], prepare_custom_config: Optional[PrepareCustomConfig], backend_config: Optional[BackendConfig]) -> PrepareCustomConfig: """ 为通过 ``module_name`` 标识的独立模块设置配置。 如果 ``qconfig_mapping`` 为 None,则将使用父 ``qconfig_mapping``。 如果 ``prepare_custom_config`` 为 None,则将使用空的 ``PrepareCustomConfig``。 如果 ``backend_config`` 为 None,则将使用父 ``backend_config``。 """ self.standalone_module_names[module_name] = \ StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config) return self
<a class="viewcode-back" href="../../../../../generated/torch.ao.quantization.fx.custom_config.PrepareCustomConfig.html#torch.ao.quantization.
优云智算