编译#
MLX 有一个 compile() 函数转换,用于编译计算图。通过合并常见的工作和融合某些操作,函数编译可以生成更小的图。在许多情况下,这可以显著提高运行时间和内存使用效率。
开始使用compile()很简单,但对于更复杂的图和高级用法,了解一些边缘情况是有益的。
编译基础#
让我们从一个简单的例子开始:
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
常规函数和编译函数的输出在数值精度上是相同的。
第一次调用编译函数时,MLX 将构建计算图,优化它,并生成和编译代码。这可能相对较慢。然而,MLX 会缓存编译函数,因此多次调用编译函数不会启动新的编译。这意味着您通常应该编译计划多次使用的函数。
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
有一些重要的情况需要注意,这些情况可能导致函数被重新编译:
改变形状或维度的数量
更改任何输入的类型
更改函数的输入数量
在某些情况下,只会重新运行部分编译堆栈(例如更改形状时),而在其他情况下,将重新运行完整的编译堆栈(例如更改类型时)。通常,您应避免过于频繁地编译函数。
另一个需要注意的习惯用法是频繁创建和销毁的编译函数。例如,在循环中编译匿名函数时可能会发生这种情况:
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
示例加速#
mlx.nn.gelu() 是一种常用于基于Transformer模型的非线性激活函数。其实现涉及多个一元和二元逐元素操作:
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
如果你在小数组上使用这个函数,它将会受到开销的限制。如果你在大数组上使用它,它将会受到内存带宽的限制。然而,gelu中的所有操作都可以通过compile()融合到一个单一的内核中。这可以显著加快这两种情况的速度。
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
现在创建一个数组,并对两个函数进行基准测试:
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
在M1 Max上,时间分别为15.5和3.1毫秒。编译后的gelu快了五倍。
调试#
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
为了调试,检查数组可能会有所帮助。一种方法是使用disable_compile()函数或MLX_DISABLE_COMPILE标志全局禁用编译。例如,即使fun被编译,以下代码也是可以的:
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
纯函数#
编译的函数旨在是纯的;也就是说它们不应该有副作用。例如:
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
在第一次调用fun之后,state列表将持有一个占位符数组。该占位符没有任何数据;它仅用于构建计算图。打印这样的数组会导致崩溃。
你有两种选择来处理这个问题。第一种选择是简单地返回
state 作为输出:
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
在某些情况下,返回更新后的状态可能非常不方便。因此,
compile() 有一个参数来捕获隐式输出:
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
这对于编译一个包含数组容器更新的函数特别有用,这在训练mlx.nn.Module的参数时是常见的做法。
编译后的函数也会将任何不在参数列表中的输入视为常量。例如:
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
为了使状态的变化反映在fun的输出中,你再次有两个选择。第一个选择是简单地将state作为输入传递给函数。在某些情况下,这可能相当不方便。因此,compile()也有一个参数来捕获隐式输入:
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
编译训练图#
本节将通过一个简单的示例逐步介绍如何使用compile(),这是一个常见的设置:使用mlx.nn.Module训练模型,并使用带有状态的mlx.optimizers.Optimizer。我们将展示如何使用compile()编译完整的前向、反向和更新过程。
首先,这里是一个没有任何编译的简单示例:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
为了编译更新,我们可以将其全部放入一个函数中,并使用适当的输入和输出捕获进行编译。以下是相同的示例,但已编译:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
注意
如果你正在使用一个执行随机采样的模块,例如
mlx.nn.Dropout(),请确保你也将 mx.random.state 包含在
compile() 捕获的 state 中,即 state = [model.state,
optimizer.state, mx.random.state]。
注意
有关编译完整训练图的更多示例,请查看MLX 示例 GitHub 仓库。
使用编译进行转换#
在MLX中,函数变换是可组合的。你可以将任何函数变换应用于任何其他函数变换的输出。有关更多信息,请参阅函数变换的文档。
编译转换后的函数按预期工作:
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
注意
为了尽可能多地编译,默认情况下不会编译已编译函数的转换。要编译转换后的函数,只需将其通过compile()传递即可。
你也可以编译那些本身调用已编译函数的函数。一个好的做法是编译最外层的函数,以便给compile()提供最大的机会来优化计算图:
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)
无形状编译#
当编译函数的输入形状发生变化时,函数会被重新编译。你可以通过指定shapeless=True给compile()来编译一次函数并在具有可变形状的输入上运行它。在这种情况下,输入形状的变化不会导致函数被重新编译。
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
请谨慎使用无形状编译。由于形状改变时不会触发编译,任何依赖于输入形状的图形都不会按预期工作。依赖于形状的计算很常见,有时难以察觉。例如:
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
第二次调用compiled_fun失败是因为调用了reshape(),它在第一次调用中使用了x的静态形状。我们可以通过使用flatten()来避免硬编码x的形状来解决这个问题:
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)