Shortcuts

torch.export.dynamic_shapes 的源代码

import builtins
import dataclasses
import inspect
import math
import sys
import weakref
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union

import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import SUPPORTED_NODES

from .exported_program import ExportedProgram

if TYPE_CHECKING:
    from sympy import Symbol

    from torch._guards import Source

    from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint

__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"]


class _Dim(type):
    """
    用于 :func:`Dim` 类型的元类。
    """

    @staticmethod
    def readable(name, min_, max_):
        if min_ == 2:
            min_ = None
        if max_ == sys.maxsize - 1:
            max_ = None
        if min_ is None and max_ is None:
            return f"Dim('{name}')"
        if min_ is None:
            return f"Dim('{name}', max={max_})"
        if max_ is None:
            return f"Dim('{name}', min={min_})"
        return f"Dim('{name}', min={min_}, max={max_})"

    def __add__(cls, other):
        # 例如,dim + 1
        if type(other) is not int:
            raise NotImplementedError(
                f"尝试将 {other} 添加到 {cls.__name__},其中预期为整数。"
                "(仅支持具有整数系数的线性递增操作。)"
            )
        return cls._derive(lambda x: x + other)

    def __radd__(cls, other):
        return cls + other

    def __sub__(cls, other):
        # 例如,dim - 1
        if type(other) is not int:
            raise NotImplementedError(
                f"尝试从 {cls.__name__} 中减去 {other},其中预期为整数。"
                "(仅支持具有整数系数的线性递增操作。)"
            )
        return cls._derive(lambda x: x - other)

    def __rsub__(cls, other):
        raise NotImplementedError(
            f"尝试否定 {cls.__name__}。"
            "(仅支持具有整数系数的线性递增操作。)"
        )

    def __mul__(cls, other):
        # 例如,dim * 2
        if type(other) is not int or other <= 0:
            raise NotImplementedError(
                f"尝试将 {other}{cls.__name__} 相乘,其中预期为正整数。"
                "(仅支持具有整数系数的线性递增操作。)"
            )
        return cls._derive(lambda x: x * other)

    def __rmul__(cls, other):
        return cls * other

    def _derived_name(cls, fn):
        from sympy import sympify

        return str(fn(sympify(cls.__name__)))

    def _derive(cls, fn):
        return _DerivedDim(cls._derived_name(fn), (int,), {
优云智算