jax.numpy.roll

目录

jax.numpy.roll#

jax.numpy.roll(a, shift, axis=None)[源代码][源代码]#

沿指定轴滚动数组的元素。

JAX 实现的 numpy.roll()

参数:
  • a (ArrayLike) – 输入数组。

  • shift (ArrayLike | Sequence[int]) – 指定轴的移动位置数。如果是一个整数,所有轴都按相同数量移动。如果是一个元组,每个轴的移动量分别指定。

  • axis (int | Sequence[int] | None) – 要滚动的轴或轴。如果为 None ,则数组被展平、移位,然后重新整形为其原始形状。

返回:

a 的一个副本,其元素沿着指定的轴或轴滚动。

返回类型:

Array

参见

示例

>>> a = jnp.array([0, 1, 2, 3, 4, 5])
>>> jnp.roll(a, 2)
Array([4, 5, 0, 1, 2, 3], dtype=int32)

沿特定轴滚动元素:

>>> a = jnp.array([[ 0,  1,  2,  3],
...                [ 4,  5,  6,  7],
...                [ 8,  9, 10, 11]])
>>> jnp.roll(a, 1, axis=0)
Array([[ 8,  9, 10, 11],
       [ 0,  1,  2,  3],
       [ 4,  5,  6,  7]], dtype=int32)
>>> jnp.roll(a, [2, 3], axis=[0, 1])
Array([[ 5,  6,  7,  4],
       [ 9, 10, 11,  8],
       [ 1,  2,  3,  0]], dtype=int32)