惰性求值#

为什么需要惰性求值#

当你在MLX中执行操作时,实际上并没有发生任何计算。相反,会记录一个计算图。只有在执行eval()时,才会进行实际的计算。

MLX 使用惰性求值,因为它具有一些很好的特性,其中一些我们将在下面描述。

转换计算图#

惰性求值允许我们记录计算图而不实际进行任何计算。这对于像grad()vmap()这样的函数变换以及图优化非常有用。

目前,MLX 不会编译和重新运行计算图。它们都是动态生成的。然而,惰性求值使得未来集成编译以提升性能变得更加容易。

只计算你使用的内容#

在MLX中,您不需要过多担心计算从未使用的输出。例如:

def fun(x):
    a = fun1(x)
    b = expensive_fun(a)
    return a, b

y, _ = fun(x)

在这里,我们实际上从未计算过expensive_fun的输出。不过,使用这种模式时要小心,因为expensive_fun的图仍然会被构建,这会带来一些相关的成本。

同样,惰性评估在保持代码简单的同时,也有助于节省内存。假设你有一个非常大的模型 Model,它继承自 mlx.nn.Module。你可以通过 model = Model() 来实例化这个模型。通常,这会将所有权重初始化为 float32,但在你执行 eval() 之前,初始化实际上不会计算任何内容。如果你用 float16 权重更新模型,你的最大内存消耗将比使用急切计算时所需的一半还要少。

由于惰性计算,这种模式在MLX中很容易实现:

model = Model() # no memory used yet
model.load_weights("weights_fp16.safetensors")

何时评估#

一个常见的问题是何时使用eval()。权衡在于让图变得太大和不批量处理足够有用的工作之间。

例如:

for _ in range(100):
     a = a + b
     mx.eval(a)
     b = b * 2
     mx.eval(b)

这是一个不好的主意,因为每次图形评估都会有一些固定的开销。另一方面,随着计算图大小的增加,会有一些轻微的开销,所以非常大的图(虽然在计算上是正确的)可能会很昂贵。

幸运的是,MLX 可以很好地处理各种大小的计算图:从每次评估几十个操作到数千个操作都可以。

大多数数值计算都有一个迭代的外部循环(例如,随机梯度下降中的迭代)。使用eval()的自然且通常高效的地方是在这个外部循环的每次迭代中。

这是一个具体的例子:

for batch in dataset:

    # Nothing has been evaluated yet
    loss, grad = value_and_grad_fn(model, batch)

    # Still nothing has been evaluated
    optimizer.update(model, grad)

    # Evaluate the loss and the new parameters which will
    # run the full gradient computation and optimizer update
    mx.eval(loss, model.parameters())

一个需要注意的重要行为是图形何时会被隐式评估。每当你print一个数组,将其转换为numpy.ndarray,或者通过memoryview访问其内存时,图形都会被评估。通过save()(或任何其他MLX保存函数)保存数组也会评估数组。

在标量数组上调用array.item()也会对其进行评估。在上面的例子中,打印损失(print(loss))或将损失标量添加到列表(losses.append(loss.item()))会导致图形评估。如果这些行在mx.eval(loss, model.parameters())之前,那么这将是一个部分评估,仅计算前向传递。

此外,多次对数组或数组集调用eval()是完全没问题的。这实际上是一个无操作。

警告

使用标量数组进行控制流将导致评估。

这是一个示例:

def fun(x):
    h, y = first_layer(x)
    if y > 0:  # An evaluation is done here!
        z  = second_layer_a(h)
    else:
        z  = second_layer_b(h)
    return z

使用数组进行控制流时应谨慎。上述示例有效,甚至可以用于梯度变换。然而,如果评估过于频繁,这可能会非常低效。