编译#

MLX 有一个 compile() 函数转换,用于编译计算图。通过合并常见的工作和融合某些操作,函数编译可以生成更小的图。在许多情况下,这可以显著提高运行时间和内存使用效率。

开始使用compile()很简单,但对于更复杂的图和高级用法,了解一些边缘情况是有益的。

编译基础#

让我们从一个简单的例子开始:

def fun(x, y):
    return mx.exp(-x) + y

x = mx.array(1.0)
y = mx.array(2.0)

# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))

# Compile the function
compiled_fun = mx.compile(fun)

# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))

常规函数和编译函数的输出在数值精度上是相同的。

第一次调用编译函数时,MLX 将构建计算图,优化它,并生成和编译代码。这可能相对较慢。然而,MLX 会缓存编译函数,因此多次调用编译函数不会启动新的编译。这意味着您通常应该编译计划多次使用的函数。

def fun(x, y):
    return mx.exp(-x) + y

x = mx.array(1.0)
y = mx.array(2.0)

compiled_fun = mx.compile(fun)

# Compiled here
compiled_fun(x, y)

# Not compiled again
compiled_fun(x, y)

# Not compiled again
mx.compile(fun)(x, y)

有一些重要的情况需要注意,这些情况可能导致函数被重新编译:

  • 改变形状或维度的数量

  • 更改任何输入的类型

  • 更改函数的输入数量

在某些情况下,只会重新运行部分编译堆栈(例如更改形状时),而在其他情况下,将重新运行完整的编译堆栈(例如更改类型时)。通常,您应避免过于频繁地编译函数。

另一个需要注意的习惯用法是频繁创建和销毁的编译函数。例如,在循环中编译匿名函数时可能会发生这种情况:

a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
    mx.compile(lambda x: mx.exp(mx.abs(x)))(a)

示例加速#

mlx.nn.gelu() 是一种常用于基于Transformer模型的非线性激活函数。其实现涉及多个一元和二元逐元素操作:

def gelu(x):
    return x * (1 + mx.erf(x / math.sqrt(2))) / 2

如果你在小数组上使用这个函数,它将会受到开销的限制。如果你在大数组上使用它,它将会受到内存带宽的限制。然而,gelu中的所有操作都可以通过compile()融合到一个单一的内核中。这可以显著加快这两种情况的速度。

import time

def timeit(fun, x):
    # warm up
    for _ in range(10):
        mx.eval(fun(x))

    tic = time.perf_counter()
    for _ in range(100):
        mx.eval(fun(x))
    toc = time.perf_counter()
    tpi = 1e3 * (toc - tic) / 100
    print(f"Time per iteration {tpi:.3f} (ms)")

现在创建一个数组,并对两个函数进行基准测试:

x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)

在M1 Max上,时间分别为15.5和3.1毫秒。编译后的gelu快了五倍。

调试#

@mx.compile
def fun(x):
    z = -x
    print(z)  # Crash
    return mx.exp(z)

fun(mx.array(5.0))

为了调试,检查数组可能会有所帮助。一种方法是使用disable_compile()函数或MLX_DISABLE_COMPILE标志全局禁用编译。例如,即使fun被编译,以下代码也是可以的:

@mx.compile
def fun(x):
    z = -x
    print(z) # Okay
    return mx.exp(z)

mx.disable_compile()
fun(mx.array(5.0))

纯函数#

编译的函数旨在是的;也就是说它们不应该有副作用。例如:

state = []

@mx.compile
def fun(x, y):
    z = x + y
    state.append(z)
    return mx.exp(z)

fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)

在第一次调用fun之后,state列表将持有一个占位符数组。该占位符没有任何数据;它仅用于构建计算图。打印这样的数组会导致崩溃。

你有两种选择来处理这个问题。第一种选择是简单地返回 state 作为输出:

state = []

@mx.compile
def fun(x, y):
   z = x + y
   state.append(z)
   return mx.exp(z), state

 _, state = fun(mx.array(1.0), mx.array(2.0))
 # Prints [array(3, dtype=float32)]
 print(state)

在某些情况下,返回更新后的状态可能非常不方便。因此, compile() 有一个参数来捕获隐式输出:

from functools import partial

state = []

# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
    z = x + y
    state.append(z)
    return mx.exp(z), state

fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)

这对于编译一个包含数组容器更新的函数特别有用,这在训练mlx.nn.Module的参数时是常见的做法。

编译后的函数也会将任何不在参数列表中的输入视为常量。例如:

state = [mx.array(1.0)]

@mx.compile
def fun(x):
    return x + state[0]

# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

# Update state
state[0] = mx.array(5.0)

# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

为了使状态的变化反映在fun的输出中,你再次有两个选择。第一个选择是简单地将state作为输入传递给函数。在某些情况下,这可能相当不方便。因此,compile()也有一个参数来捕获隐式输入:

from functools import partial
state = [mx.array(1.0)]

# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
    return x + state[0]

# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

# Update state
state[0] = mx.array(5.0)

# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))

编译训练图#

本节将通过一个简单的示例逐步介绍如何使用compile(),这是一个常见的设置:使用mlx.nn.Module训练模型,并使用带有状态的mlx.optimizers.Optimizer。我们将展示如何使用compile()编译完整的前向、反向和更新过程。

首先,这里是一个没有任何编译的简单示例:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))

# 0, 1 targets
y = mx.array([0, 1, 0, 1])

# Simple linear model
model = nn.Linear(10, 1)

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)

def loss_fn(model, x, y):
    logits = model(x).squeeze()
    return nn.losses.binary_cross_entropy(logits, y)

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Perform 10 steps of gradient descent
for it in range(10):
    loss, grads = loss_and_grad_fn(model, x, y)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)

为了编译更新,我们可以将其全部放入一个函数中,并使用适当的输入和输出捕获进行编译。以下是相同的示例,但已编译:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial

# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))

# 0, 1 targets
y = mx.array([0, 1, 0, 1])

# Simple linear model
model = nn.Linear(10, 1)

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)

def loss_fn(model, x, y):
    logits = model(x).squeeze()
    return nn.losses.binary_cross_entropy(logits, y)

# The state that will be captured as input and output
state = [model.state, optimizer.state]

@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    loss, grads = loss_and_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

# Perform 10 steps of gradient descent
for it in range(10):
    loss = step(x, y)
    # Evaluate the model and optimizer state
    mx.eval(state)
    print(loss)

注意

如果你正在使用一个执行随机采样的模块,例如 mlx.nn.Dropout(),请确保你也将 mx.random.state 包含在 compile() 捕获的 state 中,即 state = [model.state, optimizer.state, mx.random.state]

注意

有关编译完整训练图的更多示例,请查看MLX 示例 GitHub 仓库。

使用编译进行转换#

在MLX中,函数变换是可组合的。你可以将任何函数变换应用于任何其他函数变换的输出。有关更多信息,请参阅函数变换的文档。

编译转换后的函数按预期工作:

grad_fn = mx.grad(mx.exp)

compiled_grad_fn = mx.compile(grad_fn)

# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))

# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))

注意

为了尽可能多地编译,默认情况下不会编译已编译函数的转换。要编译转换后的函数,只需将其通过compile()传递即可。

你也可以编译那些本身调用已编译函数的函数。一个好的做法是编译最外层的函数,以便给compile()提供最大的机会来优化计算图:

@mx.compile
def inner(x):
    return mx.exp(-mx.abs(x))

def outer(x):
    inner(inner(x))

# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

无形状编译#

当编译函数的输入形状发生变化时,函数会被重新编译。你可以通过指定shapeless=Truecompile()来编译一次函数并在具有可变形状的输入上运行它。在这种情况下,输入形状的变化不会导致函数被重新编译。

def fun(x, y):
    return mx.abs(x + y)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.array(1.0)
y = mx.array(-2.0)

# Firt call compiles the function
print(compiled_fun(x, y))

# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))

请谨慎使用无形状编译。由于形状改变时不会触发编译,任何依赖于输入形状的图形都不会按预期工作。依赖于形状的计算很常见,有时难以察觉。例如:

def fun(x):
    return x.reshape(x.shape[0] * x.shape[1], -1)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.random.uniform(shape=(2, 3, 4))

out = compiled_fun(x)

x = mx.random.uniform(shape=(5, 5, 3))

# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)

第二次调用compiled_fun失败是因为调用了reshape(),它在第一次调用中使用了x的静态形状。我们可以通过使用flatten()来避免硬编码x的形状来解决这个问题:

def fun(x):
    return x.flatten(0, 1)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.random.uniform(shape=(2, 3, 4))

out = compiled_fun(x)

x = mx.random.uniform(shape=(5, 5, 3))

# Ok
out = compiled_fun(x)