torch.random 的源代码
import contextlib
from typing import Generator
import warnings
from torch._C import default_generator
import torch
[docs]def set_rng_state(new_state: torch.Tensor) -> None:
r"""设置随机数生成器状态。
.. 注意: 此函数仅适用于CPU。对于CUDA,请使用
torch.manual_seed(seed),它适用于CPU和CUDA。
参数:
new_state (torch.ByteTensor): 所需的状态
"""
default_generator.set_state(new_state)
[docs]def get_rng_state() -> torch.Tensor:
r"""返回随机数生成器状态作为 `torch.ByteTensor`。"""
return default_generator.get_state()
[docs]def manual_seed(seed) -> torch._C.Generator:
r"""设置生成随机数的种子。返回一个
`torch.Generator` 对象。
参数:
seed (int): 所需的种子。值必须在包含范围内
`[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`。否则,会引发RuntimeError。
负输入通过公式 `0xffff_ffff_ffff_ffff + seed` 重新映射为正数。
"""
seed = int(seed)
import torch.cuda
if not torch.cuda._is_in_bad_fork():
torch.cuda.manual_seed_all(seed)
import torch.mps
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
import torch.xpu
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)
_seed_custom_device(seed)
return default_generator.manual_seed(seed)
[docs]def seed() -> int:
r"""将生成随机数的种子设置为一个非确定性的
随机数。返回用于种子RNG的64位数字。
"""
seed = default_generator.seed()
import torch.cuda
if not torch.cuda._is_in_bad_fork():
torch.cuda.manual_seed_all(seed)
import torch.mps
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
import torch.xpu
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)
_seed_custom_device(seed)
return seed
def _seed_custom_device(seed) -> None:
r"""为自定义设备设置生成随机数的种子。
参数:
seed (int): 所需的种子。
参见 [注意: 支持带有privateuse1的自定义设备]
"""
seed = int(seed)
custom_backend_name = torch._C._get_privateuse1_backend_name()
if hasattr(torch, custom_backend_name):
custom_device_mod = getattr(torch, custom_backend_name)
_bad_fork_name = "_is_in_bad_fork"
_seed_all_name = "manual_seed_all"
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
if not getattr(custom_device_mod, _bad_fork_name)():
getattr(custom_device_mod, _seed_all_name)(seed)
else:
message = f"为 `{custom_backend_name}` 设备设置种子无效,请添加API的 "
message += f"`{_bad_fork_name}` 和 `{_seed_all_name}` 到 `{custom_backend_name}` 设备模块。"
warnings.warn(message, UserWarning, stacklevel=3)
[docs]def initial_seed() -> int:
r"""返回生成随机数的初始种子作为
Python `long`。
"""
return default_generator.initial_seed()
_fork_rng_warned_already