torch.export.dynamic_shapes 的源代码
import builtins
import dataclasses
import inspect
import math
import sys
import weakref
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import SUPPORTED_NODES
from .exported_program import ExportedProgram
if TYPE_CHECKING:
from sympy import Symbol
from torch._guards import Source
from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"]
class _Dim(type):
"""
用于 :func:`Dim` 类型的元类。
"""
@staticmethod
def readable(name, min_, max_):
if min_ == 2:
min_ = None
if max_ == sys.maxsize - 1:
max_ = None
if min_ is None and max_ is None:
return f"Dim('{name}')"
if min_ is None:
return f"Dim('{name}', max={max_})"
if max_ is None:
return f"Dim('{name}', min={min_})"
return f"Dim('{name}', min={min_}, max={max_})"
def __add__(cls, other):
# 例如,dim + 1
if type(other) is not int:
raise NotImplementedError(
f"尝试将 {other} 添加到 {cls.__name__},其中预期为整数。"
"(仅支持具有整数系数的线性递增操作。)"
)
return cls._derive(lambda x: x + other)
def __radd__(cls, other):
return cls + other
def __sub__(cls, other):
# 例如,dim - 1
if type(other) is not int:
raise NotImplementedError(
f"尝试从 {cls.__name__} 中减去 {other},其中预期为整数。"
"(仅支持具有整数系数的线性递增操作。)"
)
return cls._derive(lambda x: x - other)
def __rsub__(cls, other):
raise NotImplementedError(
f"尝试否定 {cls.__name__}。"
"(仅支持具有整数系数的线性递增操作。)"
)
def __mul__(cls, other):
# 例如,dim * 2
if type(other) is not int or other <= 0:
raise NotImplementedError(
f"尝试将 {other} 与 {cls.__name__} 相乘,其中预期为正整数。"
"(仅支持具有整数系数的线性递增操作。)"
)
return cls._derive(lambda x: x * other)
def __rmul__(cls, other):
return cls * other
def _derived_name(cls, fn):
from sympy import sympify
return str(fn(sympify(cls.__name__)))
def _derive(cls, fn):
return _DerivedDim(cls._derived_name(fn), (int,), {