Shortcuts

TorchScript 语言参考

TorchScript 是一种静态类型的 Python 子集,可以直接编写(使用 @torch.jit.script 装饰器)或通过跟踪从 Python 代码自动生成。当使用跟踪时,代码会通过仅记录张量上的实际操作符并简单地执行和丢弃其他周围的 Python 代码,自动转换为这个 Python 子集。

当直接使用 @torch.jit.script 装饰器编写 TorchScript 时,程序员必须仅使用 TorchScript 支持的 Python 子集。本节文档记录了 TorchScript 中支持的内容,就好像它是一个独立语言的语言参考。本参考中未提及的任何 Python 特性都不属于 TorchScript。请参阅 内置函数 以获取可用 PyTorch 张量方法、模块和函数的完整参考。

作为 Python 的一个子集,任何有效的 TorchScript 函数也是一个有效的 Python 函数。这使得可以 禁用 TorchScript 并使用标准的 Python 工具(如 pdb)进行调试。反之则不然:有许多有效的 Python 程序不是有效的 TorchScript 程序。相反,TorchScript 专注于 Python 中用于表示 PyTorch 神经网络模型的特性。

类型

TorchScript 与完整 Python 语言之间最大的区别在于,TorchScript 仅支持表达神经网络模型所需的一小部分类型。特别是,TorchScript 支持:

类型

描述

张量

任意数据类型、维度或后端的PyTorch张量

Tuple[T0, T1, ..., TN]

包含子类型 T0, T1 等的元组(例如 Tuple[Tensor, Tensor]

布尔

一个布尔值

int

一个标量整数

浮点数

一个标量浮点数

字符串

一个字符串

列表[T]

所有成员都是类型 T 的列表

可选[T]

一个值,可以是 None 或类型 T

Dict[K, V]

一个键类型为K,值类型为V的字典。只有strintfloat可以作为键类型。

T

一个 TorchScript 类

E

一个 TorchScript 枚举

NamedTuple[T0, T1, ...]

一个 collections.namedtuple 元组类型

Union[T0, T1, ...]

其中一个子类型 T0, T1 等。

与Python不同,TorchScript 函数中的每个变量必须具有单一的静态类型。 这使得优化 TorchScript 函数变得更加容易。

示例(类型不匹配)

import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...

类型不匹配:r 在真分支中被设置为类型 Tensor,在假分支中被设置为类型 int:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- 此处
    return r
并且在以下位置使用:
    else:
        r = 4
    return r
           ~ <--- 此处...

不支持的类型构造

TorchScript 不支持 typing 模块的所有功能和类型。其中一些是更基础的东西,不太可能在将来添加,而其他一些可能会在有足够用户需求的情况下优先添加。

来自 typing 模块的这些类型和特性在 TorchScript 中不可用。

项目

描述

typing.Any

typing.Any 目前正在开发中,但尚未发布

typing.NoReturn

未实现

typing.Sequence

未实现

typing.Callable

未实现

typing.Literal

未实现

typing.ClassVar

未实现

typing.Final

这是支持的 模块属性 类属性注解,但不支持函数

typing.AnyStr

TorchScript 不支持 bytes,因此不使用此类型

typing.overload

typing.overload 目前正在开发中,但尚未发布

类型别名

未实现

名义类型与结构子类型

名义类型正在开发中,但结构类型尚未实现

新类型

不太可能实现

泛型

不太可能实现

此文档中未明确列出的 typing 模块的任何其他功能均不受支持。

默认类型

默认情况下,TorchScript 函数的所有参数都被假定为 Tensor。 要指定 TorchScript 函数的参数为其他类型,可以使用上述列出的类型进行 MyPy 风格的类型注解。

import torch

@torch.jit.script
def foo(x, tup):
    # 类型: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

注意

也可以使用来自 typing 模块的 Python 3 类型提示来注释类型。

import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

假设空列表为 List[Tensor] 和空字典为 Dict[str, Tensor]。要实例化其他类型的空列表或字典, 请使用 Python 3 类型提示

示例(Python 3 的类型注解):

import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # 这注释了列表为 `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

可选类型细化

TorchScript 会在 if 语句的条件中或通过 assert 检查时,对类型为 Optional[T] 的变量进行类型细化,当与 None 进行比较时。 编译器可以推断出与 andornot 结合的多个 None 检查。对于未明确编写的 if 语句的 else 块,也会进行细化。

必须在if语句的条件中进行None检查;将None检查赋值给一个变量并在if语句的条件中使用它,不会细化检查中变量的类型。 只有局部变量会被细化,像self.x这样的属性不会被细化,必须将其赋值给一个局部变量才能被细化。

示例(在参数和局部变量上细化类型):

import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super().__init__()
        # 如果 `z` 是 None,则无法推断其类型,因此必须
        # 指定(如上所示)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # 通过赋值给局部变量来细化属性
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # 通过 `assert` 进行细化
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript 类

警告

TorchScript 类支持是实验性的。目前它最适合用于简单的记录类类型(可以理解为带有方法的 NamedTuple)。

如果 Python 类用 @torch.jit.script 注解,它们可以在 TorchScript 中使用,类似于声明一个 TorchScript 函数的方式:

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

此子集受到限制:

  • 所有函数必须是有效的 TorchScript 函数(包括 __init__())。

  • 类必须是新式类,因为我们使用 __new__() 通过 pybind11 来构造它们。

  • TorchScript 类是静态类型的。成员只能在 __init__() 方法中通过赋值给 self 来声明。

    例如,在 __init__() 方法之外对 self 进行赋值:

    @torch.jit.script
    class Foo:
      def assign_x(self):
        self.x = torch.rand(2, 3)
    

    结果将是:

    运行时错误:
    尝试设置不存在的属性:x。你是否忘记在 __init__() 中初始化它?:
    def assign_x(self):
      self.x = torch.rand(2, 3)
      ~~~~~~~~~~~~~~~~~~~~~~~~ <--- 此处
    
  • 类体中只允许方法定义,不允许其他表达式。

  • 不支持继承或其他任何多态策略,除了从object继承以指定一个新式类。

定义类之后,可以在 TorchScript 和 Python 中像任何其他 TorchScript 类型一样互换使用:

# 声明一个 TorchScript 类
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # 类型: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

TorchScript 枚举

Python枚举可以在TorchScript中使用,无需任何额外的注释或代码:

from enum import Enum


class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y

定义枚举后,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 类型一样互换使用。枚举值的类型必须是 intfloatstr。所有值必须具有相同的类型;不支持枚举值的异构类型。

命名元组

collections.namedtuple 生成的类型可以在 TorchScript 中使用。

import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # 类型: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

可迭代对象

一些函数(例如,zipenumerate)只能操作可迭代类型。 TorchScript 中的可迭代类型包括 Tensor、列表、元组、字典、字符串、 torch.nn.ModuleListtorch.nn.ModuleDict

表达式

支持以下 Python 表达式。

字面量

True
False
None
'string literals'
"string literals"
3  # 解释为整数
3.4  # 解释为浮点数

列表构造

假设空列表的类型为 List[Tensor]。 其他列表字面量的类型是从成员的类型派生而来的。 有关更多详细信息,请参阅 默认类型

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元组构造

(3, 4)
(3,)

字典构造

假设一个空的字典具有类型 Dict[str, Tensor]。 其他字典字面量的类型是从成员的类型派生而来的。 有关更多详细信息,请参见默认类型

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

变量

参见变量解析以了解变量是如何解析的。

my_variable_name

算术运算符

a + b
a - b
a * b
a / b
a ^ b
a @ b

比较运算符

a == b
a != b
a < b
a > b
a <= b
a >= b

逻辑运算符

a  b
a  b
 b

下标和切片

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函数调用

调用 内置函数

torch.rand(3, dtype=torch.int)

调用其他脚本函数:

import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

方法调用

调用内置类型(如张量)的方法:x.mm(y)

在模块上,方法必须先编译后才能被调用。TorchScript 编译器在编译其他方法时会递归地编译它看到的方法。默认情况下,编译从 forward 方法开始。任何被 forward 调用的方法都会被编译,以及这些方法调用的其他方法,依此类推。要在 forward 以外的其他方法开始编译,请使用 @torch.jit.export 装饰器(forward 隐式地被标记为 @torch.jit.export)。

直接调用子模块(例如 self.resnet(input))等同于调用其 forward 方法(例如 self.resnet.forward(input))。

import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # 由于模型中没有任何地方调用 `top_level_method`,编译器
    # 必须显式地被告知编译此方法
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` 将拥有编译后的方法 `forward`、`helper`、
# `top_level_method` 和 `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表达式

x 如果 x > y 否则 y

类型转换

float(ten)
int(3.5)
bool(ten)
str(2)``

访问模块参数

self.my_parameter
self.my_submodule.my_parameter

语句

TorchScript 支持以下类型的语句:

简单赋值

a = b
a += b # 简写形式,表示 a = a + b,不会在 a 上就地操作
a -= b

模式匹配赋值

a, b = tuple_or_list
a, b, *c = a_tuple

多重赋值

a = b, c = tup

If 语句

```python
if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a
```

除了布尔值、浮点数、整数之外,张量也可以在条件语句中使用,并将被隐式转换为布尔值。

While 循环

a = 0
while a < 4:
    print(a)
    a += 1

使用range的for循环

x = 0
for i in range(10):
    x *= i

遍历元组的for循环

这些展开循环,为元组的每个成员生成一个主体。主体必须对每个成员正确进行类型检查。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

遍历常量 nn.ModuleList

要在编译方法中使用nn.ModuleList,必须通过将属性名称添加到类型的__constants__列表中来将其标记为常量。对于遍历nn.ModuleList的循环,将在编译时展开循环体,并使用常量模块列表中的每个成员。

class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super().__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

中断和继续

for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)

返回

返回 a, b

变量解析

TorchScript 支持 Python 变量解析(即作用域)规则的一个子集。局部变量的行为与 Python 中相同,但有一个限制,即变量在函数中的所有路径上必须具有相同的类型。如果在 if 语句的不同分支上变量的类型不同,则在 if 语句结束后使用它是错误的。

同样地,如果一个变量只在函数中的某些路径上定义,则不允许使用该变量。

示例:

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
回溯(最近一次调用最后一次):
  ...
RuntimeError: ...

y 未在 false 分支中定义...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- 此处
    print(y)
并且在使用时:
    if x < 0:
        y = 4
    print(y)
          ~ <--- 此处...

非局部变量在函数定义时在编译时解析为Python值。然后,这些值根据使用Python值中描述的规则转换为TorchScript值。

使用Python值

为了使编写 TorchScript 更加方便,我们允许脚本代码引用周围作用域中的 Python 值。例如,任何时候引用 torch,TorchScript 编译器实际上在函数声明时将其解析为 torch Python 模块。这些 Python 值并不是 TorchScript 的一等公民。相反,它们在编译时被去糖化(de-sugared)为 TorchScript 支持的基本类型。这取决于编译时引用的 Python 值的动态类型。本节描述了在 TorchScript 中访问 Python 值时使用的规则。

函数

TorchScript 可以调用 Python 函数。当逐步将模型转换为 TorchScript 时,这一功能非常有用。可以将模型逐个函数地迁移到 TorchScript,同时保留对 Python 函数的调用。这样,您可以在转换过程中逐步检查模型的正确性。

torch.jit.is_scripting()[源代码]

函数在编译时返回True,否则返回False。这在使用@unused装饰器时特别有用,可以在模型中保留尚未与TorchScript兼容的代码。 .. testcode:

import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
   if torch.jit.is_scripting():
      return torch.linear(x)
   else:
      return unsupported_linear_op(x)
Return type

bool

torch.jit.is_tracing()[源代码]

返回一个布尔值。

在跟踪中返回True(如果在使用torch.jit.trace跟踪代码期间调用了函数),否则返回False

Python模块上的属性查找

TorchScript 可以在模块上查找属性。内置函数torch.add 是通过这种方式访问的。这使得 TorchScript 能够调用在其他模块中定义的函数。

Python定义的常量

TorchScript 还提供了一种使用在 Python 中定义的常量的方法。 这些常量可以用于将超参数硬编码到函数中,或者定义通用常量。有两种方法可以指定一个 Python 值应被视为常量。

  1. 作为模块属性的值被假定为常量:

import math
import torch

@torch.jit.script
def fn():
    return math.pi
  1. ScriptModule 的属性可以通过使用 Final[T] 注释来标记为常量。

import torch
import torch.nn as nn

class Foo(nn.Module):
    # 来自 `typing_extensions` 模块的 `Final` 也可以使用
    a : torch.jit.Final[int]

    def __init__(self):
        super().__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

支持的常量Python类型有

  • int

  • 浮点数

  • 布尔

  • torch.device

  • torch.layout

  • torch.dtype

  • 包含支持类型的元组

  • torch.nn.ModuleList 可以在 TorchScript 循环中使用

模块属性

torch.nn.Parameter 包装器和 register_buffer 可以用于将张量分配给模块。分配给编译模块的其他值,如果可以推断其类型,则将被添加到编译模块中。TorchScript 中所有可用的 类型 都可以用作模块属性。张量属性在语义上与缓冲区相同。空列表和字典的类型以及 None 值无法推断,必须通过 PEP 526 风格的 类注释来指定。如果无法推断类型且未明确注释,则不会将其作为属性添加到生成的 ScriptModule 中。

示例:

from typing import List, Dict

class Foo(nn.Module):
    # `words` 被初始化为一个空列表,因此必须指定其类型
    words: List[str]

    # 如果 `a_dict`(如下)不为空,类型可能会被推断出来,但此注释确保 `some_dict` 将被转换为正确的类型
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super().__init__()
        self.words = []
        self.some_dict = a_dict

        # `int` 类型可以被推断出来
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))