Shortcuts

torch.export.graph_signature 的源代码

```html
import dataclasses
from enum import auto, Enum
from typing import Collection, Dict, List, Mapping, Optional, Set, Tuple, Union


__all__ = [
    "ConstantArgument",
    "CustomObjArgument",
    "ExportBackwardSignature",
    "ExportGraphSignature",
    "InputKind",
    "InputSpec",
    "OutputKind",
    "OutputSpec",
    "SymIntArgument",
    "TensorArgument",
]


@dataclasses.dataclass
class TensorArgument:
    name: str


@dataclasses.dataclass
class SymIntArgument:
    name: str


[docs]@dataclasses.dataclass class CustomObjArgument: name: str class_fqn: str
@dataclasses.dataclass class ConstantArgument: value: Union[int, float, bool, None] ArgumentSpec = Union[ TensorArgument, SymIntArgument, ConstantArgument, CustomObjArgument ]
[docs]class InputKind(Enum): USER_INPUT = auto() PARAMETER = auto() BUFFER = auto() CONSTANT_TENSOR = auto() CUSTOM_OBJ = auto() TOKEN = auto()
[docs]@dataclasses.dataclass class InputSpec: kind: InputKind arg: ArgumentSpec target: Optional[str] persistent: Optional[bool] = None def __post_init__(self): if self.kind == InputKind.BUFFER: assert ( self.persistent is not None ), "Failed to specify persistent flag on BUFFER." assert isinstance( self.arg, (TensorArgument, SymIntArgument, ConstantArgument, CustomObjArgument), ), f"got {type(self.arg)}"
[docs]class OutputKind(Enum): USER_OUTPUT = auto() LOSS_OUTPUT = auto() BUFFER_MUTATION = auto() GRADIENT_TO_PARAMETER = auto() GRADIENT_TO_USER_INPUT = auto() USER_INPUT_MUTATION = auto() TOKEN = auto()
[docs]@dataclasses.dataclass class OutputSpec: kind: OutputKind arg: ArgumentSpec target: Optional[str] def __post_init__(self): assert isinstance(self.arg, (TensorArgument, SymIntArgument, ConstantArgument))
def _sig_to_specs( *, user_inputs: Set[str], inputs_to_parameters: Mapping[str, str], inputs_to_buffers: Mapping[str, str], user_outputs: Set[str], buffer_mutations: Mapping[str, str], user_input_mutations: Mapping[str, str], grad_params: Mapping[str, str], grad_user_inputs: Mapping[str, str], loss_output: Optional[str], inputs: List[ArgumentSpec], outputs: List[ArgumentSpec], input_tokens: List[str], output_tokens: List[str], ) -> Tuple[List[InputSpec], List[OutputSpec]]: def to_input_spec(inp: ArgumentSpec</span