转换为NumPy和其他框架#
MLX 数组支持与其他框架之间的转换,可以通过以下方式实现:
让我们将一个数组转换为NumPy并返回。
import mlx.core as mx
import numpy as np
a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b
注意
由于NumPy不支持bfloat16数组,您需要先转换为float16或float32:np.array(a.astype(mx.float32))。否则,您将收到类似以下的错误:Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.
默认情况下,NumPy 会将数据复制到一个新数组中。可以通过创建数组视图来防止这种情况:
a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1
注意
类型为 float64 的 NumPy 数组将默认转换为类型为 float32 的 MLX 数组。
NumPy 数组视图是一个普通的 NumPy 数组,但它不拥有其内存。这意味着对视图的写入会反映在原始数组中。
虽然这在防止复制数组方面非常强大,但需要注意的是,对数组内存的外部更改无法反映在梯度中。
让我们通过一个例子来演示这一点:
def f(x):
x_view = np.array(x, copy=False)
x_view[:] *= x_view # modify memory without telling mx
return x.sum()
x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0
函数 f 通过内存视图间接修改了数组 x。
然而,这种修改并未反映在梯度中,如最后一行输出的 1.0 所示,它仅表示求和操作的梯度。
x 的平方操作发生在 MLX 外部,这意味着没有包含梯度。需要注意的是,在数组转换和复制过程中也会出现类似的问题。
例如,定义为 mx.array(np.array(x)**2).sum() 的函数也会导致错误的梯度,即使没有在 MLX 内存上执行原地操作。
PyTorch#
警告
PyTorch 对 memoryview 的支持是实验性的,可能会在多维数组上出现问题。目前建议先转换为 NumPy。
PyTorch 支持缓冲区协议,但它需要一个显式的
memoryview。
import mlx.core as mx
import torch
a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
从PyTorch张量转换回数组必须通过中间NumPy数组使用numpy()来完成。
JAX#
JAX 完全支持缓冲区协议。
import mlx.core as mx
import jax.numpy as jnp
a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)
张量流#
TensorFlow 支持缓冲区协议,但它需要一个显式的
memoryview。
import mlx.core as mx
import tensorflow as tf
a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)