Shortcuts

torch.export.unflatten 的源代码

```html
import abc
import copy
import operator
from copy import deepcopy
from enum import Enum
from itertools import chain
from typing import Any, cast, Dict, List, Optional, Union

import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.export._tree_utils import reorder_kwargs
from torch.export.exported_program import (
    ConstantArgument,
    ExportedProgram,
    ModuleCallSignature,
    SymIntArgument,
    TensorArgument,
)
from torch.fx._symbolic_trace import is_fx_tracing
from torch.utils._pytree import GetAttrKey, SequenceKey

__all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"]


class _AttrKind(Enum):
    PARAMETER = "parameter"
    BUFFER = "buffer"
    CONSTANT = "constant"


# 将属性 'from_obj' 分配给 'to_module' 上的限定名称 'target'
# 如果目标的子路径尚不存在,则会在 'to_module' 上安装空的模块
def _assign_attr(
    from_obj: Union[torch.Tensor, torch.ScriptObject],
    to_module: torch.nn.Module,
    target: str,
    attr_kind: _AttrKind,
    persistent: bool = True,
):
    *prefix, field = target.split(".")
    for item in prefix:
        t = getattr(to_module, item, None)

        if t is None:
            t = torch.nn.Module()
            setattr(to_module, item, t)
        to_module = t

    if attr_kind == _AttrKind.PARAMETER:
        assert isinstance(from_obj, torch.nn.Parameter)
        to_module.register_parameter(field, from_obj)
    elif attr_kind == _AttrKind.BUFFER:
        assert isinstance(from_obj, torch.Tensor)
        to_module.register_buffer(field, from_obj, persistent=persistent)
    elif attr_kind == _AttrKind.CONSTANT:
        assert isinstance(from_obj, (torch.Tensor, torch.ScriptObject))
        setattr(to_module, field, from_obj)


[docs]class InterpreterModule(torch.nn.Module): """一个使用 torch.fx.Interpreter 执行的模块,而不是通常的 GraphModule 使用的代码生成。这提供了更好的堆栈跟踪信息,并使调试执行更容易。""" def __init__( self, graph: torch.fx.Graph, ): super().__init__() self.graph = graph self.graph.owning_module = self def forward(self, *args, **kwargs): assert self.graph_module is not None, "Didn't finalize this InterpreterModule" if torch.compiler.is_dynamo_compiling(): # Dynamo 无法通过 torch.fx.Interpreter 进行跟踪,因此在这种情况下回退到 GraphModule 代码生成。 return self.graph_module(*args, **kwargs) else: if kwargs: # 处理 **kwargs。FX 仅原生支持位置参数(通过占位符)。因此,为了传递 kwargs,我们必须将占位符的名称与 kwarg 字典中的键对应起来。 arg_list = list(args) kwarg_names = self.arg_names[len(arg_list) :] for kwarg_name in kwarg_names: if kwarg_name in kwargs: arg_list.append(kwargs[kwarg_name]) # 断言传入的 kwargs 正好与 GraphModule 指定的位置参数匹配。这应该由展平过程保证。 assert len(kwarg_names) == len(kwargs</span
优云智算