随机

随机#

MLX中的随机抽样函数默认使用隐式的全局PRNG状态。然而,所有函数都接受一个可选的key关键字参数,以便在需要更细粒度的控制或显式状态管理时使用。

例如,您可以使用以下方式生成随机数:

for _ in range(3):
  print(mx.random.uniform())

这将打印一系列唯一的伪随机数。或者,您可以显式设置密钥:

key = mx.random.key(0)
for _ in range(3):
  print(mx.random.uniform(key=key))

在每次迭代时都会产生相同的伪随机数。

遵循JAX的PRNG设计,我们使用了Threefry的可分割版本,这是一种基于计数器的PRNG。

bernoulli([p, shape, key, stream])

生成伯努利随机值。

categorical(logits[, axis, shape, ...])

从分类分布中采样。

gumbel([shape, dtype, key, stream])

从标准Gumbel分布中采样。

key(seed)

从种子中获取一个伪随机数生成器(PRNG)密钥。

normal([shape, dtype, loc, scale, key, stream])

生成正态分布的随机数。

multivariate_normal(mean, cov[, shape, ...])

给定均值和协方差生成联合正态随机样本。

randint(low, high[, shape, dtype, key, stream])

从给定区间生成随机整数。

seed(seed)

为全局伪随机数生成器设置种子。

split(key[, num, stream])

将PRNG密钥拆分为子密钥。

truncated_normal(lower, upper[, shape, ...])

从截断正态分布生成值。

uniform([low, high, shape, dtype, key, stream])

生成均匀分布的随机数。

laplace([shape, dtype, loc, scale, key, stream])

从拉普拉斯分布中抽取样本。

permutation(x[, axis, key, stream])

生成一个随机排列或对数组的条目进行排列。