保存和加载数组#
MLX 支持多种数组序列化格式。
格式 |
扩展名 |
功能 |
备注 |
|---|---|---|---|
NumPy |
|
仅限单个数组 |
|
NumPy 存档 |
|
多个数组 |
|
Safetensors |
|
多个数组 |
|
GGUF |
|
多个数组 |
load() 函数将加载任何支持的序列化格式。它根据扩展名确定格式。load() 的输出取决于格式。
以下是将单个数组保存到文件的示例:
>>> a = mx.array([1.0])
>>> mx.save("array", a)
数组 a 将被保存到文件 array.npy 中(注意扩展名会自动添加)。包含扩展名是可选的;如果缺少扩展名,它将被自动添加。你可以通过以下方式加载数组:
>>> mx.load("array.npy")
array([1], dtype=float32)
以下是将多个数组保存到单个文件的示例:
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.savez("arrays", a, b=b)
为了与numpy.savez()兼容,MLX的savez()将数组作为参数。如果缺少关键字,则会提供默认名称。可以使用以下方式加载:
>>> mx.load("arrays.npz")
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
在这种情况下,load() 返回一个名称到数组的字典。
函数 save_safetensors() 和 save_gguf() 类似于
savez(),但它们接受一个字符串名称到数组的 dict 作为输入:
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.save_safetensors("arrays", {"a": a, "b": b})