Shortcuts

torch.random 的源代码

import contextlib
from typing import Generator
import warnings

from torch._C import default_generator
import torch


[docs]def set_rng_state(new_state: torch.Tensor) -> None: r"""设置随机数生成器状态。 .. 注意: 此函数仅适用于CPU。对于CUDA,请使用 torch.manual_seed(seed),它适用于CPU和CUDA。 参数: new_state (torch.ByteTensor): 所需的状态 """ default_generator.set_state(new_state)
[docs]def get_rng_state() -> torch.Tensor: r"""返回随机数生成器状态作为 `torch.ByteTensor`。""" return default_generator.get_state()
[docs]def manual_seed(seed) -> torch._C.Generator: r"""设置生成随机数的种子。返回一个 `torch.Generator` 对象。 参数: seed (int): 所需的种子。值必须在包含范围内 `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`。否则,会引发RuntimeError。 负输入通过公式 `0xffff_ffff_ffff_ffff + seed` 重新映射为正数。 """ seed = int(seed) import torch.cuda if not torch.cuda._is_in_bad_fork(): torch.cuda.manual_seed_all(seed) import torch.mps if not torch.mps._is_in_bad_fork(): torch.mps.manual_seed(seed) import torch.xpu if not torch.xpu._is_in_bad_fork(): torch.xpu.manual_seed_all(seed) _seed_custom_device(seed) return default_generator.manual_seed(seed)
[docs]def seed() -> int: r"""将生成随机数的种子设置为一个非确定性的 随机数。返回用于种子RNG的64位数字。 """ seed = default_generator.seed() import torch.cuda if not torch.cuda._is_in_bad_fork(): torch.cuda.manual_seed_all(seed) import torch.mps if not torch.mps._is_in_bad_fork(): torch.mps.manual_seed(seed) import torch.xpu if not torch.xpu._is_in_bad_fork(): torch.xpu.manual_seed_all(seed) _seed_custom_device(seed) return seed
def _seed_custom_device(seed) -> None: r"""为自定义设备设置生成随机数的种子。 参数: seed (int): 所需的种子。 参见 [注意: 支持带有privateuse1的自定义设备] """ seed = int(seed) custom_backend_name = torch._C._get_privateuse1_backend_name() if hasattr(torch, custom_backend_name): custom_device_mod = getattr(torch, custom_backend_name) _bad_fork_name = "_is_in_bad_fork" _seed_all_name = "manual_seed_all" if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name): if not getattr(custom_device_mod, _bad_fork_name)(): getattr(custom_device_mod, _seed_all_name)(seed) else: message = f"为 `{custom_backend_name}` 设备设置种子无效,请添加API的 " message += f"`{_bad_fork_name}` 和 `{_seed_all_name}` 到 `{custom_backend_name}` 设备模块。" warnings.warn(message, UserWarning, stacklevel=3)
[docs]def initial_seed() -> int: r"""返回生成随机数的初始种子作为 Python `long`。 """ return default_generator.initial_seed()
_fork_rng_warned_already
优云智算