torch.动态形状¶
cond_branch_class_method¶
原始源代码:
import torch
from functorch.experimental.control_flow import cond
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
def forward(self, x):
return self.foo(x)
class CondBranchClassMethod(torch.nn.Module):
"""
传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
- 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
- 两个分支都必须返回一个张量
- 返回的张量必须具有相同的张量元数据,例如形状和数据类型
- 分支函数可以是自由函数、嵌套函数、lambda、类方法
- 分支函数不能有闭包变量
- 不能对输入或全局变量进行原地修改
这个示例演示了在 cond() 中使用类方法。
注意:如果 `pred` 在批次大小 < 2 的维度上进行测试,它将被特化。
"""
def __init__(self):
super().__init__()
self.subm = MySubModule()
def bar(self, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
结果:
导出的程序:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3]"):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
getitem: "f32[3]" = conditional[0]; conditional = None
return (getitem,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[3]"):
cos: "f32[3]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
return (cos,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[3]"):
sin: "f32[3]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return (sin,)
图的签名: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {}
cond_branch_nested_function¶
原始源代码:
import torch
from functorch.experimental.control_flow import cond
class CondBranchNestedFunction(torch.nn.Module):
"""
传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
- 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
- 两个分支都必须返回一个张量
- 返回的张量必须具有相同的张量元数据,例如形状和数据类型
- 分支函数可以是自由函数、嵌套函数、lambda、类方法
- 分支函数不能有闭包变量
- 不能对输入或全局变量进行原地修改
这个示例演示了在 cond() 中使用嵌套函数。
注意:如果 `pred` 在批次大小 < 2 的维度上进行测试,它将被特化。
"""
def __init__(self):
super().__init__()
def forward(self, x):
def true_fn(x):
def inner_true_fn(y):
return x + y
return inner_true_fn(x)
def false_fn(x):
def inner_false_fn(y):
return x - y
return inner_false_fn(x)
return cond(x.shape[0] < 10, true_fn, false_fn, [x])
结果:
导出的程序:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3]"):
真图_0 = self.真图_0
假图_0 = self.假图_0
条件 = torch.ops.higher_order.cond(True, 真图_0, 假图_0, [arg0_1]); 真图_0 = 假图_0 = arg0_1 = None
获取项: "f32[3]" = 条件[0]; 条件 = None
return (获取项,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[3]"):
加: "f32[3]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
return (加,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[3]"):
减: "f32[3]" = torch.ops.aten.sub.Tensor(arg0_1, arg0_1); arg0_1 = None
return (减,)
图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}
cond_branch_nonlocal_variables¶
原始源代码:
import torch
from functorch.experimental.control_flow import cond
class CondBranchNonlocalVariables(torch.nn.Module):
"""
传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
- 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
- 两个分支都必须返回一个张量
- 返回的张量必须具有相同的张量元数据,例如形状和数据类型
- 分支函数可以是自由函数、嵌套函数、lambda、类方法
- 分支函数不能有闭包变量
- 不能对输入或全局变量进行原地修改
此示例演示了如何重写代码以避免在分支函数中捕获闭包变量。
下面的代码将无法工作,因为不支持捕获闭包变量。
```
my_tensor_var = x + 100
my_primitive_var = 3.14
def true_fn(y):
nonlocal my_tensor_var, my_primitive_var
return y + my_tensor_var + my_primitive_var
def false_fn(y):
nonlocal my_tensor_var, my_primitive_var
return y - my_tensor_var - my_primitive_var
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
```
注意:如果 `pred` 是在批次大小 < 2 的维度上进行测试,它将被特化。
"""
def __init__(self):
super().__init__()
def forward(self, x):
my_tensor_var = x + 100
my_primitive_var = 3.14
def true_fn(x, y, z):
return x + y + z
def false_fn(x, y, z):
return x - y - z
return cond(
x.shape[0] > 5,
true_fn,
false_fn,
[x, my_tensor_var, torch.tensor(my_primitive_var)],
)
结果:
导出的程序:
class GraphModule(torch.nn.Module):
def forward(self, _lifted_tensor_constant0: "f32[]", arg0_1: "f32[6]"):
add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, 100)
lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant0); _lifted_tensor_constant0 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [arg0_1, add, lift_fresh_copy]); true_graph_0 = false_graph_0 = arg0_1 = add = lift_fresh_copy = None
getitem: "f32[6]" = conditional[0]; conditional = None
return (getitem,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
add_1: "f32[6]" = torch.ops.aten.add.Tensor(add, arg2_1); add = arg2_1 = None
return (add_1,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
sub: "f32[6]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
sub_1: "f32[6]" = torch.ops.aten.sub.Tensor(sub, arg2_1); sub = arg2_1 = None
return (sub_1,)
图的签名: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0', persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {}
条件操作数¶
原始源代码:
import torch
from torch.export import Dim
from functorch.experimental.control_flow import cond
x = torch.randn(3, 2)
y = torch.ones(2)
dim0_x = Dim("dim0_x")
class CondOperands(torch.nn.Module):
"""
传递给 cond() 的操作数必须是:
- 张量列表
- 匹配 `true_fn` 和 `false_fn` 的参数
注意:如果 `pred` 在批量大小 < 2 的维度上进行测试,它将被专门化。
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
def true_fn(x, y):
return x + y
def false_fn(x, y):
return x - y
return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
结果:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: "Sym(s0 > 2)" = sym_size_int > 2; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
getitem: "f32[s0, 2]" = conditional[0]; conditional = None
return (getitem,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (sub,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {s0: ValueRanges(下限=2, 上限=oo, is_bool=False)}
条件谓词¶
原始源代码:
import torch
from functorch.experimental.control_flow import cond
class CondPredicate(torch.nn.Module):
"""
传递给 cond() 的条件语句(也称为谓词)必须是以下之一:
- 包含单个元素的 torch.Tensor
- 布尔表达式
注意:如果 `pred` 在批量大小 < 2 的维度上进行测试,它将被专门化。
"""
def __init__(self):
super().__init__()
def forward(self, x):
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
结果:
导出的程序:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[6, 4, 3]"):
真图_0 = self.true_graph_0
假图_0 = self.false_graph_0
条件 = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); 真图_0 = 假图_0 = arg0_1 = None
获取项: "f32[6, 4, 3]" = 条件[0]; 条件 = None
return (获取项,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[6, 4, 3]"):
余弦: "f32[6, 4, 3]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
return (余弦,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[6, 4, 3]"):
正弦: "f32[6, 4, 3]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return (正弦,)
图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}
动态形状构造器¶
原始源代码:
```html
import torch
结果:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]"):
ones: "f32[6]" = torch.ops.aten.ones.default([6], device = device(type='cpu'), pin_memory = False)
return (ones,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='ones'), target=None)])
Range constraints: {}
dynamic_shape_if_guard¶
原始源代码:
import torch
结果:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2, 2]"):
cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
return (cos,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
范围约束: {}
dynamic_shape_map¶
原始源代码:
import torch
from functorch.experimental.control_flow import map
class DynamicShapeMap(torch.nn.Module):
"""
functorch map() 将一个函数映射到第一个张量维度上。
"""
def __init__(self):
super().__init__()
def forward(self, xs, y):
def body(x, y):
return x + y
return map(body, xs, y)
结果:
导出的程序:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[2]"):
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [arg0_1], [arg1_1]); body_graph_0 = arg0_1 = arg1_1 = None
getitem: "f32[3, 2]" = map_impl[0]; map_impl = None
return (getitem,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"):
add: "f32[2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)
图的签名: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {}
dynamic_shape_round¶
原始源代码:
import torch
from torch.export import Dim
x = torch.ones(3, 2)
dim0_x = Dim("dim0_x")
class DynamicShapeRound(torch.nn.Module):
"""
不支持对动态形状调用round。
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x[: round(x.shape[0] / 2)]
结果:
AssertionError:
动态形状切片¶
原始源代码:
```html
import torch
结果:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]"):
slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 1); arg0_1 = None
slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2); slice_1 = None
return (slice_2,)
Graph 签名: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
范围 约束: {}
dynamic_shape_view¶
原始源代码:
import torch
结果:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]"):
view: "f32[10, 2, 5]" = torch.ops.aten.view.default(arg0_1, [10, 2, 5]); arg0_1 = None
permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]); view = None
return (permute,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='permute'), target=None)])
Range constraints: {}
list_contains¶
原始源代码:
import torch
结果:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
标量输出¶
原始源代码:
import torch
from torch.export import Dim
x = torch.ones(3, 2)
dim1_x = Dim("dim1_x")
class ScalarOutput(torch.nn.Module):
"""
除了张量输出外,还支持从图中返回标量值。
捕获符号形状并专门化秩。
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x.shape[1] + 1
结果:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, s0]"):
# 没有找到以下节点的堆栈跟踪
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 1); arg0_1 = None
add: "Sym(s0 + 1)" = sym_size_int + 1; sym_size_int = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='add'), target=None)])
Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)}