jax.numpy.searchsorted#
- jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[源代码][源代码]#
在一个已排序的数组中执行二分查找。
JAX 实现的
numpy.searchsorted()
。这将返回在已排序数组
a
中,v
中的值可以插入以保持其排序顺序的索引。- 参数:
- 返回:
形状为
v.shape
的插入索引数组。- 返回类型:
备注
method
参数控制用于计算插入索引的算法。'scan'``(默认)在CPU上往往性能更高,特别是在 ``a
非常大时。'scan_unrolled'
在牺牲额外编译时间的情况下,在GPU上性能更佳。'sort'
在 GPU 和 TPU 等加速器后端上通常更具性能,尤其是在v
非常大的情况下。'compare_all'
在a
非常小时,往往表现最佳。
示例
搜索单个值:
>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5]) >>> jnp.searchsorted(a, 2) Array(1, dtype=int32) >>> jnp.searchsorted(a, 2, side='right') Array(3, dtype=int32)
搜索一批值:
>>> vals = jnp.array([0, 3, 8, 1.5, 2]) >>> jnp.searchsorted(a, vals) Array([0, 3, 7, 1, 1], dtype=int32)
可选地,可以使用
sorter
参数来找到插入到通过jax.numpy.argsort()
排序的数组中的索引:>>> a = jnp.array([4, 3, 5, 1, 2]) >>> sorter = jnp.argsort(a) >>> jnp.searchsorted(a, vals, sorter=sorter) Array([0, 2, 5, 1, 1], dtype=int32)
结果等同于传递排序后的数组:
>>> jnp.searchsorted(jnp.sort(a), vals) Array([0, 2, 5, 1, 1], dtype=int32)