torch.jit._trace 的源代码
"""Tracing."""
This module contains functionality to support the JIT's tracing frontend, notably:
* torch.jit.trace
* torch.jit.trace_module
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import contextlib
import copy
import functools
import inspect
import os
import re
import warnings
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar
from typing_extensions import ParamSpec
import torch
from torch._jit_internal import (
_qualified_name,
get_callable_argument_names,
is_scripting,
)
from torch.autograd import function
from torch.jit._script import _CachedForward, script, ScriptModule
from torch.jit._state import _enabled, _python_cu
from torch.nn import Module
from torch.testing._comparison import default_tolerances
_flatten = torch._C._jit_flatten
_unflatten = torch._C._jit_unflatten
R = TypeVar("R", covariant=True) # return type (always covariant)
P = ParamSpec("P")
def _create_interpreter_name_lookup_fn(frames_up=1):
def _get_interpreter_name_for_var(var):
frame = inspect.currentframe()
if not frame:
raise RuntimeError("failed to inspect frame")
i = 0
while i < frames_up + 1:
frame = frame.f_back
if not frame:
raise RuntimeError("failed to get frame")
i += 1
f_locals = frame.f_locals
f_globals = frame.f_globals
for k, v in f_locals.items():
if isinstance(v, torch.Tensor) and var is v:
return k if k != "self" else ""
return ""
return _get_interpreter_name_for_var
def _unique_state_dict(module, keep_vars=False):
# since Parameter.detach() always creates a new torch.Tensor instance,
# id(v) doesn't work with it. So we always get the Parameter or Buffer
# as values, and deduplicate the params using Parameters and Buffers
state_dict = module.state_dict(keep_vars=True)
filtered_dict = type(state_dict)()
seen_ids: Set[int] = set()
for k, v in state_dict.items():
if id(v) in seen_ids:
continue
seen_ids.add(id(v))
if keep_vars:
filtered_dict[k] = v
else:
filtered_dict[k] = v.detach()
return filtered_dict
class ONNXTracedModule(torch.nn.Module):
def __init__(
self,
inner,
strict=True,
force_outplace=False,
return_inputs=False,
return_inputs_states=False,
):
super().__init__()
# inner may be a Module, or it may be an arbitrary callable
# If it's a Module, we get its parameters automatically, which lets
# us avoid a special casing functions versus modules.
self.inner = inner
self.strict = strict
self._force_outplace = force_outplace
self._return_inputs = return_inputs
self._return_inputs_states = return_inputs_states
def forward(self, *args: torch.Tensor):
in_vars, in_desc = _flatten(args)
# NOTE: use full state, because we need it for BatchNorm export
# This differs from the compiler path, which doesn't support it at the moment.
module_state = list(_unique_state_dict(self, keep_vars=True).values</