Shortcuts

torch.testing

torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[源代码]

断言 actualexpected 接近。

如果 actualexpected 是跨步的、非量化的、实值的且有限的,则当它们满足以下条件时被认为是接近的:

actualexpectedatol+rtolexpected\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert

非有限值(-infinf)只有在相等时才被认为是接近的。NaN 只有在 equal_nanTrue 时才被认为是相等的。

此外,只有当它们具有相同的

  • device(如果check_deviceTrue),

  • dtype(如果check_dtypeTrue),

  • 布局(如果检查布局True),并且

  • 步幅(如果 check_strideTrue)。

如果 actualexpected 是元张量,则只会执行属性检查。

如果 actualexpected 是稀疏的(具有COO、CSR、CSC、BSR或BSC布局中的任何一个),它们的跨步成员将分别进行检查。索引,即COO的 indices,CSR和BSR的 crow_indicescol_indices, 或CSC和BSC布局的 ccol_indicesrow_indices, 分别总是检查是否相等,而值则根据上述定义检查是否接近。

如果 actualexpected 是量化的,当它们具有相同的 qscheme() 并且根据上述定义,dequantize() 的结果接近时,它们被认为是接近的。

actualexpected 可以是 Tensor 或任何可以通过 torch.as_tensor() 构造 torch.Tensor 的张量或标量。除了 Python 标量外,输入类型必须直接相关。此外,actualexpected 可以是 SequenceMapping,在这种情况下,如果它们的结构匹配并且根据上述定义它们的元素都被认为是接近的,则它们被认为是接近的。

注意

Python标量是类型关系要求的例外,因为它们的type(),即 intfloatcomplex,等同于类张量的dtype。因此, 可以检查不同类型的Python标量,但需要check_dtype=False

Parameters
  • 实际 (任意) – 实际输入。

  • 预期 (任意) – 预期的输入。

  • allow_subclasses (bool) – 如果为True(默认)且除了Python标量外,允许直接相关类型的输入。否则需要类型完全相同。

  • rtol (可选[浮点数]) – 相对容差。如果指定,atol也必须指定。如果省略,则根据dtype选择默认值,如下表所示。

  • atol (可选[float]) – 绝对容差。如果指定,rtol也必须指定。如果省略,则根据dtype选择默认值,如下表所示。

  • equal_nan (Union[bool, str]) – 如果True,两个NaN值将被视为相等。

  • check_device (bool) – 如果为True(默认),则断言相应的张量位于相同的 device。如果禁用此检查,则不同 device上的张量在比较之前会被移动到CPU上。

  • check_dtype (bool) – 如果 True(默认),断言相应的张量具有相同的 dtype。如果禁用此检查,则具有不同 dtype 的张量将被提升为共同的 dtype(根据 torch.promote_types())在进行比较之前。

  • check_layout (bool) – 如果 True(默认),断言相应的张量具有相同的 layout。如果禁用此检查,则在比较之前,具有不同 layout 的张量将转换为跨步张量。

  • check_stride (bool) – 如果 True 并且对应的张量是分步的,断言它们具有相同的步幅。

  • msg (可选[联合[str, 可调用[[str], str]]]) – 在比较过程中发生失败时使用的可选错误消息。也可以作为可调用对象传递,在这种情况下,它将被调用并生成消息,并应返回新的消息。

Raises

下表显示了不同dtype的默认rtolatol。如果dtype不匹配,则使用两者容差的最大值。

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

quint8

1.3e-6

1e-5

quint2x4

1.3e-6

1e-5

quint4x2

1.3e-6

1e-5

qint8

1.3e-6

1e-5

qint32

1.3e-6

1e-5

其他

0.0

0.0

注意

assert_close() 具有高度可配置性,默认设置严格。鼓励用户使用 partial() 来适应他们的使用场景。例如,如果需要进行相等性检查,可以定义一个 assert_equal,默认情况下对每个 dtype 使用零容差:

>>> import functools
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
>>> assert_equal(1e-9, 1e-10)
Traceback (most recent call last):
...
AssertionError: 标量不相等!

预期值为1e-10,但得到的是1e-09。
绝对差值:9.000000000000001e-10
相对差值:9.0

示例

>>> # 张量到张量的比较
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # 标量到标量的比较
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected)
>>> # numpy数组与numpy数组比较
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # 序列到序列的比较
>>> import numpy as np
>>> # 序列的类型不必匹配。它们只需要具有相同的长度并且它们的元素需要匹配。
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected)
>>> # 映射到映射比较
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # 映射的类型和可能的顺序不必匹配。它们只需要
>>> # 具有相同的键集,并且它们的元素必须匹配。
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.clone()
>>> # 默认情况下,可以直接比较相关的实例
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
>>> # 可以通过设置 allow_subclasses=False 使检查更加严格
>>> torch.testing.assert_close(
...     torch.nn.Parameter(actual), expected, allow_subclasses=False
... )
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type
 and .
>>> # 如果输入不是直接相关的,它们永远不会被认为是接近的
>>> torch.testing.assert_close(actual.numpy(), expected)
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type 
and .
>>> # 这些规则的例外是 Python 标量。如果 check_dtype=False,可以检查它们的类型
>>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # 默认情况下,NaN != NaN。
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
Traceback (most recent call last):
...
AssertionError: 标量不相等!

预期 nan 但得到 nan。
绝对差值: nan (允许的最大值为 1e-05)
相对差值: nan (允许的最大值为 1.3e-06)
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # 默认的错误消息可以被覆盖。
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
>>> # 如果 msg 是一个可调用对象,它可以用来增强生成的消息,
>>> # 添加额外的信息
>>> torch.testing.assert_close(
...     actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
... )
Traceback (most recent call last):
...
AssertionError: Header

Tensor-likes are not close!

Mismatched elements: 2 / 3 (66.7%)
Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)

Footer
torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[源代码]

创建一个具有给定形状设备数据类型的张量,并用从[低, 高)均匀抽取的值填充。

如果指定了lowhigh并且它们超出了dtype可表示的有限值范围,则它们将被限制为最低或最高可表示的有限值。 如果None,则下表描述了lowhigh的默认值, 这些值取决于dtype

dtype

布尔类型

0

2

无符号整数类型

0

10

有符号整数类型

-9

10

浮点类型

-9

9

复杂类型

-9

9

Parameters
  • 形状 (元组[整数, ...]) – 单个整数或整数序列,定义输出张量的形状。

  • dtype (torch.dtype) – 返回张量的数据类型。

  • 设备 (联合[字符串, torch.device]) – 返回张量的设备。

  • (可选[数字]) – 设置给定范围的下限(包含)。如果提供了一个数字,它将被限制为给定数据类型的最小可表示有限值。当 None(默认)时,此值根据 dtype 确定(见上表)。默认值:None

  • (可选[数字]) –

    设置给定范围的上限(不包括)。如果提供了一个数字,它将被限制为给定数据类型可表示的最大有限值。当 None(默认)时,此值根据 dtype 确定(见上表)。默认值:None

    自版本 2.1 起已弃用:low==high 传递给 make_tensor() 以用于浮点或复数类型自 2.1 起已弃用,并将在 2.3 中移除。请改用 torch.full()

  • requires_grad (可选[布尔值]) – 如果 autograd 应该记录对返回张量的操作。默认值:False

  • 非连续 (可选[布尔值]) – 如果为True,返回的张量将是不连续的。如果构建的张量元素少于两个,则忽略此参数。与memory_format互斥。

  • exclude_zero (可选[布尔值]) – 如果True,则零将被替换为数据类型的最小正数值,具体取决于dtype。对于布尔值和整数类型,零被替换为一。对于浮点类型,它被替换为数据类型的最小正正规数(dtypefinfo()对象的“tiny”值),对于复数类型,它被替换为一个复数,其实部和虚部都是复数类型可表示的最小正正规数。默认值为False

  • memory_format (可选[torch.memory_format]) – 返回张量的内存格式。与noncontiguous互斥。

Raises
  • ValueError – 如果为整数dtype传递了requires_grad=True

  • ValueError – 如果 low >= high

  • ValueError – 如果 lowhighnan

  • ValueError – 如果同时传递了noncontiguousmemory_format

  • TypeError – 如果 dtype 不被此函数支持。

Return type

张量

示例

>>> from torch.testing import make_tensor
>>> # 创建一个值在[-1, 1)之间的浮点张量
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
tensor([ 0.1205, 0.2282, -0.6380])
>>> # 在CUDA上创建一个布尔张量
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
tensor([[False, False],
        [False, True]], device='cuda:0')
torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[源代码]

警告

torch.testing.assert_allclose()1.12 起已被弃用,并将在未来的版本中移除。 请改用 torch.testing.assert_close()。您可以在 这里 找到详细的升级说明。