Shortcuts

torch.testing._comparison 的源代码

import abc
import cmath
import collections.abc
import contextlib
import warnings
from typing import (
    Any,
    Callable,
    Collection,
    Dict,
    List,
    NoReturn,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
)

import torch

try:
    import numpy as np

    NUMPY_AVAILABLE = True
except ModuleNotFoundError:
    NUMPY_AVAILABLE = False


class ErrorMeta(Exception):
    """内部测试异常,携带错误元数据。"""

    def __init__(
        self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()
    ) -> None:
        super().__init__(
            "如果你是用户,并且在正常操作中看到此消息,请在 https://github.com/pytorch/pytorch/issues 提交问题。 "
            "如果你是开发人员,并且在比较函数上工作,请使用 `raise ErrorMeta().to_error()` 引发面向用户的错误。"
        )
        self.type = type
        self.msg = msg
        self.id = id

    def to_error(
        self, msg: Optional[Union[str, Callable[[str], str]]] = None
    ) -> Exception:
        if not isinstance(msg, str):
            generated_msg = self.msg
            if self.id:
                generated_msg += f"\n\n失败发生在项目 {''.join(str([item]) for item in self.id)}"

            msg = msg(generated_msg) if callable(msg) else generated_msg

        return self.type(msg