torch.testing._comparison 的源代码
import abc
import cmath
import collections.abc
import contextlib
import warnings
from typing import (
Any,
Callable,
Collection,
Dict,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import torch
try:
import numpy as np
NUMPY_AVAILABLE = True
except ModuleNotFoundError:
NUMPY_AVAILABLE = False
class ErrorMeta(Exception):
"""内部测试异常,携带错误元数据。"""
def __init__(
self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()
) -> None:
super().__init__(
"如果你是用户,并且在正常操作中看到此消息,请在 https://github.com/pytorch/pytorch/issues 提交问题。 "
"如果你是开发人员,并且在比较函数上工作,请使用 `raise ErrorMeta().to_error()` 引发面向用户的错误。"
)
self.type = type
self.msg = msg
self.id = id
def to_error(
self, msg: Optional[Union[str, Callable[[str], str]]] = None
) -> Exception:
if not isinstance(msg, str):
generated_msg = self.msg
if self.id:
generated_msg += f"\n\n失败发生在项目 {''.join(str([item]) for item in self.id)}"
msg = msg(generated_msg) if callable(msg) else generated_msg
return self.type(msg