函数转换#
MLX 使用可组合的函数变换来实现自动微分、向量化和计算图优化。要查看完整的函数变换列表,请查阅 API 文档。
可组合函数变换的关键思想在于,每一个变换都会返回一个可以进一步变换的函数。
这是一个简单的例子:
>>> dfdx = mx.grad(mx.sin)
>>> dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
>>> mx.cos(mx.array(mx.pi))
array(-1, dtype=float32)
grad() 在 sin() 上的输出只是另一个函数。在这种情况下,它是正弦函数的梯度,正好是余弦函数。要获得二阶导数,你可以这样做:
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
>>> d2fdx2(mx.array(mx.pi / 2))
array(-1, dtype=float32)
>>> mx.sin(mx.array(mx.pi / 2))
array(1, dtype=float32)
在grad()的输出上使用grad()总是可以的。你会继续得到更高阶的导数。
任何MLX函数变换都可以以任何顺序组合到任何深度。有关自动微分和自动向量化的更多信息,请参见以下部分。有关compile()的更多信息,请参见编译文档。
自动微分#
MLX中的自动微分是在函数上工作,而不是在隐式图上。
注意
如果你是从PyTorch转到MLX,你不再需要像backward、zero_grad和detach这样的函数,或者像requires_grad这样的属性。
最基本的例子是如上所示,取标量值函数的梯度。你可以使用grad()和value_and_grad()函数来计算更复杂函数的梯度。默认情况下,这些函数计算相对于第一个参数的梯度:
def loss_fn(w, x, y):
return mx.mean(mx.square(w * x - y))
w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
print(dloss_dw)
# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
print(dloss_dx)
获取损失和梯度的一种方法是调用loss_fn,然后调用grad_fn,但这可能会导致大量冗余工作。相反,你应该使用value_and_grad()。继续上面的例子:
# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)
# Prints array(1, dtype=float32)
print(loss)
# Prints array(-1, dtype=float32)
print(dloss_dw)
你也可以对任意嵌套的Python数组容器(特别是list、tuple或dict中的任意一种)进行梯度计算。
假设我们想要在上面的例子中添加一个权重和一个偏置参数。一个很好的方法如下:
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to both the
# weight and bias:
grad_fn = mx.grad(loss_fn)
grads = grad_fn(params, x, y)
# Prints
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
print(grads)
请注意,参数的树状结构在梯度中得以保留。
在某些情况下,您可能希望阻止梯度通过函数的一部分传播。您可以使用stop_gradient()来实现这一点。
自动向量化#
使用vmap()来自动化向量化复杂函数。这里我们将通过一个基本且人为的例子来说明,但vmap()对于更复杂且难以手动优化的函数来说非常强大。
警告
一些操作尚未支持vmap()。如果您遇到类似ValueError: Primitive's vmap not implemented.的错误,请提交一个issue并包含您的函数。我们将优先考虑将其包含在内。
一种简单的方法来添加两组向量的元素是使用循环:
xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
相反,你可以使用 vmap() 来自动向量化加法:
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
in_axes 参数可用于指定要对相应输入的哪些维度进行向量化。同样地,使用 out_axes 来指定向量化轴在输出中的位置。
让我们来计时这两个不同的版本:
import timeit
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
在M1 Max上,朴素版本总共需要5.639秒,而向量化版本仅需0.024秒,快了200多倍。
当然,这个操作相当人为。更好的方法是简单地执行
xs + ys.T,但对于更复杂的函数,vmap() 可以非常方便。