模块#
- 类 模块#
用于使用MLX构建神经网络的基类。
在
mlx.nn.layers中提供的所有层都继承自这个类,你的模型也应该这样做。一个
Module可以包含其他Module实例或mlx.core.array实例,这些实例可以任意嵌套在Python列表或字典中。然后,Module允许使用mlx.nn.Module.parameters()递归提取所有mlx.core.array实例。此外,
Module具有可训练和不可训练参数(称为“冻结”)的概念。当使用mlx.nn.value_and_grad()时,梯度仅针对可训练参数返回。模块中的所有数组都是可训练的,除非它们通过调用freeze()被添加到“冻结”集合中。import mlx.core as mx import mlx.nn as nn class MyMLP(nn.Module): def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): super().__init__() self.in_proj = nn.Linear(in_dims, hidden_dims) self.out_proj = nn.Linear(hidden_dims, out_dims) def __call__(self, x): x = self.in_proj(x) x = mx.maximum(x, 0) return self.out_proj(x) model = MyMLP(2, 1) # All the model parameters are created but since MLX is lazy by # default, they are not evaluated yet. Calling `mx.eval` actually # allocates memory and initializes the parameters. mx.eval(model.parameters()) # Setting a parameter to a new value is as simply as accessing that # parameter and assigning a new array to it. model.in_proj.weight = model.in_proj.weight * 2 mx.eval(model.parameters())
属性
布尔值,指示模型是否处于训练模式。
模块的状态字典
方法
Module.apply(map_fn[, filter_fn])使用提供的
map_fn映射所有参数,并立即使用映射后的参数更新模块。Module.apply_to_modules(apply_fn)将函数应用于此实例中的所有模块(包括此实例)。
返回此Module实例的直接子代。
将模型设置为评估模式。
Module.filter_and_map(filter_fn[, map_fn, ...])使用
filter_fn递归过滤模块的内容,即仅选择filter_fn返回为真的键和值。Module.freeze(*[, recurse, keys, strict])冻结模块的参数或其中一部分。
返回不包含其他模块的子模块。
Module.load_weights(file_or_weights[, strict])从
.npz、.safetensors文件或列表中更新模型的权重。返回此实例中所有模块的列表。
返回一个列表,包含此实例中的所有模块及其使用点符号表示的名称。
递归地返回此模块的所有
mlx.core.array成员,作为字典和列表的字典。Module.save_weights(file)将模型的权重保存到文件中。
Module.set_dtype(dtype[, predicate])设置模块参数的dtype。
Module.train([mode])设置模型进入或退出训练模式。
递归地返回此模块中所有未冻结的
mlx.core.array成员,作为字典和列表的字典。Module.unfreeze(*[, recurse, keys, strict])解冻模块的参数或其中一部分。
Module.update(parameters)用提供的字典和列表中的参数替换此模块的参数。
Module.update_modules(modules)用提供的字典和列表中的子模块替换此
Module实例的子模块。