jax.numpy.bincount

目录

jax.numpy.bincount#

jax.numpy.bincount(x, weights=None, minlength=0, *, length=None)[源代码][源代码]#

计算整数数组中每个值的出现次数。

JAX 实现的 numpy.bincount()

对于一个正整数数组 x,该函数返回一个大小为 x.max() + 1 的数组 counts,使得 counts[i] 包含数组 x 中值 i 的出现次数。

JAX 版本与 NumPy 版本有一些不同之处:

  • 在NumPy中,传递包含负值的数组 x 会导致错误。在JAX中,负值会被裁剪为零。

  • JAX 添加了一个可选的 length 参数,该参数可以用于静态指定输出数组的长度,以便此函数可以与 jax.jit() 等转换一起使用。在这种情况下,大于 length + 1 的项将被丢弃。

参数:
  • x (ArrayLike) – 正整数的N维数组

  • weights (ArrayLike | None) – 与 x 相关联的权重可选数组。如果未指定,每个条目的权重将为 1

  • minlength (int) – 输出计数数组的最小长度。

  • length (int | None) – 输出计数数组的长度。必须为 bincount 指定静态长度,以便与 jax.jit() 和其他 JAX 变换一起使用。

返回:

反映 x 中值出现次数的计数或加权和的数组。

返回类型:

Array

示例

基本 bincount:

>>> x = jnp.array([1, 1, 2, 3, 3, 3])
>>> jnp.bincount(x)
Array([0, 2, 1, 3], dtype=int32)

加权计数:

>>> weights = jnp.array([1, 2, 3, 4, 5, 6])
>>> jnp.bincount(x, weights)
Array([ 0,  3,  3, 15], dtype=int32)

指定一个静态的 length 使得这个与即时编译兼容:

>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length'])
>>> jit_bincount(x, length=5)
Array([0, 2, 1, 3, 0], dtype=int32)

任何负数都会被截断到第一个区间,超出指定 length 的数会被丢弃:

>>> x = jnp.array([-1, -1, 1, 3, 10])
>>> jnp.bincount(x, length=5)
Array([2, 1, 0, 1, 0], dtype=int32)