索引数组#
在大多数情况下,索引一个MLX array 与索引一个NumPy numpy.ndarray 的工作方式相同。有关其工作原理的更多详细信息,请参阅NumPy文档。
例如,你可以使用常规整数和切片(slice)来索引数组:
>>> arr = mx.arange(10)
>>> arr[3]
array(3, dtype=int32)
>>> arr[-2] # negative indexing works
array(8, dtype=int32)
>>> arr[2:8:2] # start, stop, stride
array([2, 4, 6], dtype=int32)
对于多维数组,... 或 Ellipsis 语法在 NumPy 中同样适用:
>>> arr = mx.arange(8).reshape(2, 2, 2)
>>> arr[:, :, 0]
array(3, dtype=int32)
array([[0, 2],
[4, 6]], dtype=int32
>>> arr[..., 0]
array([[0, 2],
[4, 6]], dtype=int32
你可以使用None进行索引以创建一个新轴:
>>> arr = mx.arange(8)
>>> arr.shape
[8]
>>> arr[None].shape
[1, 8]
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
混合和匹配整数、slice、...和array索引
在NumPy中同样适用。
其他可能对数组索引有用的函数是 take() 和
take_along_axis()。
与NumPy的区别#
注意
MLX 索引在两个方面与 NumPy 索引不同:
索引不执行边界检查。索引越界是未定义行为。
基于布尔掩码的索引尚未支持。
缺乏边界检查的原因是异常无法从GPU传播。在内核启动之前对数组索引执行边界检查将非常低效。
使用布尔掩码进行索引是MLX未来可能支持的功能。一般来说,MLX对输出形状依赖于输入数据的操作支持有限。MLX目前还不支持的其他此类操作包括numpy.nonzero()和numpy.where()的单输入版本。
就地更新#
在MLX中可以对索引数组进行原地更新。例如:
>>> a = mx.array([1, 2, 3])
>>> a[2] = 0
>>> a
array([1, 2, 0], dtype=int32)
就像在NumPy中一样,就地更新将反映在对同一数组的所有引用中:
>>> a = mx.array([1, 2, 3])
>>> b = a
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 0], dtype=int32)
允许使用就地更新的函数转换,并按预期工作。例如:
def fun(x, idx):
x[idx] = 2.0
return x.sum()
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
在上面的dfdx中,将会有正确的梯度,即在idx处为零,其他地方为一。