保存和加载数组

保存和加载数组#

MLX 支持多种数组序列化格式。

序列化格式#

格式

扩展名

功能

备注

NumPy

.npy

save()

仅限单个数组

NumPy 存档

.npz

savez()savez_compressed()

多个数组

Safetensors

.safetensors

save_safetensors()

多个数组

GGUF

.gguf

save_gguf()

多个数组

load() 函数将加载任何支持的序列化格式。它根据扩展名确定格式。load() 的输出取决于格式。

以下是将单个数组保存到文件的示例:

>>> a = mx.array([1.0])
>>> mx.save("array", a)

数组 a 将被保存到文件 array.npy 中(注意扩展名会自动添加)。包含扩展名是可选的;如果缺少扩展名,它将被自动添加。你可以通过以下方式加载数组:

>>> mx.load("array.npy")
array([1], dtype=float32)

以下是将多个数组保存到单个文件的示例:

>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.savez("arrays", a, b=b)

为了与numpy.savez()兼容,MLX的savez()将数组作为参数。如果缺少关键字,则会提供默认名称。可以使用以下方式加载:

>>> mx.load("arrays.npz")
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}

在这种情况下,load() 返回一个名称到数组的字典。

函数 save_safetensors()save_gguf() 类似于 savez(),但它们接受一个字符串名称到数组的 dict 作为输入:

>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.save_safetensors("arrays", {"a": a, "b": b})