TorchScript¶
TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。 任何 TorchScript 程序都可以从 Python 进程中保存,并在没有 Python 依赖的进程中加载。
我们提供了工具,可以逐步将模型从一个纯Python程序过渡到一个可以在Python之外独立运行的TorchScript程序,例如在独立的C++程序中。这使得可以在PyTorch中使用熟悉的Python工具训练模型,然后通过TorchScript将模型导出到生产环境中,因为在性能和多线程方面,Python程序可能存在不利因素。
有关TorchScript的温和介绍,请参阅TorchScript简介教程。
有关将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行的端到端示例,请参阅 在 C++ 中加载 PyTorch 模型教程。
创建 TorchScript 代码¶
脚本 |
编写函数脚本。 |
trace |
跟踪一个函数并返回一个可执行的或 |
script_if_tracing |
在跟踪期间首次调用时编译 |
trace_module |
跟踪一个模块并返回一个可执行的 |
fork |
创建一个异步任务执行func,并引用该执行结果的值。 |
wait |
强制完成一个 torch.jit.Future[T] 异步任务,返回任务的结果。 |
ScriptModule |
用于C++ torch::jit::Module的包装器,包含方法、属性和参数。 |
ScriptFunction |
功能上等同于一个 |
freeze |
冻结 ScriptModule、内联子模块和属性为常量。 |
optimize_for_inference |
执行一系列优化过程,以优化模型以用于推理目的。 |
enable_onednn_fusion |
启用或禁用基于参数enabled的onednn JIT融合。 |
onednn_fusion_enabled |
返回是否启用了onednn JIT融合。 |
set_fusion_strategy |
设置融合过程中可以发生的专业化和数量。 |
strict_fusion |
如果在推理中没有融合所有节点,或在训练中没有进行符号化微分,则给出错误。 |
保存 |
保存此模块的离线版本,以便在单独的进程中使用。 |
加载 |
加载之前使用 |
忽略 |
这个装饰器向编译器指示应忽略某个函数或方法,并将其保留为Python函数。 |
未使用 |
这个装饰器向编译器指示应忽略一个函数或方法,并将其替换为引发异常。 |
接口 |
装饰以注释不同类型的类或模块。 |
isinstance |
在TorchScript中提供容器类型细化。 |
属性 |
此方法是一个透传函数,返回值,主要用于向TorchScript编译器指示左侧表达式是类型为类型的类实例属性。 |
annotate |
用于在TorchScript编译器中指定the_value的类型。 |
混合追踪和脚本¶
在许多情况下,跟踪或脚本化是转换模型为TorchScript的更简单方法。 跟踪和脚本化可以组合使用,以适应模型部分的特定需求。
脚本化函数可以调用跟踪的函数。这在需要围绕简单的前馈模型使用控制流时特别有用。例如,序列到序列模型的束搜索通常会以脚本形式编写,但可以调用使用跟踪生成的编码器模块。
示例(在脚本中调用跟踪函数):
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
跟踪的函数可以调用脚本函数。这在模型的大部分只是一个前馈网络,但其中一小部分需要一些控制流时非常有用。在跟踪函数调用的脚本函数内部,控制流被正确地保留。
示例(在跟踪函数中调用脚本函数):
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
def bar(x, y, z):
return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
这种组合也适用于 nn.Module,它可用于通过追踪生成一个子模块,该子模块可以从脚本模块的方法中调用。
示例(使用跟踪模块):
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
TorchScript 语言¶
TorchScript 是 Python 的一个静态类型子集,因此许多 Python 特性可以直接应用于 TorchScript。详情请参阅完整的 TorchScript 语言参考。
内置函数和模块¶
TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置函数。 请参阅 TorchScript 内置函数 以获取支持的函数的完整参考。
PyTorch 函数和模块¶
TorchScript 支持 PyTorch 提供的一部分张量和神经网络函数。大多数 Tensor 上的方法以及 torch 命名空间中的函数,torch.nn.functional 中的所有函数,以及 torch.nn 中的大多数模块在 TorchScript 中都受支持。
请参阅TorchScript 不支持的 PyTorch 构造以获取不支持的 PyTorch 函数和模块列表。
Python 函数和模块¶
Python 的许多 内置函数 在 TorchScript 中都得到了支持。
math 模块也得到了支持(详见 math 模块),但其他 Python 模块(内置或第三方)均不受支持。
Python 语言参考比较¶
有关支持的 Python 功能的完整列表,请参阅 Python 语言参考覆盖范围。
调试¶
禁用JIT以进行调试¶
- PYTORCH_JIT¶
设置环境变量 PYTORCH_JIT=0 将禁用所有脚本和追踪注解。如果在您的 TorchScript 模型中出现难以调试的错误,您可以使用此标志强制所有内容使用原生 Python 运行。由于使用此标志禁用了 TorchScript(脚本和追踪),您可以使用诸如 pdb 之类的工具来调试模型代码。例如:
@torch.jit.script
def scripted_fn(x : torch.Tensor):
for i in range(12):
x = x + x
return x
def fn(x):
x = torch.neg(x)
import pdb; pdb.set_trace()
return scripted_fn(x)
traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
使用 pdb 调试此脚本时,除了调用
@torch.jit.script 函数时,其他情况都可以正常工作。我们可以全局禁用
JIT,这样我们就可以将 @torch.jit.script
函数作为普通 Python 函数调用,而不会对其进行编译。如果上述脚本
被命名为 disable_jit_example.py,我们可以这样调用它:
$ PYTORCH_JIT=0 python disable_jit_example.py
并且我们将能够像普通Python函数一样进入@torch.jit.script函数。要禁用特定函数的TorchScript编译器,请参阅@torch.jit.ignore。
检查代码¶
TorchScript 提供了一个代码美化器,适用于所有 ScriptModule 实例。这个美化器将脚本方法的代码解释为有效的 Python 语法。例如:
@torch.jit.script
def foo(len):
# 类型: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.code)
一个具有单个 forward 方法的 ScriptModule 将具有一个 code 属性,您可以使用它来检查 ScriptModule 的代码。
如果 ScriptModule 有多个方法,您需要访问方法本身的 .code,而不是模块。我们可以通过访问 .foo.code 来检查 ScriptModule 上名为 foo 的方法的代码。
上面的示例产生以下输出:
def foo(len: int) -> Tensor:
rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
rv0 = rv
for i in range(len):
if torch.lt(i, 10):
rv1 = torch.sub(rv0, 1., 1)
else:
rv1 = torch.add(rv0, 1., 1)
rv0 = rv1
return rv0
这是TorchScript对forward方法代码的编译。
您可以使用它来确保TorchScript(跟踪或脚本化)已正确捕获您的模型代码。
解释图表¶
TorchScript 还具有比代码美化器更底层的表示形式,即 IR 图。
TorchScript 使用静态单赋值(SSA)中间表示(IR)来表示计算。这种格式的指令由 ATen(PyTorch 的 C++ 后端)运算符和其他基本运算符组成,包括用于循环和条件语句的控制流运算符。例如:
@torch.jit.script
def foo(len):
# 类型: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.graph)
graph 遵循与 检查代码 部分中描述的关于 forward 方法查找相同的规则。
上面的示例脚本生成图表:
graph(%len.1 : int):
%24 : int = prim::Constant[value=1]()
%17 : bool = prim::Constant[value=1]() # test.py:10:5
%12 : bool? = prim::Constant()
%10 : Device? = prim::Constant()
%6 : int? = prim::Constant()
%1 : int = prim::Constant[value=3]() # test.py:9:22
%2 : int = prim::Constant[value=4]() # test.py:9:25
%20 : int = prim::Constant[value=10]() # test.py:11:16
%23 : float = prim::Constant[value=1]() # test.py:12:23
%4 : int[] = prim::ListConstruct(%1, %2)
%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
%rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
block0(%i.1 : int, %rv.14 : Tensor):
%21 : bool = aten::lt(%i.1, %20) # test.py:11:12
%rv.13 : Tensor = prim::If(%21) # test.py:11:9
block0():
%rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
-> (%rv.3)
block1():
%rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
-> (%rv.6)
-> (%17, %rv.13)
return (%rv)
以指令 %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 为例。
%rv.1 : Tensor意味着我们将输出赋值给一个(唯一的)名为rv.1的值,该值的类型为Tensor,并且我们不知道它的具体形状。aten::zeros是操作符(等同于torch.zeros),输入列表(%4, %6, %6, %10, %12)指定了作用域中哪些值应作为输入传递。内置函数如aten::zeros的模式可以在 内置函数 中找到。# test.py:9:10是生成此指令的原始源文件中的位置。在这种情况下,它是一个名为 test.py 的文件,在第9行,第10个字符处。
请注意,运算符也可以有相关的块,即prim::Loop和prim::If运算符。在图的打印输出中,这些运算符的格式是为了反映它们等效的源代码形式,以便于轻松调试。
可以按照所示的方式检查图表,以确认由ScriptModule描述的计算是正确的,无论是自动方式还是手动方式,如下所述。
追踪器¶
追踪边缘案例¶
存在一些边缘情况,其中给定的Python函数/模块的跟踪可能无法代表底层代码。这些情况可能包括:
依赖于输入的控制流跟踪(例如,张量形状)
张量视图的就地操作跟踪(例如,赋值左侧的索引)
请注意,这些情况在未来实际上可能是可追踪的。
自动跟踪检查¶
自动捕获跟踪中的许多错误的一种方法是使用 check_inputs
在 torch.jit.trace() API 上。check_inputs 接受一个输入元组列表,
这些输入将用于重新跟踪计算并验证结果。例如:
def loop_in_traced_fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
提供以下诊断信息:
错误:在不同的调用中图表不一致!
图表差异:
graph(%x : Tensor) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %2)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=0]()
%6 : Tensor = aten::select(%x, %4, %5)
%result.2 : Tensor = aten::mul(%result.1, %6)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=1]()
%10 : Tensor = aten::select(%x, %8, %9)
- %result : Tensor = aten::mul(%result.2, %10)
+ %result.3 : Tensor = aten::mul(%result.2, %10)
? ++
%12 : int = prim::Constant[value=0]()
%13 : int = prim::Constant[value=2]()
%14 : Tensor = aten::select(%x, %12, %13)
+ %result : Tensor = aten::mul(%result.3, %14)
+ %16 : int = prim::Constant[value=0]()
+ %17 : int = prim::Constant[value=3]()
+ %18 : Tensor = aten::select(%x, %16, %17)
- %15 : Tensor = aten::mul(%result, %14)
? ^ ^
+ %19 : Tensor = aten::mul(%result, %18)
? ^ ^
- return (%15);
? ^
+ return (%19);
? ^
}
这条消息告诉我们,在第一次追踪和使用check_inputs追踪时,计算结果有所不同。实际上,loop_in_traced_fn函数体内的循环依赖于输入x的形状,因此当我们尝试使用不同形状的另一个x时,追踪结果会有所不同。
在这种情况下,可以使用
torch.jit.script() 来捕获这种数据依赖的控制流:
def fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
# 打印(str(scripted_fn.graph).strip())
for input_tuple in [inputs] + check_inputs:
torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))
生成如下内容:
graph(%x : Tensor) {
%5 : bool = prim::Constant[value=1]()
%1 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %1)
%4 : int = aten::size(%x, %1)
%result : Tensor = prim::Loop(%4, %5, %result.1)
block0(%i : int, %7 : Tensor) {
%10 : Tensor = aten::select(%x, %1, %i)
%result.2 : Tensor = aten::mul(%7, %10)
-> (%5, %result.2)
}
return (%result);
}
追踪器警告¶
跟踪器会对跟踪计算中的几种问题模式产生警告。例如,考虑一个包含对张量切片(视图)进行就地赋值的函数的跟踪:
def fill_row_zero(x):
x[0] = torch.rand(*x.shape[1:2])
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
生成几个警告和一个简单返回输入的图形:
fill_row_zero.py:4: TracerWarning: 在跟踪就地操作符 copy_ 时,有 2 个对正在修改的数据区域的实时引用(可能由于赋值)。这可能会导致跟踪不正确,因为所有其他也引用此数据的视图都不会在跟踪中反映此更改!另一方面,如果所有其他视图使用相同的内存块,但它们是分离的(例如,是 torch.split 的输出),这可能仍然是安全的。
x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: 跟踪函数的第 1 个输出与 Python 函数的相应输出不匹配。详细错误:
不在容差范围内 rtol=1e-05 atol=1e-05 在输入[0, 1] 处 (0.09115803241729736 vs. 0.6782537698745728) 并且在其他 3 个位置 (33.00%)
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
return (%0);
}
我们可以通过修改代码,不再使用就地更新,而是通过torch.cat来构建结果张量,从而解决这个问题:
def fill_row_zero(x):
x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
常见问题¶
问:我想在GPU上训练模型并在CPU上进行推理。有哪些最佳实践?
首先将您的模型从GPU转换为CPU,然后保存它,如下所示:
cpu_model = gpu_model.cpu() sample_input_cpu = sample_input_gpu.cpu() traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) torch.jit.save(traced_cpu, "cpu.pt") traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) torch.jit.save(traced_gpu, "gpu.pt") # ... 稍后,在使用模型时: if use_gpu: model = torch.jit.load("gpu.pt") else: model = torch.jit.load("cpu.pt") model(input)这是推荐的,因为跟踪器可能会在特定设备上见证张量创建,因此对已经加载的模型进行转换可能会产生意外效果。在保存模型之前进行转换可以确保跟踪器具有正确的设备信息。
问:如何在 ScriptModule 上存储属性?
假设我们有一个模型如下:
import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.x = 2 def forward(self): return self.x m = torch.jit.script(Model())如果
Model被实例化,将会导致编译错误,因为编译器不知道x。有4种方法可以告知编译器ScriptModule上的属性:1.
nn.Parameter- 包装在nn.Parameter中的值将像在nn.Module中一样工作2.
register_buffer- 被register_buffer包装的值将像在nn.Module上一样工作。这相当于一个类型为Tensor的属性(见4)。3. 常量 - 将类成员注释为
Final(或在类定义级别将其添加到名为__constants__的列表中)将标记包含的名称 为常量。常量直接保存在模型的代码中。详情请参见 内置常量。4. 属性 - 可以添加为可变属性的值是支持的类型。大多数类型可以被推断,但有些可能需要指定,详情请参见模块属性。
问:我想追踪模块的方法,但我一直遇到这个错误:
RuntimeError: 无法 插入 一个 需要 梯度 的 张量 作为 常量。考虑 将其 设为 参数 或 输入,或者 分离 梯度
这个错误通常意味着您正在追踪的方法使用了模块的参数,而您传递的是模块的方法而不是模块实例(例如
my_module_instance.forward对比my_module_instance)。
调用
trace时使用模块的方法会捕获模块参数(可能需要梯度)作为 常量。另一方面,使用模块的实例(例如
my_module)调用trace会创建一个新模块,并正确地将参数复制到新模块中,因此如果需要,它们可以累积梯度。要在模块上跟踪特定方法,请参见
torch.jit.trace_module
已知问题¶
如果你在使用 Sequential 与 TorchScript,某些 Sequential 子模块的输入可能会被错误地推断为 Tensor,即使它们被注释为其他类型。标准的解决方案是继承 nn.Sequential 并重新声明 forward 方法,以正确地输入类型。
附录¶
迁移到 PyTorch 1.2 递归脚本 API¶
本节详细介绍了PyTorch 1.2中TorchScript的更改。如果您是TorchScript的新手,可以跳过本节。PyTorch 1.2中TorchScript API有两个主要变化。
1. torch.jit.script 现在将尝试递归编译它遇到的函数、方法和类。一旦你调用 torch.jit.script,编译是“选择退出”,而不是“选择加入”。
2. torch.jit.script(nn_module_instance) 现在是创建
ScriptModule的首选方式,而不是继承自 torch.jit.ScriptModule。
这些变化结合起来,为将您的 nn.Module 转换为准备好在非 Python 环境中进行优化和执行的 ScriptModule 提供了更简单、更易于使用的 API。
新的用法如下所示:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
my_model = Model()
my_scripted_model = torch.jit.script(my_model)
模块的
forward默认情况下会被编译。从forward调用的方法会按其在forward中使用的顺序进行延迟编译。要编译一个不是
forward的方法,并且该方法不是从forward调用的,请添加@torch.jit.export。要阻止编译器编译一个方法,添加
@torch.jit.ignore或@torch.jit.unused。@ignore保留了将方法作为对python的调用,并且
@unused将其替换为异常。@ignored不能被导出;@unused可以。大多数属性类型可以被推断,因此
torch.jit.Attribute不是必需的。对于空容器类型,使用PEP 526风格的类注解来标注它们的类型。常量可以用
Final类注解标记,而不是将成员名称添加到__constants__中。Python 3 类型提示可以用来替代
torch.jit.annotate
- As a result of these changes, the following items are considered deprecated and should not appear in new code:
The
@torch.jit.script_method装饰器继承自
torch.jit.ScriptModule的类The
torch.jit.Attribute包装类The
__constants__数组The
torch.jit.annotate函数
模块¶
警告
@torch.jit.ignore 注解的行为在 PyTorch 1.2 中发生了变化。在 PyTorch 1.2 之前,@ignore 装饰器用于使函数或方法可以从导出的代码中调用。要恢复此功能,请使用 @torch.jit.unused()。@torch.jit.ignore 现在等同于 @torch.jit.ignore(drop=False)。详情请参见 @torch.jit.ignore 和 @torch.jit.unused。
当传递给 torch.jit.script 函数时,torch.nn.Module 的数据会被复制到一个 ScriptModule,并且 TorchScript 编译器会编译该模块。
模块的 forward 默认会被编译。从 forward 调用的方法是按它们在 forward 中使用的顺序进行延迟编译的,以及任何 @torch.jit.export 方法。
- torch.jit.export(fn)[源代码]¶
这个装饰器表示一个在
nn.Module上的方法被用作进入一个ScriptModule的入口点,并且应该被编译。forward隐式地被认为是入口点,因此不需要这个装饰器。 从forward调用的函数和方法在编译器看到时会被编译,因此它们也不需要这个装饰器。示例(在方法上使用
@torch.jit.export):import torch import torch.nn as nn class MyModule(nn.Module): def implicitly_compiled_method(self, x): return x + 99 # `forward` 被隐式地用 `@torch.jit.export` 装饰, # 所以在这里添加它不会有任何效果 def forward(self, x): return x + 10 @torch.jit.export def another_forward(self, x): # 当编译器看到这个调用时,它将编译 # `implicitly_compiled_method` return self.implicitly_compiled_method(x) def unused_method(self, x): return x - 20 # `m` 将包含编译的方法: # `forward` # `another_forward` # `implicitly_compiled_method` # `unused_method` 不会被编译,因为它没有从 # 任何编译的方法中调用,也没有用 `@torch.jit.export` 装饰 m = torch.jit.script(MyModule())
函数¶
函数变化不大,如果需要,它们可以用@torch.jit.ignore或torch.jit.unused进行装饰。
# 与PyTorch 1.2之前的行为相同
@torch.jit.script
def some_fn():
return 2
# 将函数标记为忽略,如果没有
# 任何调用它,则此操作无效
@torch.jit.ignore
def some_fn2():
return 2
# 与ignore一样,如果没有调用它,则它没有效果。
# 如果在脚本中调用它,则会被替换为一个异常。
@torch.jit.unused
def some_fn3():
import pdb; pdb.set_trace()
return 4
# 没有任何作用,这个函数已经是
# 主要的入口点
@torch.jit.export
def some_fn4():
return 2
TorchScript 类¶
警告
TorchScript 类支持是实验性的。目前它最适合用于简单的记录类类型(可以理解为带有方法的 NamedTuple)。
用户定义的TorchScript 类中的所有内容默认都会被导出,如果需要,可以使用@torch.jit.ignore装饰函数。
属性¶
TorchScript 编译器需要知道 模块属性 的类型。大多数类型可以从成员的值推断出来。空列表和字典无法推断其类型,必须使用 PEP 526 风格的 类注解来标注其类型。如果一个类型无法推断且未明确注解,它将不会被添加为生成的 ScriptModule 的属性。
旧版 API:
from typing import Dict
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.my_dict = torch.jit.Attribute({}, Dict[str, int])
self.my_int = torch.jit.Attribute(20, int)
m = MyModule()
新API:
from typing import Dict
class MyModule(torch.nn.Module):
my_dict: Dict[str, int]
def __init__(self):
super().__init__()
# 这种类型无法推断,必须指定
self.my_dict = {}
# 这里的属性类型被推断为 `int`
self.my_int = 20
def forward(self):
pass
m = torch.jit.script(MyModule())
常量¶
Final 类型构造函数可以用来将成员标记为常量。如果成员没有被标记为常量,它们将被复制到生成的 ScriptModule 作为属性。使用 Final 可以在值已知为固定的情况下提供优化的机会,并提供额外的类型安全性。
旧版 API:
class MyModule(torch.jit.ScriptModule):
__constants__ = ['my_constant']
def __init__(self):
super().__init__()
self.my_constant = 2
def forward(self):
pass
m = MyModule()
新API:
from typing import Final
class MyModule(torch.nn.Module):
my_constant: Final[int]
def __init__(self):
super().__init__()
self.my_constant = 2
def forward(self):
pass
m = torch.jit.script(MyModule())
变量¶
容器被假定为具有类型 Tensor 并且是非可选的(更多信息请参见 默认类型)。以前,torch.jit.annotate 用于告诉 TorchScript 编译器类型应该是什么。现在支持 Python 3 风格的类型提示。
import torch
from typing import Dict, Optional
@torch.jit.script
def make_dict(flag: bool):
x: Dict[str, int] = {}
x['hi'] = 2
b: Optional[int] = None
if flag:
b = 2
return x, b
融合后端¶
有几种融合后端可用于优化 TorchScript 执行。CPU 上的默认融合器是 NNC,它可以为 CPU 和 GPU 执行融合。GPU 上的默认融合器是 NVFuser,它支持更广泛的操作符,并展示了生成内核的吞吐量改进。有关使用和调试的更多详细信息,请参阅 NVFuser 文档。