torch.cuda.random 的源代码
from typing import Iterable, List, Union
import torch
from .. import Tensor
from . import _lazy_call, _lazy_init, current_device, device_count
__all__ = [
"get_rng_state",
"get_rng_state_all",
"set_rng_state",
"set_rng_state_all",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
]
[docs]def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
r"""返回指定GPU的随机数生成器状态,以ByteTensor形式表示。
参数:
device (torch.device 或 int, 可选): 要返回RNG状态的设备。
默认值: ``'cuda'`` (即 ``torch.device('cuda')``, 当前CUDA设备)。
.. 警告::
此函数会急切地初始化CUDA。
"""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
return default_generator.get_state()
[docs]def get_rng_state_all() -> List[Tensor]:
r"""返回一个ByteTensor列表,表示所有设备的随机数状态。"""
results = []
for i in range(device_count()):
results.append(get_rng_state(i))
return results
[docs]def set_rng_state(
new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
) -> None:
r"""设置指定GPU的随机数生成器状态。
参数:
new_state (torch.ByteTensor): 所需的状态
device (torch.device 或 int, 可选): 要设置RNG状态的设备。
默认值: ``'cuda'`` (即 ``torch.device('cuda')``, 当前CUDA设备)。
"""
with torch._C._DisableFuncTorch():
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
def cb():
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state_copy)
_lazy_call(cb)
[docs]def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
r"""设置所有设备的随机数生成器状态。
参数:
new_states (Iterable of torch.ByteTensor): 每个设备的所需状态。
"""
for i, state in enumerate(new_states):
set_rng_state(state, i)
[docs]def manual_seed(seed: int) -> None:
r"""为当前GPU设置生成随机数的种子。
如果CUDA不可用,调用此函数是安全的;在这种情况下,它会被静默忽略。
参数:
seed (int): 所需的种子。
.. 警告::
如果您正在使用多GPU模型,此函数不足以获得确定性。要为所有GPU设置种子,请使用 :func:`manual_seed_all`。
"""
seed = int(seed)
def cb():
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
<span