numpy.take_along_axis#
- numpy.take_along_axis(arr, indices, axis)[源代码]#
通过匹配一维索引和数据切片从输入数组中取值.
这会遍历索引和数据数组中沿指定轴方向匹配的1维切片,并使用前者在后者中查找值.这些切片可以有不同的长度.
像
argsort和argpartition这样返回沿轴索引的函数,为这个函数生成合适的索引.在 1.15.0 版本加入.
- 参数:
- arrndarray (Ni…, M, Nk…)
源数组
- indicesndarray (Ni…, J, Nk…)
沿 arr 每个一维切片的索引.这必须与 arr 的维度匹配,但维度 Ni 和 Nj 只需要与 arr 进行广播.
- axisint
要沿其进行1d切片的方向.如果axis为None,则输入数组被视为首先被展平为1d,以与`sort`和`argsort`保持一致.
- 返回:
- out: ndarray (Ni…, J, Nk…)
索引结果.
参见
take沿轴进行操作,对每个1维切片使用相同的索引
put_along_axis通过匹配一维索引和数据切片将值放入目标数组
备注
这相当于(但比以下使用
ndindex和s_更快),后者将ii和kk分别设置为索引元组:Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:] J = indices.shape[axis] # Need not equal M out = np.empty(Ni + (J,) + Nk) for ii in ndindex(Ni): for kk in ndindex(Nk): a_1d = a [ii + s_[:,] + kk] indices_1d = indices[ii + s_[:,] + kk] out_1d = out [ii + s_[:,] + kk] for j in range(J): out_1d[j] = a_1d[indices_1d[j]]
等效地,消除内部循环,最后两行将是:
out_1d[:] = a_1d[indices_1d]
示例
>>> import numpy as np
对于这个示例数组
>>> a = np.array([[10, 30, 20], [60, 40, 50]])
我们可以通过直接使用排序,或者使用 argsort 和这个函数来进行排序.
>>> np.sort(a, axis=1) array([[10, 20, 30], [40, 50, 60]]) >>> ai = np.argsort(a, axis=1) >>> ai array([[0, 2, 1], [1, 2, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 20, 30], [40, 50, 60]])
对于 max 和 min 也是同样的工作原理,如果你使用
keepdims保持平凡的维度:>>> np.max(a, axis=1, keepdims=True) array([[30], [60]]) >>> ai = np.argmax(a, axis=1, keepdims=True) >>> ai array([[1], [0]]) >>> np.take_along_axis(a, ai, axis=1) array([[30], [60]])
如果我们想同时获取最大值和最小值,我们可以先堆叠索引
>>> ai_min = np.argmin(a, axis=1, keepdims=True) >>> ai_max = np.argmax(a, axis=1, keepdims=True) >>> ai = np.concatenate([ai_min, ai_max], axis=1) >>> ai array([[0, 1], [1, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 30], [40, 60]])