函数变换

函数转换#

MLX 使用可组合的函数变换来实现自动微分、向量化和计算图优化。要查看完整的函数变换列表,请查阅 API 文档

可组合函数变换的关键思想在于,每一个变换都会返回一个可以进一步变换的函数。

这是一个简单的例子:

>>> dfdx = mx.grad(mx.sin)
>>> dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
>>> mx.cos(mx.array(mx.pi))
array(-1, dtype=float32)

grad()sin() 上的输出只是另一个函数。在这种情况下,它是正弦函数的梯度,正好是余弦函数。要获得二阶导数,你可以这样做:

>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
>>> d2fdx2(mx.array(mx.pi / 2))
array(-1, dtype=float32)
>>> mx.sin(mx.array(mx.pi / 2))
array(1, dtype=float32)

grad()的输出上使用grad()总是可以的。你会继续得到更高阶的导数。

任何MLX函数变换都可以以任何顺序组合到任何深度。有关自动微分自动向量化的更多信息,请参见以下部分。有关compile()的更多信息,请参见编译文档

自动微分#

MLX中的自动微分是在函数上工作,而不是在隐式图上。

注意

如果你是从PyTorch转到MLX,你不再需要像backwardzero_graddetach这样的函数,或者像requires_grad这样的属性。

最基本的例子是如上所示,取标量值函数的梯度。你可以使用grad()value_and_grad()函数来计算更复杂函数的梯度。默认情况下,这些函数计算相对于第一个参数的梯度:

def loss_fn(w, x, y):
   return mx.mean(mx.square(w * x - y))

w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])

# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
print(dloss_dw)

# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
print(dloss_dx)

获取损失和梯度的一种方法是调用loss_fn,然后调用grad_fn,但这可能会导致大量冗余工作。相反,你应该使用value_and_grad()。继续上面的例子:

# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)

# Prints array(1, dtype=float32)
print(loss)

# Prints array(-1, dtype=float32)
print(dloss_dw)

你也可以对任意嵌套的Python数组容器(特别是listtupledict中的任意一种)进行梯度计算。

假设我们想要在上面的例子中添加一个权重和一个偏置参数。一个很好的方法如下:

def loss_fn(params, x, y):
   w, b = params["weight"], params["bias"]
   h = w * x + b
   return mx.mean(mx.square(h - y))

params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])

# Computes the gradient of loss_fn with respect to both the
# weight and bias:
grad_fn = mx.grad(loss_fn)
grads = grad_fn(params, x, y)

# Prints
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
print(grads)

请注意,参数的树状结构在梯度中得以保留。

在某些情况下,您可能希望阻止梯度通过函数的一部分传播。您可以使用stop_gradient()来实现这一点。

自动向量化#

使用vmap()来自动化向量化复杂函数。这里我们将通过一个基本且人为的例子来说明,但vmap()对于更复杂且难以手动优化的函数来说非常强大。

警告

一些操作尚未支持vmap()。如果您遇到类似ValueError: Primitive's vmap not implemented.的错误,请提交一个issue并包含您的函数。我们将优先考虑将其包含在内。

一种简单的方法来添加两组向量的元素是使用循环:

xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))

def naive_add(xs, ys):
    return [xs[i] + ys[:, i] for i in range(xs.shape[0])]

相反,你可以使用 vmap() 来自动向量化加法:

# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))

in_axes 参数可用于指定要对相应输入的哪些维度进行向量化。同样地,使用 out_axes 来指定向量化轴在输出中的位置。

让我们来计时这两个不同的版本:

import timeit

print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))

在M1 Max上,朴素版本总共需要5.639秒,而向量化版本仅需0.024秒,快了200多倍。

当然,这个操作相当人为。更好的方法是简单地执行 xs + ys.T,但对于更复杂的函数,vmap() 可以非常方便。