torch.jit._script 的源代码
"""TorchScript."""
This module contains functionality to support the JIT's scripting frontend, notably:
- torch.jit.script
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import collections
import copy
import enum
import functools
import inspect
import pickle
import warnings
from typing import Any, Callable, Dict, List, Set, Tuple, Union
import torch
import torch._jit_internal as _jit_internal
from torch._classes import classes
from torch._jit_internal import _qualified_name
from torch.jit._builtins import _register_builtin
from torch.jit._fuser import _graph_for, _script_method_graph_for
from torch.jit._monkeytype_config import (
JitTypeTraceConfig,
JitTypeTraceStore,
monkeytype_trace,
)
from torch.jit._recursive import (
_compile_and_register_class,
infer_methods_to_compile,
ScriptMethodStub,
wrap_cpp_module,
)
from torch.jit._state import (
_enabled,
_set_jit_function_cache,
_set_jit_overload_cache,
_try_get_jit_cached_function,
_try_get_jit_cached_overloads,
)
from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def
from torch.nn import Module
from torch.overrides import (
has_torch_function,
has_torch_function_unary,
has_torch_function_variadic,
)
from torch.package import PackageExporter, PackageImporter
from torch.utils import set_module
from ._serialization import validate_map_location
type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined]
torch._C.ScriptFunction.graph_for = _graph_for # type: ignore[attr-defined]
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")
# Throws an error if a jit function is pickled.
# Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument.
def _reduce(cls):
raise pickle.PickleError("ScriptFunction cannot be pickled")
ScriptFunction.__reduce__ = _reduce # type: ignore[assignment]
if _enabled:
Attribute = collections.namedtuple("Attribute", ["value", "type"])
else:
[docs] def Attribute(value, type): # type: ignore[no-redef]
return value
Attribute.__doc__ = """
This method is a pass-through function that returns `value`, mostly
used to indicate to the TorchScript compiler that the left-hand side
expression is a class instance attribute with type of `type`. Note that
`torch.jit.Attribute` should only be used in `__init__` method of `jit.ScriptModule`
subclasses.
Though TorchScript can infer correct type for most Python expressions, there are some cases where
type inference can be wrong, including:
- Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
- Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
it is type `T` rather than `Optional[T]`
In eager mode, it is simply a pass-through function that returns `value`
without other implications.
Example:
.. testcode::
import torch
from typing import Dict
class AttributeModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.foo = torch.jit.Attribute(0.1, float)
# we should be able to use self.foo as a float here
assert 0.0 < self.foo
self.names_ages = torch.jit.Attribute({}, Dict[str, int])
self.names_ages["someone"] = 20
assert isinstance(self.names_ages["someone"], int)
m = AttributeModule()
# m will contain two attributes
# 1. foo of type float
# 2. names_ages of type Dict[str, int]
.. testcleanup::
del AttributeModule
del m
Note: it's now preferred to instead use type annotations instead of `torch.jit.Attribute`:
.. testcode::
import torch
from typing import Dict
class AttributeModule(torch.nn.Module):
names: Dict[str, int]
def __init__(self):
super().__init__()
self.names = {}
m = AttributeModule()
.. testcleanup::
del AttributeModule
del m
Args:
value: An initial value to be assigned to attribute.
type: A Python type
Returns:
Returns `value`
"""
def _get_type_trace_db():
# This is a private API. Use of this for external purposes is discouraged.
return type_trace_db
# Gets a function from the name of a method on a type
def _get_function_from_type(cls, name):
return getattr(cls, name, None)
# ScriptClasses must be new-style classes because we construct them using their
# __new__ method.</span