Shortcuts

torch.jit._script 的源代码

"""TorchScript."""

This module contains functionality to support the JIT's scripting frontend, notably:
    - torch.jit.script

This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import collections
import copy
import enum
import functools
import inspect
import pickle
import warnings
from typing import Any, Callable, Dict, List, Set, Tuple, Union

import torch
import torch._jit_internal as _jit_internal
from torch._classes import classes
from torch._jit_internal import _qualified_name
from torch.jit._builtins import _register_builtin
from torch.jit._fuser import _graph_for, _script_method_graph_for

from torch.jit._monkeytype_config import (
    JitTypeTraceConfig,
    JitTypeTraceStore,
    monkeytype_trace,
)
from torch.jit._recursive import (
    _compile_and_register_class,
    infer_methods_to_compile,
    ScriptMethodStub,
    wrap_cpp_module,
)
from torch.jit._state import (
    _enabled,
    _set_jit_function_cache,
    _set_jit_overload_cache,
    _try_get_jit_cached_function,
    _try_get_jit_cached_overloads,
)
from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def
from torch.nn import Module
from torch.overrides import (
    has_torch_function,
    has_torch_function_unary,
    has_torch_function_variadic,
)
from torch.package import PackageExporter, PackageImporter
from torch.utils import set_module
from ._serialization import validate_map_location

type_trace_db = JitTypeTraceStore()  # DB to hold all call traces from MonkeyType

torch._C.ScriptMethod.graph_for = _script_method_graph_for  # type: ignore[attr-defined]
torch._C.ScriptFunction.graph_for = _graph_for  # type: ignore[attr-defined]
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")


# Throws an error if a jit function is pickled.
# Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument.
def _reduce(cls):
    raise pickle.PickleError("ScriptFunction cannot be pickled")


ScriptFunction.__reduce__ = _reduce  # type: ignore[assignment]


if _enabled:
    Attribute = collections.namedtuple("Attribute", ["value", "type"])
else:

[docs] def Attribute(value, type): # type: ignore[no-redef] return value
Attribute.__doc__ = """ This method is a pass-through function that returns `value`, mostly used to indicate to the TorchScript compiler that the left-hand side expression is a class instance attribute with type of `type`. Note that `torch.jit.Attribute` should only be used in `__init__` method of `jit.ScriptModule` subclasses. Though TorchScript can infer correct type for most Python expressions, there are some cases where type inference can be wrong, including: - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor` - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume it is type `T` rather than `Optional[T]` In eager mode, it is simply a pass-through function that returns `value` without other implications. Example: .. testcode:: import torch from typing import Dict class AttributeModule(torch.jit.ScriptModule): def __init__(self): super().__init__() self.foo = torch.jit.Attribute(0.1, float) # we should be able to use self.foo as a float here assert 0.0 < self.foo self.names_ages = torch.jit.Attribute({}, Dict[str, int]) self.names_ages["someone"] = 20 assert isinstance(self.names_ages["someone"], int) m = AttributeModule() # m will contain two attributes # 1. foo of type float # 2. names_ages of type Dict[str, int] .. testcleanup:: del AttributeModule del m Note: it's now preferred to instead use type annotations instead of `torch.jit.Attribute`: .. testcode:: import torch from typing import Dict class AttributeModule(torch.nn.Module): names: Dict[str, int] def __init__(self): super().__init__() self.names = {} m = AttributeModule() .. testcleanup:: del AttributeModule del m Args: value: An initial value to be assigned to attribute. type: A Python type Returns: Returns `value` """ def _get_type_trace_db(): # This is a private API. Use of this for external purposes is discouraged. return type_trace_db # Gets a function from the name of a method on a type def _get_function_from_type(cls, name): return getattr(cls, name, None) # ScriptClasses must be new-style classes because we construct them using their # __new__ method.</span