torch.export¶
警告
此功能是处于积极开发中的原型,未来将会有重大变更。
概述¶
torch.export.export()
接受任意 Python 可调用对象(一个
torch.nn.Module
、一个函数或一个方法),并生成一个仅表示函数张量计算的前向(AOT)方式的跟踪图,该图随后可以以不同的输出执行或序列化。
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
# 代码: a = torch.sin(x)
sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);
# 代码: b = torch.cos(y)
cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);
# 代码: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
return (add,)
Graph signature: ExportGraphSignature(
parameters=[],
buffers=[],
user_inputs=['arg0_1', 'arg1_1'],
user_outputs=['add'],
inputs_to_parameters={},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
torch.export
生成一个干净的中间表示(IR),具有以下不变性。有关IR的更多规格可以在这里找到。
健全性:它保证是原始程序的健全表示,并保持与原始程序相同的调用约定。
规范化:图中没有Python语义。原始程序中的子模块被内联以形成一个完全扁平化的计算图。
图属性:该图是纯函数式的,意味着它不包含具有副作用的操作,如突变或别名。它不会突变任何中间值、参数或缓冲区。
元数据:图中包含了在追踪过程中捕获的元数据,例如用户代码中的堆栈跟踪。
在内部,torch.export
利用了以下最新技术:
TorchDynamo (torch._dynamo) 是一个内部 API,它使用 CPython 的一个特性 称为帧评估 API 来安全地跟踪 PyTorch 图。这提供了大幅改进的图捕获体验,需要重写的部分大大减少,以便完全跟踪 PyTorch 代码。
AOT Autograd 提供了一个功能化的 PyTorch 图,并确保该图被分解/降低到 ATen 操作符集。
Torch FX (torch.fx) 是图的底层表示,允许基于 Python 的灵活转换。
现有框架¶
torch.compile()
也使用了与 torch.export
相同的 PT2 栈,但略有不同:
JIT 与 AOT:
torch.compile()
是一个 JIT 编译器,而 它并不旨在用于在部署之外生成编译后的工件。部分图捕获与全图捕获:当
torch.compile()
遇到模型中不可追踪的部分时,它将“图断开”并回退到在急切的 Python 运行时中运行程序。相比之下,torch.export
旨在获取 PyTorch 模型的全图表示,因此当遇到不可追踪的内容时,它将报错。由于torch.export
生成的全图与任何 Python 特性或运行时无关,因此该图可以被保存、加载并在不同的环境和语言中运行。可用性权衡:由于
torch.compile()
能够在遇到无法追踪的内容时回退到 Python 运行时,因此它更加灵活。torch.export
则要求用户提供更多信息或重写代码以使其可追踪。
与torch.fx.symbolic_trace()
相比,torch.export
使用 TorchDynamo 进行追踪,TorchDynamo 在 Python 字节码级别上操作,使其能够追踪不受 Python 运算符重载限制的任意 Python 构造。此外,torch.export
能够精细地跟踪张量元数据,因此对张量形状等条件的判断不会导致追踪失败。通常情况下,torch.export
预计能够在更多用户程序上工作,并生成更低级别的图(在 torch.ops.aten
运算符级别)。请注意,用户仍然可以在 torch.export
之前使用 torch.fx.symbolic_trace()
作为预处理步骤。
与torch.jit.script()
相比,torch.export
不捕获Python控制流或数据结构,但它支持更多的Python语言特性(因为它更容易全面覆盖Python字节码)。生成的图更简单,并且只包含直线控制流(除了显式控制流操作符)。
与torch.jit.trace()
相比,torch.export
是可靠的:它能够追踪执行整数计算的代码,并记录所有必要的条件,以证明特定追踪对于其他输入是有效的。
导出一个 PyTorch 模型¶
一个示例¶
主要的入口是通过 torch.export.export()
,它接受一个可调用对象(torch.nn.Module
、函数或方法)和样本输入,并将计算图捕获到一个 torch.export.ExportedProgram
中。例如:
import torch
from torch.export import export
# 简单的模块用于演示
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):
# 代码: a = self.conv(x)
convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
);
# 代码: a.add_(constant)
add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);
# 代码: return self.maxpool(self.relu(a))
relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
relu, [3, 3], [3, 3]
);
getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
return (getitem,)
Graph signature: ExportGraphSignature(
parameters=['L__self___conv.weight', 'L__self___conv.bias'],
buffers=[],
user_inputs=['arg2_1', 'arg3_1'],
user_outputs=['getitem'],
inputs_to_parameters={
'arg0_1': 'L__self___conv.weight',
'arg1_1': 'L__self___conv.bias',
},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
检查 ExportedProgram
,我们可以注意到以下内容:
The
torch.fx.Graph
包含了原始程序的计算图,以及原始代码的记录,以便于调试。该图仅包含
torch.ops.aten
操作符,可在 此处 找到 以及自定义操作符,并且完全功能化,没有任何就地操作符 例如torch.add_
。参数(权重和偏置到卷积)被提升为图的输入,导致图中没有
get_attr
节点,而这些节点在之前torch.fx.symbolic_trace()
的结果中是存在的。The
torch.export.ExportGraphSignature
模型化了输入和输出签名,并指定了哪些输入是参数。图中每个节点生成的张量的形状和数据类型已注明。例如,
convolution
节点将生成一个数据类型为torch.float32
且形状为 (1, 16, 256, 256) 的张量。
非严格导出¶
在 PyTorch 2.3 中,我们引入了一种新的追踪模式,称为 非严格模式。 它仍在进行强化,因此如果您遇到任何问题,请使用“oncall: export”标签将其提交到 Github。
在非严格模式下,我们使用Python解释器跟踪程序。 您的代码将完全按照在急切模式下执行;唯一的区别是 所有Tensor对象将被ProxyTensors替换,这些ProxyTensors将记录所有 它们的操作到一个图中。
在严格模式下,这是当前的默认模式,我们首先使用TorchDynamo对程序进行追踪,TorchDynamo是一个字节码分析引擎。TorchDynamo并不会实际执行你的Python代码。相反,它符号化地分析代码并基于结果构建一个图。这种分析使得torch.export能够提供更强的安全性保证,但并非所有Python代码都受支持。
一个可能需要使用非严格模式的例子是,当你遇到一个不支持的TorchDynamo特性,而这个特性可能不容易解决,并且你知道Python代码并不完全需要用于计算。例如:
import contextlib
import torch
class ContextManager():
def __init__(self):
self.count = 0
def __enter__(self):
self.count += 1
def __exit__(self, exc_type, exc_value, traceback):
self.count -= 1
class M(torch.nn.Module):
def forward(self, x):
with ContextManager():
return x.sin() + x.cos()
export(M(), (torch.ones(3, 3),), strict=False) # 非严格模式成功追踪
export(M(), (torch.ones(3, 3),)) # 严格模式失败,出现 torch._dynamo.exc.Unsupported: ContextManager 错误
在这个例子中,第一次调用使用非严格模式(通过strict=False
标志)成功追踪,而第二次调用使用严格模式(默认)导致失败,其中TorchDynamo无法支持上下文管理器。一个选项是重写代码(参见torch.expot的限制),但由于上下文管理器不影响模型中的张量计算,我们可以采用非严格模式的结果。
表达动态性¶
默认情况下,torch.export
会假设所有输入形状都是 静态 的,并针对这些维度专门化导出的程序。然而,某些维度,例如批次维度,可以是动态的,并且在每次运行时可能会有所不同。必须使用 torch.export.Dim()
API 创建这些维度,并通过 dynamic_shapes
参数将它们传递给 torch.export.export()
。示例如下:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):
# 代码: out1 = self.branch1(x1)
permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);
# 代码: out2 = self.branch2(x2)
permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None
# 代码: return (out1 + self.buffer, out2)
add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
return (add, relu_1)
Graph signature: ExportGraphSignature(
parameters=[
'branch1.0.weight',
'branch1.0.bias',
'branch2.0.weight',
'branch2.0.bias',
],
buffers=['L__self___buffer'],
user_inputs=['arg5_1', 'arg6_1'],
user_outputs=['add', 'relu_1'],
inputs_to_parameters={
'arg0_1': 'branch1.0.weight',
'arg1_1': 'branch1.0.bias',
'arg2_1': 'branch2.0.weight',
'arg3_1': 'branch2.0.bias',
},
inputs_to_buffers={'arg4_1': 'L__self___buffer'},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
需要注意的一些额外事项:
通过
torch.export.Dim()
API 和dynamic_shapes
参数,我们指定每个输入的第一个维度为动态的。查看输入arg5_1
和arg6_1
,它们具有符号形状 (s0, 64) 和 (s0, 128),而不是我们作为示例输入传递的 (32, 64) 和 (32, 128) 形状的张量。s0
是一个符号,表示该维度可以是一定范围内的值。exported_program.range_constraints
描述了图中每个符号的范围。在这种情况下,我们看到s0
的范围是 [2, 无穷大]。由于技术原因,这些原因在这里很难解释,它们被假设为不等于 0 或 1。这并不是一个错误,也不一定意味着导出的程序不能用于维度 0 或 1。有关此主题的深入讨论,请参阅 The 0/1 Specialization Problem。
我们还可以指定输入形状之间更富有表现力的关系,例如一对形状可能相差一个,一个形状可能是另一个形状的两倍,或者一个形状是偶数。例如:
class M(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
exported_program = torch.export.export(
M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
# 代码: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None
add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None
return (add,)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
)
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}
需要注意的几点:
通过为第一个输入指定
{0: dimx}
,我们看到第一个输入的结果形状现在是动态的,为[s0]
。现在通过为第二个输入指定{0: dimy}
,我们看到第二个输入的结果形状也是动态的。然而,因为我们表达了dimy = dimx + 1
,而不是arg1_1
的形状包含一个新的符号,我们看到它现在正用与arg0_1
中相同的符号s0
表示。我们可以看到dimy = dimx + 1
的关系通过s0 + 1
显示出来。查看范围约束,我们可以看到
s0
的范围是 [3, 6], 这是最初指定的,我们可以看到s0 + 1
的求解范围是 [4, 7]。
序列化¶
要保存 ExportedProgram
,用户可以使用 torch.export.save()
和
torch.export.load()
API。一种惯例是使用 .pt2
文件扩展名保存 ExportedProgram
。
一个例子:
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
专业领域¶
理解 torch.export
行为的一个关键概念是 静态 和 动态 值之间的区别。
一个动态值是指在每次运行时可以改变的值。这些行为类似于Python函数中的普通参数——你可以为参数传递不同的值,并期望你的函数能够正确处理。张量数据被视为动态的。
一个静态值是在导出时固定的值,不能在导出的程序执行之间改变。当在跟踪过程中遇到该值时,导出器会将其视为常量并将其硬编码到图中。
当执行一个操作时(例如 x + y
),并且所有输入都是静态的,那么操作的输出将直接硬编码到图中,并且该操作不会显示(即它将被常量折叠)。
当一个值被硬编码到图中时,我们说这个图已经针对该值进行了特化。
以下值是静态的:
输入张量形状¶
默认情况下,torch.export
会追踪程序并根据输入张量的形状进行特化,除非通过 dynamic_shapes
参数将某个维度指定为动态。这意味着如果存在依赖于形状的控制流,torch.export
将特化于给定样本输入所采用的分支。例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 2]):
add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
return (add,)
条件 (x.shape[0] > 5
) 没有出现在
ExportedProgram
中,因为示例输入具有静态形状 (10, 2)。由于 torch.export
专门处理输入的静态形状,因此 else 分支 (x - 1
) 将永远不会被访问。为了在跟踪图中保留基于张量形状的动态分支行为,
需要使用 torch.export.dynamic_dim()
来指定输入张量 (x.shape[0]
) 的维度为动态,并且源代码需要
重写。
请注意,作为模块状态一部分的张量(例如参数和缓冲区)总是具有静态形状。
Python 基本类型¶
torch.export
也针对 Python 原始类型进行了专门化处理,
例如 int
、float
、bool
和 str
。然而,它们也有动态变体,例如 SymInt
、SymFloat
和 SymBool
。
例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
return (add_2,)
因为整数是专门化的,torch.ops.aten.add.Tensor
操作都是使用硬编码的常量 1
进行计算,而不是 arg1_1
。如果用户在运行时为 arg1_1
传递了不同的值,例如 2,而不是导出时使用的 1,这将导致错误。
此外,for
循环中使用的 times
迭代器也通过 3 次重复的 torch.ops.aten.add.Tensor
调用在图中“内联”,并且输入 arg2_1
从未被使用。
Python 容器¶
Python 容器(List
、Dict
、NamedTuple
等)被认为是具有静态结构的。
torch.export 的局限性¶
图形断点¶
由于 torch.export
是从 PyTorch 程序中捕获计算图的一次性过程,它最终可能会遇到无法追踪的程序部分,因为几乎不可能支持追踪所有 PyTorch 和 Python 特性。在 torch.compile
的情况下,不支持的操作将导致“图中断”,并且不支持的操作将使用默认的 Python 评估运行。相比之下,torch.export
将要求用户提供额外的信息或重写部分代码以使其可追踪。由于追踪基于 TorchDynamo,它在 Python 字节码级别进行评估,因此与之前的追踪框架相比,所需的代码重写将显著减少。
当遇到图中断时,ExportDB 是一个很好的资源,可以了解支持和不支持的程序类型,以及如何重写程序以使其可追踪。
解决处理此图表中断问题的一个选项是使用 非严格导出
数据/形状依赖的控制流¶
在形状未被专门化的情况下,数据依赖的控制流(if
x.shape[0] > 2
)也可能遇到图断裂,因为跟踪编译器无法在不生成代码的情况下处理组合爆炸的路径数量。在这种情况下,用户需要使用特殊的控制流操作符重写他们的代码。目前,我们支持 torch.cond
来表达类似if-else的控制流(更多功能即将推出!)。
API参考¶
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[源代码]¶
export()
接受任意 Python 可调用对象(一个 nn.Module、一个函数或一个方法)以及示例输入,并生成一个仅表示函数张量计算的前向(AOT)方式的跟踪图,该图随后可以与不同输入一起执行或序列化。跟踪图(1)在功能性 ATen 操作符集中生成标准化操作符(以及任何用户指定的自定义操作符),(2)消除了所有 Python 控制流和数据结构(在某些例外情况下),并且(3)记录了显示此标准化和控制流消除对未来输入有效的形状约束集。健全性保证
在追踪过程中,
export()
会记录用户程序和底层 PyTorch 操作符内核所做的与形状相关的假设。 只有当这些假设成立时,输出的ExportedProgram
才被认为是有效的。追踪对输入张量的形状(而非值)做出假设。 这些假设必须在图捕获时进行验证,以确保
export()
成功。具体来说:输入张量的静态形状假设会自动验证,无需额外努力。
对输入张量动态形状的假设需要通过使用
Dim()
API来显式指定,以构建动态维度,并通过dynamic_shapes
参数将它们与示例输入关联。
如果任何假设无法验证,将会引发致命错误。当这种情况发生时,错误消息将包含验证假设所需的规范修复建议。例如,
export()
可能会建议对动态维度dim0_x
的定义进行以下修复,该维度出现在与输入x
关联的形状中,之前定义为Dim("dim0_x")
:dim = Dim("dim0_x", max=5)
此示例意味着生成的代码要求输入
x
的维度 0 必须小于或等于 5 才有效。您可以检查动态维度定义的建议修复,然后将其逐字复制到您的代码中,而无需更改dynamic_shapes
参数到您的export()
调用中。- Parameters
mod (模块) – 我们将跟踪此模块的前向方法。
dynamic_shapes (可选[联合[字典[字符串, 任意], 元组[任意], 列表[任意]]]) –
一个可选参数,类型应为以下之一: 1) 一个从
f
的参数名称到其动态形状规范的字典, 2) 一个元组,按原始顺序指定每个输入的动态形状规范。 如果你在关键字参数上指定动态性,你需要按照原始函数签名中定义的顺序传递它们。张量参数的动态形状可以指定为以下两种方式之一: (1) 一个从动态维度索引到
Dim()
类型的字典,其中不需要在此字典中包含静态维度索引,但如果包含,则应映射到None;或 (2) 一个Dim()
类型或None的元组/列表,其中Dim()
类型对应于动态维度,静态维度由None表示。作为字典或张量元组/列表的参数通过使用包含的规范的映射或序列递归指定。严格(布尔值)——启用时(默认),导出函数将通过TorchDynamo跟踪程序,以确保生成的图的正确性。否则,导出的程序将不会验证图中所隐含的假设,并可能导致原始模型与导出模型之间的行为差异。当用户需要绕过跟踪器中的错误,或者只是希望逐步在模型中启用安全性时,这很有用。请注意,这不会影响生成的IR规范的不同,无论此处传递的值是什么,模型都将以相同的方式序列化。 警告:此选项是实验性的,使用此选项需自行承担风险。
- Returns
包含跟踪的可调用对象的
ExportedProgram
。- Return type
可接受的输入/输出类型
可接受的输入类型(对于
args
和kwargs
)和输出包括:原始类型,即
torch.Tensor
,int
,float
,bool
和str
。数据类,但必须先通过调用
register_dataclass()
进行注册。(嵌套) 数据结构,包括
dict
、list
、tuple
、namedtuple
和OrderedDict
,包含上述所有类型。
- torch.export.dynamic_shapes.dynamic_dim(t, index, debug_name=None)[源代码]¶
警告
(此功能已弃用。请参阅
Dim()
。)dynamic_dim()
构造一个_Constraint
对象,该对象描述了张量t
的维度index
的动态性。_Constraint
对象应传递给export()
的constraints
参数。- Parameters
t (torch.Tensor) – 具有动态维度大小的示例输入张量
索引 (int) – 动态维度的索引
- Returns
一个描述形状动态性的
_Constraint
对象。它可以传递给export()
,以便export()
不假设指定张量的静态大小,即保持其动态性作为符号大小,而不是根据示例追踪输入的大小进行特化。
具体来说,
dynamic_dim()
可以用来表达以下类型的动态性。维度的尺寸是动态且无界的:
t0 = torch.rand(2, 3) t1 = torch.rand(3, 4) # t0的第一个维度可以是动态大小,而不是总是静态大小2 constraints = [dynamic_dim(t0, 0)] ep = export(fn, (t0, t1), constraints=constraints)
维度的尺寸是动态的,具有下限:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # t0的第一个维度可以是动态大小,下限为5(包含) # t1的第二个维度可以是动态大小,下限为2(不包含) constraints = [ dynamic_dim(t0, 0) >= 5, dynamic_dim(t1, 1) > 2, ] ep = export(fn, (t0, t1), constraints=constraints)
维度的尺寸是动态的,具有上限:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # t0的第一个维度可以是动态大小,上限为16(包含16) # t1的第二个维度可以是动态大小,上限为8(不包含8) constraints = [ dynamic_dim(t0, 0) <= 16, dynamic_dim(t1, 1) < 8, ] ep = export(fn, (t0, t1), constraints=constraints)
维度的尺寸是动态的,它总是等于另一个动态维度的尺寸:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # t0的第二维和t1的第一维的大小总是相等的 constraints = [ dynamic_dim(t0, 1) == dynamic_dim(t1, 0), ] ep = export(fn, (t0, t1), constraints=constraints)
混合搭配上述所有类型,只要它们不表达冲突的要求
- torch.export.save(ep, f, *, extra_files=None, opset_version=None)[源代码]¶
警告
在积极开发中,保存的文件可能无法在较新版本的 PyTorch 中使用。
将一个
ExportedProgram
保存到一个类文件对象中。然后可以使用Python APItorch.export.load
加载。- Parameters
ep (ExportedProgram) – 要保存的导出程序。
f (Union[str, os.PathLike, io.BytesIO) – 一个类文件对象(必须实现 write 和 flush)或包含文件名的字符串。
extra_files (可选[字典[str, 任意]]) – 从文件名到内容的映射,这些内容将作为f的一部分存储。
示例:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # 保存到文件 torch.export.save(ep, 'exported_program.pt2') # 保存到 io.BytesIO 缓冲区 buffer = io.BytesIO() torch.export.save(ep, buffer) # 保存带有额外文件 extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[源代码]¶
警告
在积极开发中,保存的文件可能无法在较新版本的 PyTorch 中使用。
加载一个之前使用
torch.export.save
保存的ExportedProgram
。- Parameters
ep (ExportedProgram) – 要保存的导出程序。
f (Union[str, os.PathLike, io.BytesIO) – 一个类文件对象(必须实现 write 和 flush 方法)或包含文件名的字符串。
extra_files (可选[字典[str, 任意]]) – 在此映射中给出的额外文件名将被加载,并且它们的内容将被存储在提供的映射中。
expected_opset_version (可选[字典[str, int]]) – 一个将操作集名称映射到预期操作集版本的字典
- Returns
一个
ExportedProgram
对象- Return type
示例:
import torch import io # 从文件加载 ExportedProgram ep = torch.export.load('exported_program.pt2') # 从 io.BytesIO 对象加载 ExportedProgram with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # 加载带有额外文件的 ExportedProgram extra_files = {'foo.txt': ''} # 值将被数据替换 ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[源代码]¶
将一个数据类注册为
torch.export.export()
的有效输入/输出类型。- Parameters
示例:
@dataclass class InputDataClass: feature: torch.Tensor bias: int class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) def fn(o: InputDataClass) -> torch.Tensor: res = res=o.feature + o.bias return OutputDataClass(res=res) ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[源代码]¶
Dim()
构造了一个类似于具有范围的命名符号整数的类型。 它可以用于描述动态张量维度的多个可能值。 请注意,同一张量或不同张量的不同动态维度可以由相同的类型描述。
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, verifier=None, tensor_constants=None, constants=None)[源代码]¶
来自
export()
的程序包。它包含一个表示张量计算的torch.fx.Graph
,一个包含所有提升参数和缓冲区张量值的 state_dict,以及各种元数据。你可以像原始的可调用对象一样调用一个导出的程序,使用与
export()
相同的调用约定。要对图进行转换,请使用
.module
属性来访问 一个torch.fx.GraphModule
。然后,您可以使用 FX 转换 来重写图。之后,您可以再次使用export()
来构建一个正确的 ExportedProgram。- run_decompositions(decomp_table=None)[源代码]¶
对导出的程序运行一系列分解,并返回一个新的导出程序。默认情况下,我们将运行Core ATen分解,以获取Core ATen操作集中的操作符。
目前,我们不分解联合图。
- Return type
- class torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[源代码]¶
- class torch.export.ExportGraphSignature(input_specs, output_specs)[源代码]¶
ExportGraphSignature
模型化了导出图的输入/输出签名, 这是一个具有更强不变性保证的 fx.Graph。导出图是功能性的,不会通过
getattr
节点访问图中的“状态”,如参数或缓冲区。相反,export()
保证参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何修改也不会包含在图中,而是将修改后的缓冲区的更新值建模为导出图的额外输出。所有输入和输出的顺序如下:
输入 = [*parameters_buffers_constant_tensors, *flattened_user_inputs] 输出 = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块:
class CustomModule(nn.Module): def __init__(self): super(CustomModule, self).__init__() # 定义一个参数 self.my_parameter = nn.Parameter(torch.tensor(2.0)) # 定义两个缓冲区 self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # 在forward方法中使用参数、缓冲区和两个输入 output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # 改变其中一个缓冲区(例如,将其增加1) self.my_buffer2.add_(1.0) # 原地加法 return output
生成的图表将是:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的 ExportGraphSignature 将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='我的参数'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='我的缓冲区1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='我的缓冲区2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='我的缓冲区2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec)[源代码]¶
- class torch.export.ModuleCallEntry(fqn: str, signature: Union[torch.export.exported_program.ModuleCallSignature, NoneType] = None)[源代码]¶
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument], target: Union[str, NoneType], persistent: Union[bool, NoneType] = None)[源代码]¶
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument], target: Union[str, NoneType])[源代码]¶
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[源代码]¶
ExportGraphSignature
模型化了导出图的输入/输出签名, 这是一个具有更强不变性保证的 fx.Graph。导出图是功能性的,不会通过
getattr
节点访问图中的“状态”,如参数或缓冲区。相反,export()
保证将参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何突变也不会包含在图中,而是将突变缓冲区的更新值建模为导出图的额外输出。所有输入和输出的顺序如下:
输入 = [*parameters_buffers_constant_tensors, *flattened_user_inputs] 输出 = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块:
class CustomModule(nn.Module): def __init__(self): super(CustomModule, self).__init__() # 定义一个参数 self.my_parameter = nn.Parameter(torch.tensor(2.0)) # 定义两个缓冲区 self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # 在forward方法中使用参数、缓冲区和两个输入 output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # 改变其中一个缓冲区(例如,将其增加1) self.my_buffer2.add_(1.0) # 原地加法 return output
生成的图表将是:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的 ExportGraphSignature 将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.unflatten.InterpreterModule(graph)[源代码]¶
一个使用 torch.fx.Interpreter 来执行的模块,而不是通常 GraphModule 使用的代码生成。这提供了更好的堆栈跟踪信息,并使调试执行变得更加容易。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[源代码]¶
解开一个导出的 ExportedProgram,生成一个与原始 eager 模块具有相同模块层次结构的模块。如果你尝试将
torch.export
与其他期望模块层次结构而不是torch.export
通常生成的扁平图的系统一起使用,这可能会很有用。注意
未展平模块的参数/关键字参数不一定与急切模块匹配,因此进行模块交换(例如
self.submod = new_mod
)不一定有效。如果您需要替换模块,您需要设置preserve_module_call_signature
参数为torch.export.export()
。- Parameters
模块 (导出的程序) – 要解平的导出程序。
flat_args_adapter (可选[FlatArgsAdapter]) – 如果输入的TreeSpec与导出的模块不匹配,则适配扁平参数。
- Returns
一个
UnflattenedModule
的实例,它具有与导出前原始 eager 模块相同的模块层次结构。- Return type
未展平模块