Shortcuts

torch.cond

cond_branch_class_method

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)


class CondBranchClassMethod(torch.nn.Module):
    """
    传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
      - 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
      - 两个分支都必须返回一个张量
      - 返回的张量必须具有相同的张量元数据,例如形状和数据类型
      - 分支函数可以是自由函数、嵌套函数、lambda、类方法
      - 分支函数不能有闭包变量
      - 不能对输入或全局变量进行原地修改


    这个示例演示了在 cond() 中使用类方法。

    注意:如果 `pred` 是在批次大小 < 2 的维度上进行测试,它将被特化。
    """

    def __init__(self):
        super().__init__()
        self.subm = MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3]"):
                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
            getitem: "f32[3]" = conditional[0];  conditional = None
            return (getitem,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        cos: "f32[3]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (cos,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        sin: "f32[3]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return (sin,)

图的签名: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {}

cond_branch_nested_function

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class CondBranchNestedFunction(torch.nn.Module):
    """
    传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
      - 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
      - 两个分支都必须返回一个张量
      - 返回的张量必须具有相同的张量元数据,例如形状和数据类型
      - 分支函数可以是自由函数、嵌套函数、lambda、类方法
      - 分支函数不能有闭包变量
      - 不能对输入或全局变量进行原地修改

    这个示例演示了在 cond() 中使用嵌套函数。

    注意:如果 `pred` 在批次大小 < 2 的维度上进行测试,它将被特化。
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        def true_fn(x):
            def inner_true_fn(y):
                return x + y

            return inner_true_fn(x)

        def false_fn(x):
            def inner_false_fn(y):
                return x - y

            return inner_false_fn(x)

        return cond(x.shape[0] < 10, true_fn, false_fn, [x])

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3]"):
                真图_0 = self.真图_0
            假图_0 = self.假图_0
            条件 = torch.ops.higher_order.cond(True, 真图_0, 假图_0, [arg0_1]);  真图_0 = 假图_0 = arg0_1 = None
            获取项: "f32[3]" = 条件[0];  条件 = None
            return (获取项,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        加: "f32[3]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return (加,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        减: "f32[3]" = torch.ops.aten.sub.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return (减,)

图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}

cond_branch_nonlocal_variables

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class CondBranchNonlocalVariables(torch.nn.Module):
    """
    传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
    - 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
    - 两个分支都必须返回一个张量
    - 返回的张量必须具有相同的张量元数据,例如形状和数据类型
    - 分支函数可以是自由函数、嵌套函数、lambda 表达式、类方法
    - 分支函数不能有闭包变量
    - 不能对输入或全局变量进行原地修改

    此示例演示了如何重写代码以避免在分支函数中捕获闭包变量。

    下面的代码将无法工作,因为不支持捕获闭包变量。
    ```
    my_tensor_var = x + 100
    my_primitive_var = 3.14

    def true_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y + my_tensor_var + my_primitive_var

    def false_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y - my_tensor_var - my_primitive_var

    return cond(x.shape[0] > 5, true_fn, false_fn, [x])
    ```

    注意:如果 `pred` 是在批次大小 < 2 的维度上进行测试,它将被特化。
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        my_tensor_var = x + 100
        my_primitive_var = 3.14

        def true_fn(x, y, z):
            return x + y + z

        def false_fn(x, y, z):
            return x - y - z

        return cond(
            x.shape[0] > 5,
            true_fn,
            false_fn,
            [x, my_tensor_var, torch.tensor(my_primitive_var)],
        )

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, _lifted_tensor_constant0: "f32[]", arg0_1: "f32[6]"):
                add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, 100)

                lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant0);  _lifted_tensor_constant0 = None

                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [arg0_1, add, lift_fresh_copy]);  true_graph_0 = false_graph_0 = arg0_1 = add = lift_fresh_copy = None
            getitem: "f32[6]" = conditional[0];  conditional = None
            return (getitem,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
                        add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                add_1: "f32[6]" = torch.ops.aten.add.Tensor(add, arg2_1);  add = arg2_1 = None
                return (add_1,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
                        sub: "f32[6]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                sub_1: "f32[6]" = torch.ops.aten.sub.Tensor(sub, arg2_1);  sub = arg2_1 = None
                return (sub_1,)

图的签名: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0', persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {}

cond_closed_over_variable

注意

标签: torch.cond, python.closure

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class CondClosedOverVariable(torch.nn.Module):
    """
    torch.cond() 支持分支闭包任意变量。
    """

    def forward(self, pred, x):
        def true_fn(val):
            return x * 2

        def false_fn(val):
            return x - 2

        return cond(pred, true_fn, false_fn, [x + 1])

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "b8[]", arg1_1: "f32[3, 2]"):
                真图_0 = self.真图_0
            假图_0 = self.假图_0
            条件 = torch.ops.higher_order.cond(arg0_1, 真图_0, 假图_0, [arg1_1]);  arg0_1 = 真图_0 = 假图_0 = arg1_1 = None
            获取项: "f32[3, 2]" = 条件[0];  条件 = None
            return (获取项,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                        乘法: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, 2);  arg0_1 = None
                return (乘法,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                        减法: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, 2);  arg0_1 = None
                return (减法,)

图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None), 输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg1_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}

条件操作数

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from torch.export import Dim
from functorch.experimental.control_flow import cond

x = torch.randn(3, 2)
y = torch.ones(2)
dim0_x = Dim("dim0_x")

class CondOperands(torch.nn.Module):
    """
    传递给 cond() 的操作数必须是:
    - 一个张量列表
    - 匹配 `true_fn` 和 `false_fn` 的参数

    注意:如果 `pred` 在批次大小 < 2 的维度上进行测试,它将被专门化。
    """

    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        def true_fn(x, y):
            return x + y

        def false_fn(x, y):
            return x - y

        return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0)
            gt: "Sym(s0 > 2)" = sym_size_int > 2;  sym_size_int = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
            getitem: "f32[s0, 2]" = conditional[0];  conditional = None
            return (getitem,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                        add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (add,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                        sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (sub,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {s0: ValueRanges(下限=2, 上限=oo, is_bool=False)}

条件谓词

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class CondPredicate(torch.nn.Module):
    """
    传递给 cond() 的条件语句(也称为谓词)必须是以下之一:
      - 包含单个元素的 torch.Tensor
      - 布尔表达式

    注意:如果 `pred` 在批量大小 < 2 的维度上进行测试,它将被专门化。
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        pred = x.dim() > 2 and x.shape[2] > 10

        return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[6, 4, 3]"):
                真图_0 = self.true_graph_0
            假图_0 = self.false_graph_0
            条件 = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  真图_0 = 假图_0 = arg0_1 = None
            获取项: "f32[6, 4, 3]" = 条件[0];  条件 = None
            return (获取项,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[6, 4, 3]"):
                        余弦: "f32[6, 4, 3]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (余弦,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[6, 4, 3]"):
                        正弦: "f32[6, 4, 3]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return (正弦,)

图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}
优云智算