Shortcuts

torch.testing._creation 的源代码

"""
此模块包含张量创建工具。
"""

import collections.abc
import math
import warnings
from typing import cast, List, Optional, Tuple, Union

import torch

_INTEGRAL_TYPES = [
    torch.uint8,
    torch.int8,
    torch.int16,
    torch.int32,
    torch.int64,
    torch.uint16,
    torch.uint32,
    torch.uint64,
]
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
_FLOATING_8BIT_TYPES = [
    torch.float8_e4m3fn,
    torch.float8_e5m2,
    torch.float8_e4m3fnuz,
    torch.float8_e5m2fnuz,
]
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]


def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor:
    # uniform_ 要求 to-from <= std::numeric_limits::max()
    # 通过在PRNG前后缩放范围来解决这个问题
    if high - low >= torch.finfo(t.dtype).max:
        return t.uniform_(low / 2, high / 2).mul_(2)
    else:
        return t.uniform_(low, high)


[docs]def make_tensor( *shape: Union[int, torch.Size, List[int], Tuple[int, ...]], dtype: torch.dtype, device: Union[str, torch.device], low: Optional[float] = None, high: Optional[float] = None, requires_grad: bool = False, noncontiguous: bool = False, exclude_zero: bool = False, memory_format: Optional[torch.memory_format] = None, ) -> torch.Tensor: r"""创建一个具有给定 :attr:`shape`、:attr:`device` 和 :attr:`dtype` 的张量,并填充从 ``[low, high)`` 均匀绘制的值。 如果 :attr:`low` 或 :attr:`high` 被指定并且超出了 :attr:`dtype` 的可表示有限值范围,则它们将被限制为最低或最高的可表示有限值。 如果为 ``None``,则下表描述了 :attr:`low` 和 :attr:`high` 的默认值,这些值取决于 :attr:`dtype`。 +---------------------------+------------+----------+ | ``dtype`` | ``low`` | ``high`` | +===========================+============+==========+ | 布尔类型 | ``0`` | ``2`` | +---------------------------+------------+----------+ | 无符号整数类型 | ``0`` | ``10`` | +---------------------------+------------+----------+ | 有符号整数类型 | ``-9`` | ``10`` | +---------------------------+------------+----------+ | 浮点类型 | ``-9`` | ``9`` | +---------------------------+------------+----------+ | 复数类型 | ``-9`` | ``9`` | +---------------------------+------------+----------+ 参数: shape (Tuple[int, ...]): 单个整数或定义输出张量形状的整数序列。 dtype (:class:`torch.dtype`): 返回张量的数据类型。 device (Union[str, torch.device]): 返回张量的设备。 low (Optional[Number]): 设置给定范围的下限(包含)。如果提供了数字,则将其限制为给定 dtype 的最小可表示有限值。当 ``None``(默认)时,此值根据 :attr:`dtype` 确定(见上表)。默认: ``None``。 high (Optional[Number]): 设置给定范围的上限(不包含)。如果提供了数字,则将其限制为给定 dtype 的最大可表示有限值。当 ``None``(默认)时,此值根据 :attr:`dtype` 确定(见上表)。默认: ``None``。 .. deprecated:: 2.1 自 2.1 起,为浮点或复数类型传递 ``low==high`` 给 :func:`~torch.testing.make_tensor` 已被弃用,并将在 2.3 中移除。请改用 :func:`torch.full`。 requires_grad (Optional[bool]): 如果自动求导应该记录返回张量上的操作。默认: ``False``。 noncontiguous (Optional[bool]): 如果为 `True`,则返回的张量将是不连续的。如果构造的张量少于两个元素,则忽略此参数。与 ``memory_format`` 互斥。 exclude_zero (Optional[bool]): 如果为 ``True``,则零将被替换为根据 :attr:`dtype` 的小正值。对于布尔和整数类型,零被替换为一。对于浮点类型,它被替换为 dtype 的最小正正规数(:attr:`dtype` 的 :func:`~torch.finfo` 对象的“tiny”值),对于复数类型,它被替换为实部和虚部均为最小正正规数的复数。默认 ``False``。 memory_format (Optional[torch.memory_format]): 返回张量的内存格式。与 ``noncontiguous`` 互斥。 引发: ValueError: 如果为整数 `dtype` 传递 ``requires_grad=True`` ValueError: 如果 ``low >= high``。 ValueError: 如果 :attr:`low` 或 :attr:`high` 为 ``nan``。 ValueError: 如果同时传递了 :attr:`noncontiguous` 和 :attr:`memory_format`。 TypeError: 如果 :attr:`dtype` 不受此函数支持。 示例: