Shortcuts

torch.quasirandom 的源代码

```html
import torch
from typing import Optional


[docs]class SobolEngine: r""" :class:`torch.quasirandom.SobolEngine` 是一个用于生成(打乱的)Sobol序列的引擎。Sobol序列是低差异准随机序列的一个例子。 这个Sobol序列引擎的实现能够生成高达21201维的序列。它使用来自 https://web.maths.unsw.edu.au/~fkuo/sobol/ 的方向数,使用搜索标准D(6)直到21201维。这是作者推荐的选项。 参考文献: - Art B. Owen. 打乱Sobol和Niederreiter-Xing点。 复杂性杂志,14(4):466-489,1998年12月。 - I. M. Sobol. 立方体中点的分布及积分的精确评估。 计算数学与数学物理杂志,7:784-802,1967年。 参数: dimension (Int): 要绘制的序列的维度 scramble (bool, 可选): 设置为 ``True`` 将生成打乱的Sobol序列。打乱能够生成更好的Sobol序列。默认值:``False``。 seed (Int, 可选): 这是打乱的种子。如果指定,随机数生成器的种子将设置为此值。否则,它使用随机种子。默认值:``None`` 示例:: >>> # xdoctest: +SKIP("unseeded random state") >>> soboleng = torch.quasirandom.SobolEngine(dimension=5) >>> soboleng.draw(3) tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5000, 0.5000, 0.5000, 0.5000, 0.5000], [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]]) """ MAXBIT = 30 MAXDIM = 21201 def __init__(self, dimension, scramble=False, seed=None): if dimension > self.MAXDIM or dimension < 1: raise ValueError("支持的SobolEngine维度范围是 [1, " f"{self.MAXDIM}]") self.seed = seed self.scramble = scramble self.dimension = dimension cpu = torch.device("cpu") self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long) torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) if not self.scramble: self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long) else: self._scramble() self.quasi = self.shift.clone(memory_format=torch.contiguous_format) self._first_point = (self.quasi / 2 ** self.MAXBIT).reshape(1, -1) self.num_generated = 0
[docs] def draw(self, n: int = 1, out: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32) -> torch.Tensor: r""" 从Sobol序列中绘制一个长度为 :attr:`n` 的点序列的函数。请注意,样本依赖于之前的样本。结果的大小为 :math:`(n, dimension)`。 参数: n (Int, 可选): 要绘制的点序列的长度。默认值:1 out (Tensor, 可选): 输出张量 dtype (:class:`torch.dtype`, 可选): 返回张量的期望数据类型。默认值:``torch.float32`` """ if self.num_generated == 0: if n == 1: result = self._first_point.to(dtype) else: result, self.quasi = torch._sobol_engine_draw( self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype, ) result = torch.cat((self._first_point, result), dim=-2) else: result, self.quasi = torch._sobol_engine_draw( self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=<span class="n
优云智算