快速入门指南#
基础#
导入 mlx.core 并创建一个 array:
>> import mlx.core as mx
>> a = mx.array([1, 2, 3, 4])
>> a.shape
[4]
>> a.dtype
int32
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
>> b.dtype
float32
MLX中的操作是惰性的。MLX操作的输出在需要时才会被计算。要强制评估一个数组,请使用eval()。在某些情况下,数组会自动被评估。例如,使用array.item()检查标量、打印数组或将数组从array转换为numpy.ndarray时,都会自动评估数组。
>> c = a + b # c not yet evaluated
>> mx.eval(c) # evaluates c
>> c = a + b
>> print(c) # Also evaluates c
array([2, 4, 6, 8], dtype=float32)
>> c = a + b
>> import numpy as np
>> np.array(c) # Also evaluates c
array([2., 4., 6., 8.], dtype=float32)
有关更多详细信息,请参阅惰性求值页面。
函数与图形变换#
MLX 具有标准的函数转换,如 grad() 和 vmap()。
转换可以任意组合。例如
grad(vmap(grad(fn)))(或任何其他组合)是允许的。
>> x = mx.array(0.0)
>> mx.sin(x)
array(0, dtype=float32)
>> mx.grad(mx.sin)(x)
array(1, dtype=float32)
>> mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)
其他梯度变换包括vjp()用于向量-雅可比积
和jvp()用于雅可比-向量积。
使用 value_and_grad() 来高效计算函数的输出及其相对于输入的梯度。