mars.tensor.take#

mars.tensor.take(a, indices, axis=None, out=None)[来源]#

沿某个轴从张量中提取元素。

当轴不为None时,此函数的作用与“花哨”的索引相同(使用张量索引数组);然而,如果您需要沿给定轴的元素,它可能更易于使用。类似于 mt.take(arr, indices, axis=3) 的调用相当于 arr[:,:,:,indices,...]

没有复杂索引的解释,这相当于以下使用ndindex,它将每个iijjkk设置为一个索引元组:

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
    for jj in ndindex(Nj):
        for kk in ndindex(Nk):
            out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
  • a (array_like (Ni..., M, Nk...)) – 源张量。

  • indices (array_like (Nj...)) –

    要提取的值的索引。

    也允许标量作为索引。

  • axis (int, 可选) – 用于选择值的轴。默认情况下,使用展平的输入张量。

  • out (Tensor, optional (Ni..., Nj..., Nk...)) – 如果提供,结果将放置在这个张量中。它应该具有适当的形状和数据类型。

  • mode ({'raise', 'wrap', 'clip'}, optional) –

    指定越界索引的行为方式。

    • ’raise’ – 引发错误(默认)

    • ’wrap’ – 进行包裹

    • ’clip’ – 裁剪到范围

    ’clip’ 模式意味着所有过大的索引都被替换为指向该轴上最后一个元素的索引。请注意,这会禁用负数索引。

Returns

out – 返回的张量与 a 的类型相同。

Return type

张量 (Ni…, Nj…, Nk…)

另请参阅

compress

使用布尔掩码获取元素

Tensor.take

等效方法

备注

通过消除上述描述中的内部循环,并使用 s_ 构建简单的切片对象,take 可以通过对每个 1-d 切片应用花式索引来表示:

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
    for kk in ndindex(Nj):
        out[ii + s_[...,] + kk] = a[ii + s_[:,] + kk][indices]

因此,它等价于(但比)以下apply_along_axis的使用更快:

out = mt.apply_along_axis(lambda a_1d: a_1d[indices], axis, a)

示例

>>> import mars.tensor as mt
>>> a = [4, 3, 5, 7, 6, 8]
>>> indices = [0, 1, 4]
>>> mt.take(a, indices).execute()
array([4, 3, 6])

在这个例子中,如果 a 是一个张量,可以使用“花哨”的索引。

>>> a = mt.array(a)
>>> a[indices].execute()
array([4, 3, 6])

如果 indices 不是一维的,输出也将具有这些维度。

>>> mt.take(a, [[0, 1], [2, 3]]).execute()
array([[4, 3],
       [5, 7]])