jax.numpy.searchsorted

jax.numpy.searchsorted#

jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[源代码][源代码]#

在一个已排序的数组中执行二分查找。

JAX 实现的 numpy.searchsorted()

这将返回在已排序数组 a 中,v 中的值可以插入以保持其排序顺序的索引。

参数:
  • a (ArrayLike) – 一维数组,除非指定 sorter ,否则假定为已排序。

  • v (ArrayLike) – 查询值的N维数组

  • side (str) – 'left' (默认) 或 'right'; 指定在出现平局时插入索引是在左边还是右边。

  • sorter (ArrayLike | None) – 指定排序顺序的索引可选数组 a 。如果指定,则算法假设 a[sorter] 已按排序顺序排列。

  • method (str) – 'scan' (默认), 'scan_unrolled', 'sort''compare_all' 之一。见下文 注释

返回:

形状为 v.shape 的插入索引数组。

返回类型:

Array

备注

method 参数控制用于计算插入索引的算法。

  • 'scan'``(默认)在CPU上往往性能更高,特别是在 ``a 非常大时。

  • 'scan_unrolled' 在牺牲额外编译时间的情况下,在GPU上性能更佳。

  • 'sort' 在 GPU 和 TPU 等加速器后端上通常更具性能,尤其是在 v 非常大的情况下。

  • 'compare_all'a 非常小时,往往表现最佳。

示例

搜索单个值:

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

搜索一批值:

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

可选地,可以使用 sorter 参数来找到插入到通过 jax.numpy.argsort() 排序的数组中的索引:

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

结果等同于传递排序后的数组:

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)