python.闭包¶
cond_closed_over_variable¶
原始源代码:
import torch
from functorch.experimental.control_flow import cond
class CondClosedOverVariable(torch.nn.Module):
"""
torch.cond() 支持分支闭包任意变量。
"""
def forward(self, pred, x):
def true_fn(val):
return x * 2
def false_fn(val):
return x - 2
return cond(pred, true_fn, false_fn, [x + 1])
结果:
导出的程序:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "b8[]", arg1_1: "f32[3, 2]"):
真图_0 = self.真图_0
假图_0 = self.假图_0
条件 = torch.ops.higher_order.cond(arg0_1, 真图_0, 假图_0, [arg1_1]); arg0_1 = 真图_0 = 假图_0 = arg1_1 = None
获取项: "f32[3, 2]" = 条件[0]; 条件 = None
return (获取项,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]"):
乘法: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
return (乘法,)
class (torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]"):
减法: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, 2); arg0_1 = None
return (减法,)
图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None), 输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg1_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}
嵌套函数¶
原始源代码:
import torch
结果:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1); add = None
mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_1, add_1); add_1 = None
add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub); mul = sub = None
return (add_2,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {}