Shortcuts

torch.fx.graph_module 的源代码

```python import contextlib import copy import itertools import linecache import os import sys import traceback import warnings from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Type, Union import torch import torch.nn as nn import torch.overrides from torch.nn.modules.module import _addindent from torch.package import Importer, PackageExporter, PackageImporter, sys_importer from ._compatibility import compatibility from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode __all__ = [ "reduce_graph_module", "reduce_package_graph_module", "reduce_deploy_graph_module", "GraphModule", ] _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" # Normal exec loses the source code, however we can work with # the linecache module to recover it. # Using _exec_with_source will add it to our local cache # and then tools like TorchScript will be able to get source info. class _EvalCacheLoader: def __init__(self): self.eval_cache = {} self.next_id = 0 def cache(self, src: str, globals: Dict[str, Any], co_fields=None): """Store the source in a private cache, and add a lazy entry in linecache that allows the source to be retrieved by 'filename'. Args: src (str): The module source to cache globals (dict): The module globals Returns: str: The cache key (and dummy filename) generated for src. """ key = self._get_key() if co_fields: key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" self.eval_cache[key] = src # Don't mutate globals so that this loader is only used # to populate linecache, and doesn't interact with other modules # that might check `__loader__` globals_copy = globals.copy() globals_copy["__file__"] = key globals_copy["__name__"] = key globals_copy["__loader__"] = self linecache.lazycache(key, globals_copy) return key # Part of the loader protocol (PEP 302) # linecache will use this method when trying to find source code def get_source(self, module_name) -> Optional[str]: if module_name in self.eval_cache: return self.eval_cache[module_name] return None def _get_key(self): key = f".{self.next_id}" self.next_id += 1 return key _loader = _EvalCacheLoader() def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None): key = _loader.cache(src, globals, co_fields) exec(compile(src, key, "exec"), globals) def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None): return _method_from_src( method_name="forward", src=src, globals=globals, co_fields=co_fields ) def _method_from_src( method_name: str, src: str, globals: Dict[str, Any], co_fields=None ) -> Callable: # avoid mutating the passed in dict globals_copy = globals.copy() _exec_with_source(src, globals_copy, co_fields) fn = globals_copy[method_name] del globals_copy[method_name] return fn def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: if name in _custom_builtins: return _custom_builtins[name].import_str if _is_from_torch(name): return "import torch" module_name, attr_name = importer.get_name(obj) return f"from {module_name} import {attr_name} as {name}" def _format_import_block(globals: Dict[str, Any], importer: Importer): import_strs: Set[str] = set() for name, obj in globals.items(): import_strs.add(_format_import_statement(name, obj, importer)) # Sort the imports so we have a stable import block that allows us to # hash the graph module and get a consistent key for use in a cache. return "\n".join(sorted(import_strs)) @compatibility(is_backward_compatible=True) def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: # BC: attribute name was changed from `code` to `_code` to facilitate # making `code` into a property and adding a docstring to it fn_src = body.get("_code") or body["code"] forward = _forward_from_src(import_block + fn_src, {}) return _deserialize_graph_module(forward, body) @compatibility(is_backward_compatible=True) def reduce_package_graph_module( importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str ) -> torch.nn.Module: forward = importer.import_module(generated_module_name).forward return _deserialize_graph_module(forward, body) @compatibility(is_backward_compatible=True) def reduce_deploy_graph_module( importer: PackageImporter, body: Dict[Any, Any], import_block: str ) -> torch.nn.Module: ns = {} ns["__builtins__"] = importer.patched_builtins fn_src = body.get("_code") assert fn_src is not None forward = _forward_from_src(import_block + fn_src, ns) return _deserialize_graph_module(forward, body) # We create a dummy class here because symbolic_trace pulls the forward() # function off of the class, rather than the instance. This class is used # in _deserialize_graph_module() below. class _CodeOnlyModule(torch.nn.Module): def __init__(self, body): super().__init__() self.__dict__ = body def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module: """ Deserialize a GraphModule given the dictionary of the original module, using the code to reconstruct the graph. We delete the actual graph before saving the dictionary so that changes to the in-memory graph format do not get serialized. """ # Try to retrieve the forward source in a backward-compatible way _CodeOnlyModule.forward = forward tracer_cls = body.get("_tracer_cls") if tracer_cls is None: from ._symbolic_trace import Tracer tracer_cls = Tracer graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule") # This is a workaround for a mypy linter issue related to # passing base class as an argument - https://github.com/python/mypy/issues/5865. cls_tracer: Any = tracer_cls class KeepModules(cls_tracer): # we shouldn't trace into any of the submodules, # because they were not traced in the original GraphModule def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: return True com = _CodeOnlyModule(body) tracer_extras = body.get("_tracer_extras", {}) graph = KeepModules().trace(com, **tracer_extras) # Manually set Tracer class on the reconstructed Graph, to avoid # referencing the private local subclass KeepModules. graph._tracer_cls = tracer_cls from ._lazy_graph_module import _make_graph_module gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls) # The GraphModule constructor only retains attributes referenced by the graph. # In this case, our goal is return a GraphModule as close to identical as the one # put into the package. If any additional attributes were present in body, # we should keep them. for k, v in body.items(): if not hasattr(gm, k): setattr(gm, k, v) return gm # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' # This installs empty Modules where none exist yet if they are subpaths of target def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): *prefix, field = target.split(".") for item in prefix: f = getattr(from_module, item) t = getattr(to_module, item, None) if f is t: # we have already installed one of its parents # (e.g. target = root.linear.weight, but we have already installed root.linear) # once we install a parent, we no longer need to copy the children # since all the needed properties will already be present return if t is None: t = torch.nn.Module() setattr(to_module, item, t) from_module, to_module = f, t orig = getattr(from_module, field) # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. # So, we register it as a named buffer in the target module. if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): to_module.register_buffer(field, orig) else: setattr(to_module, field, orig) # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module # This installs empty Modules where none exist yet if they are subpaths of target def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): *prefix, field = target.split(".") for item in prefix: t = getattr(to_module, item, None) if t is None: t = torch.nn.Module() setattr(to_module, item, t) to_module = t # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. # So, we register it as a named buffer in the target module. if isinstance(from_obj, torch.Tensor) and not isinstance( from_obj, torch.nn.Parameter ): to_module.register_buffer(field, from_obj) else: setattr(to_module, field, from_obj) class _WrappedCall: def __init__(self, cls, cls_call): self.cls = cls self.cls_call = cls_call # Previously, if an error occurred when valid # symbolically-traced code was run with an invalid input, the # user would see the source of the error as coming from # `File "`, where N is some number. We use # this function to generate a more informative error message. We # return the traceback itself, a message explaining that the # error occurred in a traced Module's generated forward # function, and five lines of context surrounding the faulty # line @staticmethod def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: # auxiliary variables (for readability) err_lineno = frame_summary.lineno assert err_lineno is not None line = frame_summary.line assert line is not None err_line_len = len(line) all_src_lines = linecache.getlines(frame_summary.filename) # constituent substrings of the error message tb_repr = traceback.format_exc() custom_msg = ( "Call using an FX-traced Module, " f"line {err_lineno} of the traced Module's " "generated forward function:" ) before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) marker = "~" * err_line_len + "~~~ <--- HERE" err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) # joined message return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) def __call__(self, obj, *args, **kwargs): try: if self.cls_call is not None: return self.cls_call(obj, *args, **kwargs) else: return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] except Exception as e: assert e.__traceback__ topmost_framesummary: traceback.FrameSummary = ( traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] ) # type: ignore[arg-type] if "eval_with_key" in topmost_framesummary.filename: print( _WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr, ) raise e.with_traceback(None) # noqa: TRY200 else: raise e @compatibility(is_backward_compatible=True) class GraphModule(torch.nn.Module): """ GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated from that ``graph``. .. warning:: When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically regenerated. However, if you edit the contents of the ``graph`` without reassigning the ``graph`` attribute itself, you must call ``recompile()`` to update the generated code. """ def __new__(cls: "Type[GraphModule]", *args, **kwargs): # each instance of a graph module needs its own forward method # so create a new singleton class for each instance. # it is a subclass of the user-defined class, the only difference # is an extra layer to install the forward method # address issue described at https://github.com/pytorch/pytorch/issues/63883 # in other words, traverse class hierarchy to fix the redundant class definition problem for t in cls.__mro__: c = t.__qualname__.split(".")[-1] if c != "GraphModuleImpl": cls = t break class GraphModuleImpl(cls): # type: ignore[misc, valid-type] pass return super().__new__(GraphModuleImpl) @compatibility(is_backward_compatible=True) def __init__( self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule", ): """ Construct a GraphModule. Args: root (Union[torch.nn.Module, Dict[str, Any]): ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. In the case that ``root`` is a Module, any references to Module-based objects (via qualified name) in the Graph's Nodes' ``target`` field will be copied over from the respective place within ``root``'s Module hierarchy into the GraphModule's module hierarchy. In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be looked up directly in the dict's keys. The object mapped to by the Dict will be copied over into the appropriate place within the GraphModule's module hierarchy. graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation class_