Shortcuts

torch.jit.script

torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)[源代码]

编写函数脚本。

脚本化一个函数或 nn.Module 将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,并返回一个 ScriptModuleScriptFunction。TorchScript 本身是 Python 语言的一个子集,因此并非所有 Python 中的功能都适用,但我们提供了足够多的功能来计算张量并执行依赖于控制的操作。有关完整指南,请参阅 TorchScript 语言参考

脚本化一个字典或列表会将其中数据复制到一个TorchScript实例中,该实例随后可以在Python和TorchScript之间通过引用传递,而无需复制开销。

torch.jit.script can be used as a function for modules, functions, dictionaries and lists

并作为装饰器 @torch.jit.script 用于 TorchScript 类 和函数。

Parameters
  • obj (可调用对象, , 或 nn.Module) – 要编译的nn.Module、函数、类类型、字典或列表。

  • example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]) – 提供示例输入以注释函数的参数或 nn.Module

Returns

如果 objnn.Modulescript 返回一个 ScriptModule 对象。返回的 ScriptModule 将具有与原始 nn.Module 相同的子模块和参数集。如果 obj 是一个独立函数,将返回一个 ScriptFunction。如果 obj 是一个 dict,则 script 返回一个 torch._C.ScriptDict 的实例。如果 obj 是一个 list,则 script 返回一个 torch._C.ScriptList 的实例。

Scripting a function

The @torch.jit.script 装饰器将通过编译函数体来构造一个 ScriptFunction

示例(脚本化函数):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFunction

# 查看编译后的图作为Python代码
print(foo.code)

# 使用TorchScript解释器调用函数
foo(torch.ones(2, 2), torch.ones(2, 2))
**Scripting a function using example_inputs

示例输入可以用于注释函数参数。

示例(在脚本编写前注释函数):

import torch

def test_sum(a, b):
    return a + b

# 将参数注释为整数
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# 查看编译后的图作为Python代码
print(scripted_fn.code)

# 使用TorchScript解释器调用函数
scripted_fn(20, 100)
Scripting an nn.Module

默认情况下,脚本化一个 nn.Module 将编译 forward 方法并递归编译任何由 forward 调用的方法、子模块和函数。如果一个 nn.Module 仅使用 TorchScript 支持的功能,则无需对原始模块代码进行任何更改。script 将构建一个 ScriptModule,该模块包含原始模块的属性、参数和方法的副本。

示例(使用参数脚本编写简单模块):

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super().__init__()
        # 这个参数将被复制到新的 ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # 当这个子模块被使用时,它将被编译
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # 这将调用 `nn.Linear` 模块的 `forward` 方法,这将导致 `self.linear` 子模块在这里被编译为 `ScriptModule`
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

示例(使用跟踪子模块编写模块脚本):

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        # torch.jit.trace 生成 ScriptModule 的 conv1 和 conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

要编译除forward之外的方法(并递归编译它调用的任何内容),请在方法上添加@torch.jit.export装饰器。要选择退出编译,请使用@torch.jit.ignore@torch.jit.unused

示例(模块中导出并忽略的方法):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # 这个函数不会被编译,所以可以使用任何
        # Python API
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

示例(使用 example_inputs 注释 nn.Module 的前向传播):

import torch
import torch.nn as nn
from typing import NamedTuple

class MyModule(NamedTuple):
result: List[int]

class TestNNModule(torch.nn.Module):
    def forward(self, a) -> MyModule:
        result = MyModule(result=a)
        return result

pdt_model = TestNNModule()

# 在急切模式下运行pdt_model,并使用提供的输入注释forward的参数
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })

# 使用实际输入运行scripted_model
print(scripted_model([20]))