神经网络#
在MLX中编写任意复杂的神经网络仅需使用mlx.core.array和mlx.core.value_and_grad()即可完成。然而,这要求用户反复编写相同的简单神经网络操作,并手动且显式地处理所有参数状态和初始化。
模块 mlx.nn 通过提供一种直观的方式来组合神经网络层、初始化它们的参数、冻结它们以进行微调等,解决了这个问题。
神经网络快速入门#
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
模块类#
任何神经网络库的核心是Module类。在MLX中,Module类是mlx.core.array或Module实例的容器。它的主要功能是提供一种递归访问和更新其参数及其子模块参数的方式。
参数#
模块的参数是类型为mlx.core.array的任何公共成员(其名称不应以_开头)。它可以任意嵌套在其他Module实例或列表和字典中。
Module.parameters() 可用于提取包含模块及其子模块所有参数的嵌套字典。
一个Module也可以跟踪“冻结”的参数。更多详情请参见
Module.freeze()方法。mlx.nn.value_and_grad()
返回的梯度将针对这些可训练的参数。
更新参数#
MLX 模块允许访问和更新单个参数。然而,大多数时候我们需要更新模块参数的大量子集。这个操作由 Module.update() 执行。
检查模块#
查看模型架构的最简单方法是打印它。按照上面的示例,您可以使用以下代码打印MLP:
print(mlp)
这将显示:
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
要获取有关Module中数组的更多详细信息,您可以在参数上使用mlx.utils.tree_map()。例如,要查看Module中所有参数的形状,请执行以下操作:
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
再举一个例子,你可以使用以下方法计算Module中的参数数量:
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
值和梯度#
使用Module并不妨碍使用MLX的高阶函数转换(mlx.core.value_and_grad(), mlx.core.grad()等)。然而,这些函数转换假设是纯函数,即参数应作为被转换函数的参数传递。
使用MLX模块实现这一目标有一个简单的模式
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
然而,mlx.nn.value_and_grad() 提供了这种模式,并且仅计算模型可训练参数的梯度。
详细内容:
它用一个调用
Module.update()的函数包装传递的函数,以确保模型使用提供的参数。它调用
mlx.core.value_and_grad()将函数转换为一个也能计算传递参数梯度的函数。它用一个函数包装返回的函数,该函数将可训练参数作为第一个参数传递给由
mlx.core.value_and_grad()返回的函数。
|
将传递的函数 |
|
根据谓词对模块的子模块进行量化。 |
- Module
模块- mlx.nn.Module.training
- mlx.nn.Module.state
- mlx.nn.Module.apply
- mlx.nn.Module.apply_to_modules
- mlx.nn.Module.children
- mlx.nn.Module.eval
- mlx.nn.Module.filter_and_map
- mlx.nn.Module.freeze
- mlx.nn.Module.leaf_modules
- mlx.nn.Module.load_weights
- mlx.nn.Module.modules
- mlx.nn.Module.named_modules
- mlx.nn.Module.parameters
- mlx.nn.Module.save_weights
- mlx.nn.Module.set_dtype
- mlx.nn.Module.train
- mlx.nn.Module.trainable_parameters
- mlx.nn.Module.unfreeze
- mlx.nn.Module.update
- mlx.nn.Module.update_modules
- Layers
- mlx.nn.ALiBi
- mlx.nn.AvgPool1d
- mlx.nn.AvgPool2d
- mlx.nn.AvgPool3d
- mlx.nn.BatchNorm
- mlx.nn.CELU
- mlx.nn.Conv1d
- mlx.nn.Conv2d
- mlx.nn.Conv3d
- mlx.nn.ConvTranspose1d
- mlx.nn.ConvTranspose2d
- mlx.nn.ConvTranspose3d
- mlx.nn.Dropout
- mlx.nn.Dropout2d
- mlx.nn.Dropout3d
- mlx.nn.Embedding
- mlx.nn.ELU
- mlx.nn.GELU
- mlx.nn.GLU
- mlx.nn.GroupNorm
- mlx.nn.GRU
- mlx.nn.HardShrink
- mlx.nn.HardTanh
- mlx.nn.Hardswish
- mlx.nn.InstanceNorm
- mlx.nn.LayerNorm
- mlx.nn.LeakyReLU
- mlx.nn.Linear
- mlx.nn.LogSigmoid
- mlx.nn.LogSoftmax
- mlx.nn.LSTM
- mlx.nn.MaxPool1d
- mlx.nn.MaxPool2d
- mlx.nn.MaxPool3d
- mlx.nn.Mish
- mlx.nn.MultiHeadAttention
- mlx.nn.PReLU
- mlx.nn.QuantizedEmbedding
- mlx.nn.QuantizedLinear
- mlx.nn.RMSNorm
- mlx.nn.ReLU
- mlx.nn.ReLU6
- mlx.nn.RNN
- mlx.nn.RoPE
- mlx.nn.SELU
- mlx.nn.Sequential
- mlx.nn.Sigmoid
- mlx.nn.SiLU
- mlx.nn.SinusoidalPositionalEncoding
- mlx.nn.Softmin
- mlx.nn.Softshrink
- mlx.nn.Softsign
- mlx.nn.Softmax
- mlx.nn.Softplus
- mlx.nn.Step
- mlx.nn.Tanh
- mlx.nn.Transformer
- mlx.nn.Upsample
- Functions
- mlx.nn.elu
- mlx.nn.celu
- mlx.nn.gelu
- mlx.nn.gelu_approx
- mlx.nn.gelu_fast_approx
- mlx.nn.glu
- mlx.nn.hard_shrink
- mlx.nn.hard_tanh
- mlx.nn.hardswish
- mlx.nn.leaky_relu
- mlx.nn.log_sigmoid
- mlx.nn.log_softmax
- mlx.nn.mish
- mlx.nn.prelu
- mlx.nn.relu
- mlx.nn.relu6
- mlx.nn.selu
- mlx.nn.sigmoid
- mlx.nn.silu
- mlx.nn.softmax
- mlx.nn.softmin
- mlx.nn.softplus
- mlx.nn.softshrink
- mlx.nn.step
- mlx.nn.tanh
- Loss Functions
- mlx.nn.losses.binary_cross_entropy
- mlx.nn.losses.cosine_similarity_loss
- mlx.nn.losses.cross_entropy
- mlx.nn.losses.gaussian_nll_loss
- mlx.nn.losses.hinge_loss
- mlx.nn.losses.huber_loss
- mlx.nn.losses.kl_div_loss
- mlx.nn.losses.l1_loss
- mlx.nn.losses.log_cosh_loss
- mlx.nn.losses.margin_ranking_loss
- mlx.nn.losses.mse_loss
- mlx.nn.losses.nll_loss
- mlx.nn.losses.smooth_l1_loss
- mlx.nn.losses.triplet_loss
- Initializers