Shortcuts

torch.xpu.random 的源代码

from typing import Iterable, List, Union

import torch
from .. import Tensor
from . import _lazy_call, _lazy_init, current_device, device_count


[docs]def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor: r"""返回指定GPU的随机数生成器状态,以ByteTensor形式表示。 参数: device (torch.device 或 int, 可选): 要返回RNG状态的设备。 默认值: ``'xpu'`` (即 ``torch.device('xpu')``, 当前的XPU设备)。 .. 警告:: 此函数会急切地初始化XPU。 """ _lazy_init() if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): device = torch.device("xpu", device) idx = device.index if idx is None: idx = current_device() default_generator = torch.xpu.default_generators[idx] return default_generator.get_state()
[docs]def get_rng_state_all() -> List[Tensor]: r"""返回一个ByteTensor列表,表示所有设备的随机数状态。""" results = [] for i in range(device_count()): results.append(get_rng_state(i)) return results
[docs]def set_rng_state( new_state: Tensor, device: Union[int, str, torch.device] = "xpu" ) -> None: r"""设置指定GPU的随机数生成器状态。 参数: new_state (torch.ByteTensor): 所需的状态 device (torch.device 或 int, 可选): 要设置RNG状态的设备。 默认值: ``'xpu'`` (即 ``torch.device('xpu')``, 当前的XPU设备)。 """ with torch._C._DisableFuncTorch(): new_state_copy = new_state.clone(memory_format=torch.contiguous_format) if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): device = torch.device("xpu", device) def cb(): idx = device.index if idx is None: idx = current_device() default_generator = torch.xpu.default_generators[idx] default_generator.set_state(new_state_copy) _lazy_call(cb)
[docs]def set_rng_state_all(new_states: Iterable[Tensor]) -> None: r"""设置所有设备的随机数生成器状态。 参数: new_states (Iterable of torch.ByteTensor): 每个设备的所需状态。 """ for i, state in enumerate(new_states): set_rng_state(state, i)
[docs]def manual_seed(seed: int) -> None: r"""为当前GPU设置生成随机数的种子。 如果XPU不可用,调用此函数是安全的;在这种情况下,它会被静默忽略。 参数: seed (int): 所需的种子。 .. 警告:: 如果您正在使用多GPU模型,此函数不足以获得确定性。要为所有GPU设置种子,请使用 :func:`manual_seed_all`。 """ seed = int(seed) def cb(): idx = current_device() default_generator = torch.xpu.default_generators[idx] default_generator.manual_seed(seed) _lazy_call(cb, seed=True)
[docs]def manual_seed_all(seed<span class="p
优云智算