jax.numpy.arange

目录

jax.numpy.arange#

jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None)[源代码][源代码]#

创建一个等间距值的数组。

JAX 实现 numpy.arange(),基于 jax.lax.iota() 实现。

类似于Python的 range() 函数,这个函数可以用几种不同的位置签名调用:

  • jnp.arange(stop): 生成从 0 到 stop 的值,步长为 1。

  • jnp.arange(start, stop): 从 startstop 生成值,步长为 1。

  • jnp.arange(start, stop, step): 从 startstop 生成值,步长为 step

与Python的 range() 函数类似,起始值是包含的,而停止值是排除的。

参数:
  • start (ArrayLike | DimSize) – 区间的开始,包含在内。

  • stop (ArrayLike | DimSize | None) – 区间的可选结束点,不包括在内。如果未指定,则 (start, stop) = (0, start)

  • step (ArrayLike | None) – 间隔的可选步长。默认 = 1。

  • dtype (DTypeLike | None) – 返回数组的可选 dtype;如果未指定,将通过 startstopstep 的类型提升来确定。

  • device (xc.Device | Sharding | None) – (可选) DeviceSharding ,创建的数组将被提交到该设备或分片。

返回:

startstop 的等间距值数组,由 step 分隔。

返回类型:

Array

备注

使用 arange 并带有浮点数 step 参数可能会导致由于浮点误差累积而产生意外结果,特别是在使用 float8_*bfloat16 等低精度数据类型时。为了避免精度错误,可以考虑生成一个整数范围,然后将其缩放到所需的范围。例如,不要这样做:

jnp.arange(-1, 1, 0.01, dtype='bfloat16')

生成一系列整数并对其进行缩放可能更为准确:

(jnp.arange(-100, 100) * 0.01).astype('bfloat16')

示例

单参数版本仅指定 stop 值:

>>> jnp.arange(4)
Array([0, 1, 2, 3], dtype=int32)

传递一个浮点数 stop 值会导致一个浮点数结果:

>>> jnp.arange(4.0)
Array([0., 1., 2., 3.], dtype=float32)

双参数版本指定 startstop,默认 step=1

>>> jnp.arange(1, 6)
Array([1, 2, 3, 4, 5], dtype=int32)

三参数版本指定 startstopstep

>>> jnp.arange(0, 2, 0.5)
Array([0. , 0.5, 1. , 1.5], dtype=float32)

参见