jax.numpy.copy

目录

jax.numpy.copy#

jax.numpy.copy(a, order=None)[源代码][源代码]#

返回数组的副本。

JAX 实现的 numpy.copy()

参数:
  • a (ArrayLike) – 要复制的类数组对象

  • order (str | None) – 在 JAX 中未实现

返回:

输入数组 a 的一个副本。

返回类型:

Array

参见

示例

由于 JAX 数组是不可变的,在大多数情况下不需要显式复制数组。一个例外是在使用带有捐赠参数的函数时(参见 jax.jit()donate_argnums 参数)。

>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0)
>>> x = jnp.arange(4)
>>> y = f(x)
>>> print(y)
[0 2 4 6]

因为我们标记了 x 为捐赠状态,原始数组不再可用:

>>> print(x)  
Traceback (most recent call last):
RuntimeError: Array has been deleted with shape=int32[4].

在这种情况下,显式复制将允许您保持对原始缓冲区的访问:

>>> x = jnp.arange(4)
>>> y = f(x.copy())
>>> print(y)
[0 2 4 6]
>>> print(x)
[0 1 2 3]