torch.jit 的源代码
import warnings
from contextlib import contextmanager
from typing import Any, Iterator
import torch._C
# 这些是导入的,以便用户可以从 `torch.jit` 模块访问它们
from torch._jit_internal import (
_Await,
_drop,
_IgnoreContextManager,
_isinstance,
_overload,
_overload_method,
export,
Final,
Future,
ignore,
is_scripting,
unused,
)
from torch.jit._async import fork, wait
from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait
from torch.jit._decomposition_utils import _register_decomposition
from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
from torch.jit._fuser import (
fuser,
last_executed_optimized_graph,
optimized_execution,
set_fusion_strategy,
)
from torch.jit._ir_utils import _InsertPoint
from torch.jit._script import (
_ScriptProfile,
_unwrap_optional,
Attribute,
CompilationUnit,
interface,
RecursiveScriptClass,
RecursiveScriptModule,
script,
script_method,
ScriptFunction,
ScriptModule,
ScriptWarning,
)
from torch.jit._serialization import (
jit_module_from_flatbuffer,
load,
save,
save_jit_module_to_flatbuffer,
)
from torch.jit._trace import (
_flatten,
_get_trace_graph,
_script_if_tracing,
_unique_state_dict,
is_tracing,
ONNXTracedModule,
TopLevelTracedModule,
trace,
trace_module,
TracedModule,
TracerWarning,
TracingCheckError,
)
from torch.utils import set_module
__all__ = [
"Attribute",
"CompilationUnit",
"Error",
"Future",
"ScriptFunction",
"ScriptModule",
"annotate",
"enable_onednn_fusion",
"export",
"export_opnames",
"fork",
"freeze",
"ignore",
"isinstance",
"load",
"onednn_fusion_enabled",
"optimize_for_inference",
"save",
"script",
"script_if_tracing",
"set_fusion_strategy",
"strict_fusion",
"trace",
"trace_module",
"unused",
"wait",
]
# 为了向后兼容
_fork = fork
_wait = wait
_set_fusion_strategy = set_fusion_strategy
def export_opnames(m):
r"""
为 Script 模块生成新的字节码。
返回基于当前代码库的 Script 模块的操作列表。
如果你有一个 LiteScriptModule 并想获取当前存在的操作列表,请调用 _export_operator_list 代替。
"""
return torch._C._export_opnames(m._c)
# torch.jit.Error
Error = torch._C.JITException
set_module(Error, "torch.jit")
# 这并不完美,但在常见情况下有效
Error.__name__ = "Error"
Error.__qualname__ = "Error"
# 用于在 Python 中使用 annotate
[docs]def annotate(the_type, the_value):
"""用于在 TorchScript 编译器中给出 `the_value` 的类型。
此方法是一个透传函数,返回 `the_value`,用于提示 TorchScript 编译器 `the_value` 的类型。在 TorchScript 之外运行时,它是一个空操作。
虽然 TorchScript 可以为大多数 Python 表达式推断出正确的类型,但在某些情况下类型推断可能会出错,包括:
- 空容器,如 `[]` 和 `{}`,TorchScript 假设它们是 `Tensor` 的容器
- 可选类型,如 `Optional[T]`,但分配了一个类型为 `T` 的有效值,TorchScript 会假设它是类型 `T` 而不是 `Optional[T]`
请注意,`annotate()` 在 `torch.nn.Module` 子类的 `__init__` 方法中不起作用,因为它在急切模式下执行。要注释 `torch.nn.Module` 属性的类型,请使用 :meth:`~torch.jit.Annotate`。
示例:
.. testcode::
import torch
from typing import Dict
@torch.jit.script
def fn():
# 告诉 TorchScript 这个空字典是一个 (str -> int) 字典
# 而不是默认的 (str -> Tensor) 字典类型。
d = torch.jit.annotate(Dict[str, int], {})
# 如果没有上面的 `torch.jit.annotate`,下面的语句会因为类型不匹配而失败。
d["name"] = 20
.. testcleanup::
del fn
参数:
the_type: 应传递给 TorchScript 编译器的 Python 类型,作为 `the_value` 的类型提示
the_value: 要提示类型的值或表达式
返回:
`the_value` 作为返回值传递回来。
"""
return the_value
[docs]def script_if_tracing(fn):
"""
在第一次调用时编译 ``fn``。
``torch.jit.script`` 在第一次调用时由于许多编译器内置函数的延迟初始化而有一个不可忽略的启动时间。因此,你不应该在库代码中使用它。然而,你可能希望你的库中的某些部分在追踪时工作,即使它们使用控制流。在这些情况下,你应该使用 ``@torch.jit.