torch.testing._creation 的源代码
"""
此模块包含张量创建工具。
"""
import collections.abc
import math
import warnings
from typing import cast, List, Optional, Tuple, Union
import torch
_INTEGRAL_TYPES = [
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint16,
torch.uint32,
torch.uint64,
]
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
_FLOATING_8BIT_TYPES = [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor:
# uniform_ 要求 to-from <= std::numeric_limits::max()
# 通过在PRNG前后缩放范围来解决这个问题
if high - low >= torch.finfo(t.dtype).max:
return t.uniform_(low / 2, high / 2).mul_(2)
else:
return t.uniform_(low, high)
[docs]def make_tensor(
*shape: Union[int, torch.Size, List[int], Tuple[int, ...]],
dtype: torch.dtype,
device: Union[str, torch.device],
low: Optional[float] = None,
high: Optional[float] = None,
requires_grad: bool = False,
noncontiguous: bool = False,
exclude_zero: bool = False,
memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor:
r"""创建一个具有给定 :attr:`shape`、:attr:`device` 和 :attr:`dtype` 的张量,并填充从 ``[low, high)`` 均匀绘制的值。
如果 :attr:`low` 或 :attr:`high` 被指定并且超出了 :attr:`dtype` 的可表示有限值范围,则它们将被限制为最低或最高的可表示有限值。
如果为 ``None``,则下表描述了 :attr:`low` 和 :attr:`high` 的默认值,这些值取决于 :attr:`dtype`。
+---------------------------+------------+----------+
| ``dtype`` | ``low`` | ``high`` |
+===========================+============+==========+
| 布尔类型 | ``0`` | ``2`` |
+---------------------------+------------+----------+
| 无符号整数类型 | ``0`` | ``10`` |
+---------------------------+------------+----------+
| 有符号整数类型 | ``-9`` | ``10`` |
+---------------------------+------------+----------+
| 浮点类型 | ``-9`` | ``9`` |
+---------------------------+------------+----------+
| 复数类型 | ``-9`` | ``9`` |
+---------------------------+------------+----------+
参数:
shape (Tuple[int, ...]): 单个整数或定义输出张量形状的整数序列。
dtype (:class:`torch.dtype`): 返回张量的数据类型。
device (Union[str, torch.device]): 返回张量的设备。
low (Optional[Number]): 设置给定范围的下限(包含)。如果提供了数字,则将其限制为给定 dtype 的最小可表示有限值。当 ``None``(默认)时,此值根据 :attr:`dtype` 确定(见上表)。默认: ``None``。
high (Optional[Number]): 设置给定范围的上限(不包含)。如果提供了数字,则将其限制为给定 dtype 的最大可表示有限值。当 ``None``(默认)时,此值根据 :attr:`dtype` 确定(见上表)。默认: ``None``。
.. deprecated:: 2.1
自 2.1 起,为浮点或复数类型传递 ``low==high`` 给 :func:`~torch.testing.make_tensor` 已被弃用,并将在 2.3 中移除。请改用 :func:`torch.full`。
requires_grad (Optional[bool]): 如果自动求导应该记录返回张量上的操作。默认: ``False``。
noncontiguous (Optional[bool]): 如果为 `True`,则返回的张量将是不连续的。如果构造的张量少于两个元素,则忽略此参数。与 ``memory_format`` 互斥。
exclude_zero (Optional[bool]): 如果为 ``True``,则零将被替换为根据 :attr:`dtype` 的小正值。对于布尔和整数类型,零被替换为一。对于浮点类型,它被替换为 dtype 的最小正正规数(:attr:`dtype` 的 :func:`~torch.finfo` 对象的“tiny”值),对于复数类型,它被替换为实部和虚部均为最小正正规数的复数。默认 ``False``。
memory_format (Optional[torch.memory_format]): 返回张量的内存格式。与 ``noncontiguous`` 互斥。
引发:
ValueError: 如果为整数 `dtype` 传递 ``requires_grad=True``
ValueError: 如果 ``low >= high``。
ValueError: 如果 :attr:`low` 或 :attr:`high` 为 ``nan``。
ValueError: 如果同时传递了 :attr:`noncontiguous` 和 :attr:`memory_format`。
TypeError: 如果 :attr:`dtype` 不受此函数支持。
示例: