Shortcuts

python.闭包

cond_closed_over_variable

注意

标签: torch.cond, python.closure

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class CondClosedOverVariable(torch.nn.Module):
    """
    torch.cond() 支持分支闭包任意变量。
    """

    def forward(self, pred, x):
        def true_fn(val):
            return x * 2

        def false_fn(val):
            return x - 2

        return cond(pred, true_fn, false_fn, [x + 1])

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "b8[]", arg1_1: "f32[3, 2]"):
                真图_0 = self.真图_0
            假图_0 = self.假图_0
            条件 = torch.ops.higher_order.cond(arg0_1, 真图_0, 假图_0, [arg1_1]);  arg0_1 = 真图_0 = 假图_0 = arg1_1 = None
            获取项: "f32[3, 2]" = 条件[0];  条件 = None
            return (获取项,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                        乘法: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, 2);  arg0_1 = None
                return (乘法,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                        减法: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, 2);  arg0_1 = None
                return (减法,)

图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None), 输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg1_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}

嵌套函数

注意

标签: python.closure

支持级别:支持

原始源代码:

import torch

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[2]"):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)

                sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None

                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1);  add = None

                mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_1, add_1);  add_1 = None
            add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub);  mul = sub = None
            return (add_2,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {}