导出函数#

MLX 提供了一个 API,用于将函数导出到文件和从文件导入函数。这使您能够在一个 MLX 前端(例如 Python)中编写的计算在另一个 MLX 前端(例如 C++)中运行。

本指南通过一些示例介绍了MLX导出API的基础知识。 要查看完整的功能列表,请查看API文档

导出基础#

让我们从一个简单的例子开始:

def fun(x, y):
  return x + y

x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)

要导出一个函数,请提供可以调用该函数的示例输入数组。数据不重要,但数组的形状和类型很重要。在上面的示例中,我们导出了fun,其中包含两个float32标量数组。然后我们可以导入该函数并运行它:

add_fun = mx.import_function("add.mlxfn")

out, = add_fun(mx.array(1.0), mx.array(2.0))
# Prints: array(3, dtype=float32)
print(out)

out, = add_fun(mx.array(1.0), mx.array(3.0))
# Prints: array(4, dtype=float32)
print(out)

# Raises an exception
add_fun(mx.array(1), mx.array(3.0))

# Raises an exception
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))

请注意,对add_fun的第三次和第四次调用会引发异常,因为输入的形状和类型与我们导出函数时的示例输入的形状和类型不同。

还要注意,即使原始的 fun 返回单个输出数组,导入的函数总是返回一个或多个数组的元组。

export_function() 的输入和导入函数的输入可以指定为可变位置参数或数组元组:

def fun(x, y):
  return x + y

x = mx.array(1.0)
y = mx.array(1.0)

# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)

# Same as above
mx.export_function("add.mlxfn", fun, (x, y))

imported_fun = mx.import_function("add.mlxfn")

# Ok
out, = imported_fun(x, y)

# Also ok
out, = imported_fun((x, y))

你可以将示例输入作为位置参数或关键字参数传递给函数。如果你使用关键字参数导出函数,那么在调用导入的函数时也必须使用相同的关键字参数。

def fun(x, y):
  return x + y

# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)

imported_fun = mx.import_function("add.mlxfn")

# Ok
out, = imported_fun(x, y=y)

# Also ok
out, = imported_fun((x,), {"y": y})

# Raises since the keyword argument is missing
out, = imported_fun(x, y)

# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)

导出模块#

一个 mlx.nn.Module 可以在导出的函数中包含或不包含参数。以下是一个示例:

model = nn.Linear(4, 4)
mx.eval(model.parameters())

def call(x):
   return model(x)

mx.export_function("model.mlxfn", call, mx.zeros(4))

在上面的例子中,mlx.nn.Linear 模块被导出。它的参数也被保存到 model.mlxfn 文件中。

注意

对于导出函数内部的封闭数组,要格外小心确保它们被评估。导出的计算图将包括生成封闭输入的计算。

如果上面的例子缺少mx.eval(model.parameters(),导出的函数将包括mlx.nn.Module参数的随机初始化。

如果你只想导出Module.__call__函数而不带参数,将它们作为输入传递给call包装器:

model = nn.Linear(4, 4)
mx.eval(model.parameters())

def call(x, **params):
  # Set the model's parameters to the input parameters
  model.update(tree_unflatten(list(params.items())))
  return model(x)

params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)

无形状导出#

就像compile()一样,函数也可以导出用于动态形状的输入。将shapeless=True传递给export_function()exporter()以导出一个可以用于可变形状输入的函数:

mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")

# Ok
out, = imported_abs(mx.array(-1.0))

# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))

使用 shapeless=False(这是默认设置),第二次调用 imported_abs 会因形状不匹配而引发异常。

无形状导出与无形状编译的工作方式相同,应谨慎使用。有关更多信息,请参阅无形状编译的文档

导出多个轨迹#

在某些情况下,函数会为不同的输入参数构建不同的计算图。管理这种情况的一个简单方法是使用每组输入导出到一个新文件。在许多情况下,这是一个不错的选择。但如果导出的函数有大量重复的常量数据(例如mlx.nn.Module的参数),这可能不是最优的选择。

MLX中的导出API允许您通过使用exporter()创建一个导出上下文管理器,将同一函数的多个跟踪导出到单个文件中:

def fun(x, y=None):
    constant = mx.array(3.0)
    if y is not None:
      x += y
    return x + constant

with mx.exporter("fun.mlxfn", fun) as exporter:
    exporter(mx.array(1.0))
    exporter(mx.array(1.0), y=mx.array(0.0))

imported_function = mx.import_function("fun.mlxfn")

# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)

# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)

在上述示例中,函数常量数据(即 constant)仅保存一次。

使用导入函数的转换#

函数转换如 grad(), vmap(), 和 compile() 在导入的函数上工作,就像普通的 Python 函数一样:

def fun(x):
    return mx.sin(x)

x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)

imported_fun = mx.import_function("sine.mlxfn")

# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
# Prints: array(1, dtype=float32)
print(dfdx(x))

# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])

在C++中导入函数#

在C++中导入和运行函数与在Python中基本上是一样的。首先,按照说明来设置一个使用MLX作为库的简单C++项目。

接下来,从Python导出一个简单的函数:

def fun(x, y):
    return mx.exp(x + y)

x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)

在C++中导入并运行该函数只需几行代码:

auto fun = mx::import_function("fun.mlxfn");

auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);

// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;

导入的函数可以在C++中像在Python中一样进行转换。在C++中调用导入的函数时,使用std::vector表示位置参数,使用std::map mx::array>表示关键字参数。

更多示例#

以下是更多完整的示例,展示了如何从Python导出更复杂的函数并在C++中导入和运行它们: