Shortcuts

torch.动态形状

cond_branch_class_method

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)


class CondBranchClassMethod(torch.nn.Module):
    """
    传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
      - 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
      - 两个分支都必须返回一个张量
      - 返回的张量必须具有相同的张量元数据,例如形状和数据类型
      - 分支函数可以是自由函数、嵌套函数、lambda、类方法
      - 分支函数不能有闭包变量
      - 不能对输入或全局变量进行原地修改


    这个示例演示了在 cond() 中使用类方法。

    注意:如果 `pred` 在批次大小 < 2 的维度上进行测试,它将被特化。
    """

    def __init__(self):
        super().__init__()
        self.subm = MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3]"):
                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
            getitem: "f32[3]" = conditional[0];  conditional = None
            return (getitem,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        cos: "f32[3]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (cos,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        sin: "f32[3]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return (sin,)

图的签名: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {}

cond_branch_nested_function

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class CondBranchNestedFunction(torch.nn.Module):
    """
    传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
      - 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
      - 两个分支都必须返回一个张量
      - 返回的张量必须具有相同的张量元数据,例如形状和数据类型
      - 分支函数可以是自由函数、嵌套函数、lambda、类方法
      - 分支函数不能有闭包变量
      - 不能对输入或全局变量进行原地修改

    这个示例演示了在 cond() 中使用嵌套函数。

    注意:如果 `pred` 在批次大小 < 2 的维度上进行测试,它将被特化。
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        def true_fn(x):
            def inner_true_fn(y):
                return x + y

            return inner_true_fn(x)

        def false_fn(x):
            def inner_false_fn(y):
                return x - y

            return inner_false_fn(x)

        return cond(x.shape[0] < 10, true_fn, false_fn, [x])

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3]"):
                真图_0 = self.真图_0
            假图_0 = self.假图_0
            条件 = torch.ops.higher_order.cond(True, 真图_0, 假图_0, [arg0_1]);  真图_0 = 假图_0 = arg0_1 = None
            获取项: "f32[3]" = 条件[0];  条件 = None
            return (获取项,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        加: "f32[3]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return (加,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        减: "f32[3]" = torch.ops.aten.sub.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return (减,)

图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}

cond_branch_nonlocal_variables

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class CondBranchNonlocalVariables(torch.nn.Module):
    """
    传递给 cond() 的分支函数 (`true_fn` 和 `false_fn`) 必须遵循以下规则:
    - 两个分支必须接受相同的参数,这些参数也必须与传递给 cond 的分支参数匹配。
    - 两个分支都必须返回一个张量
    - 返回的张量必须具有相同的张量元数据,例如形状和数据类型
    - 分支函数可以是自由函数、嵌套函数、lambda、类方法
    - 分支函数不能有闭包变量
    - 不能对输入或全局变量进行原地修改

    此示例演示了如何重写代码以避免在分支函数中捕获闭包变量。

    下面的代码将无法工作,因为不支持捕获闭包变量。
    ```
    my_tensor_var = x + 100
    my_primitive_var = 3.14

    def true_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y + my_tensor_var + my_primitive_var

    def false_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y - my_tensor_var - my_primitive_var

    return cond(x.shape[0] > 5, true_fn, false_fn, [x])
    ```

    注意:如果 `pred` 是在批次大小 < 2 的维度上进行测试,它将被特化。
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        my_tensor_var = x + 100
        my_primitive_var = 3.14

        def true_fn(x, y, z):
            return x + y + z

        def false_fn(x, y, z):
            return x - y - z

        return cond(
            x.shape[0] > 5,
            true_fn,
            false_fn,
            [x, my_tensor_var, torch.tensor(my_primitive_var)],
        )

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, _lifted_tensor_constant0: "f32[]", arg0_1: "f32[6]"):
                add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, 100)

                lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant0);  _lifted_tensor_constant0 = None

                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [arg0_1, add, lift_fresh_copy]);  true_graph_0 = false_graph_0 = arg0_1 = add = lift_fresh_copy = None
            getitem: "f32[6]" = conditional[0];  conditional = None
            return (getitem,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
                        add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                add_1: "f32[6]" = torch.ops.aten.add.Tensor(add, arg2_1);  add = arg2_1 = None
                return (add_1,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
                        sub: "f32[6]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                sub_1: "f32[6]" = torch.ops.aten.sub.Tensor(sub, arg2_1);  sub = arg2_1 = None
                return (sub_1,)

图的签名: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0', persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {}

条件操作数

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from torch.export import Dim
from functorch.experimental.control_flow import cond

x = torch.randn(3, 2)
y = torch.ones(2)
dim0_x = Dim("dim0_x")

class CondOperands(torch.nn.Module):
    """
    传递给 cond() 的操作数必须是:
    - 张量列表
    - 匹配 `true_fn` 和 `false_fn` 的参数

    注意:如果 `pred` 在批量大小 < 2 的维度上进行测试,它将被专门化。
    """

    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        def true_fn(x, y):
            return x + y

        def false_fn(x, y):
            return x - y

        return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0)
            gt: "Sym(s0 > 2)" = sym_size_int > 2;  sym_size_int = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
            getitem: "f32[s0, 2]" = conditional[0];  conditional = None
            return (getitem,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                        add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (add,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                        sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (sub,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {s0: ValueRanges(下限=2, 上限=oo, is_bool=False)}

条件谓词

注意

标签: torch.cond, torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import cond


class CondPredicate(torch.nn.Module):
    """
    传递给 cond() 的条件语句(也称为谓词)必须是以下之一:
      - 包含单个元素的 torch.Tensor
      - 布尔表达式

    注意:如果 `pred` 在批量大小 < 2 的维度上进行测试,它将被专门化。
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        pred = x.dim() > 2 and x.shape[2] > 10

        return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[6, 4, 3]"):
                真图_0 = self.true_graph_0
            假图_0 = self.false_graph_0
            条件 = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  真图_0 = 假图_0 = arg0_1 = None
            获取项: "f32[6, 4, 3]" = 条件[0];  条件 = None
            return (获取项,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[6, 4, 3]"):
                        余弦: "f32[6, 4, 3]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (余弦,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[6, 4, 3]"):
                        正弦: "f32[6, 4, 3]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return (正弦,)

图签名: ExportGraphSignature(输入规格=[输入规格(种类=<输入种类.用户输入: 1>, 参数=张量参数(名称='arg0_1'), 目标=None, 持久=None)], 输出规格=[输出规格(种类=<输出种类.用户输出: 1>, 参数=张量参数(名称='获取项'), 目标=None)])
范围约束: {}

动态形状构造器

注意

标签: torch.dynamic-shape

支持级别:支持

原始源代码:

```html
import torch

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                ones: "f32[6]" = torch.ops.aten.ones.default([6], device = device(type='cpu'), pin_memory = False)
            return (ones,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='ones'), target=None)])
Range constraints: {}

dynamic_shape_if_guard

注意

标签: torch.dynamic-shape, python.control-flow

支持级别:支持

原始源代码:

import torch

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2, 2]"):
                cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
            return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
范围约束: {}

dynamic_shape_map

注意

标签: torch.dynamic-shape, torch.map

支持级别:支持

原始源代码:

import torch

from functorch.experimental.control_flow import map


class DynamicShapeMap(torch.nn.Module):
    """
    functorch map() 将一个函数映射到第一个张量维度上。
    """

    def __init__(self):
        super().__init__()

    def forward(self, xs, y):
        def body(x, y):
            return x + y

        return map(body, xs, y)

结果:

导出的程序:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[2]"):
                body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [arg0_1], [arg1_1]);  body_graph_0 = arg0_1 = arg1_1 = None
            getitem: "f32[3, 2]" = map_impl[0];  map_impl = None
            return (getitem,)

        class (torch.nn.Module):
            def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"):
                        add: "f32[2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (add,)

图的签名: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)])
范围约束: {}

dynamic_shape_round

注意

标签: torch.dynamic-shape, python.builtin

支持级别: 尚未支持

原始源代码:

import torch

from torch.export import Dim

x = torch.ones(3, 2)
dim0_x = Dim("dim0_x")

class DynamicShapeRound(torch.nn.Module):
    """
    不支持对动态形状调用round。
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x[: round(x.shape[0] / 2)]

结果:

AssertionError:

动态形状切片

注意

标签: torch.dynamic-shape

支持级别:支持

原始源代码:

```html
import torch

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 1);  arg0_1 = None
            slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2);  slice_1 = None
            return (slice_2,)

Graph 签名: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
范围 约束: {}

dynamic_shape_view

注意

标签: torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[10, 10]"):
                view: "f32[10, 2, 5]" = torch.ops.aten.view.default(arg0_1, [10, 2, 5]);  arg0_1 = None

                permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
            return (permute,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='permute'), target=None)])
Range constraints: {}

list_contains

注意

标签: torch.dynamic-shape, python.data-structure, python.assert

支持级别:支持

原始源代码:

import torch

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

标量输出

注意

标签: torch.dynamic-shape

支持级别:支持

原始源代码:

import torch

from torch.export import Dim

x = torch.ones(3, 2)
dim1_x = Dim("dim1_x")

class ScalarOutput(torch.nn.Module):
    """
    除了张量输出外,还支持从图中返回标量值。
    捕获符号形状并专门化秩。
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.shape[1] + 1

结果:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, s0]"):
            # 没有找到以下节点的堆栈跟踪
            sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 1);  arg0_1 = None
            add: "Sym(s0 + 1)" = sym_size_int + 1;  sym_size_int = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='add'), target=None)])
Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)}