模块

目录

模块#

模块#

用于使用MLX构建神经网络的基类。

mlx.nn.layers中提供的所有层都继承自这个类,你的模型也应该这样做。

一个Module可以包含其他Module实例或mlx.core.array实例,这些实例可以任意嵌套在Python列表或字典中。然后,Module允许使用mlx.nn.Module.parameters()递归提取所有mlx.core.array实例。

此外,Module 具有可训练和不可训练参数(称为“冻结”)的概念。当使用 mlx.nn.value_and_grad() 时,梯度仅针对可训练参数返回。模块中的所有数组都是可训练的,除非它们通过调用 freeze() 被添加到“冻结”集合中。

import mlx.core as mx
import mlx.nn as nn

class MyMLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
        super().__init__()

        self.in_proj = nn.Linear(in_dims, hidden_dims)
        self.out_proj = nn.Linear(hidden_dims, out_dims)

    def __call__(self, x):
        x = self.in_proj(x)
        x = mx.maximum(x, 0)
        return self.out_proj(x)

model = MyMLP(2, 1)

# All the model parameters are created but since MLX is lazy by
# default, they are not evaluated yet. Calling `mx.eval` actually
# allocates memory and initializes the parameters.
mx.eval(model.parameters())

# Setting a parameter to a new value is as simply as accessing that
# parameter and assigning a new array to it.
model.in_proj.weight = model.in_proj.weight * 2
mx.eval(model.parameters())

属性

Module.training

布尔值,指示模型是否处于训练模式。

Module.state

模块的状态字典

方法

Module.apply(map_fn[, filter_fn])

使用提供的map_fn映射所有参数,并立即使用映射后的参数更新模块。

Module.apply_to_modules(apply_fn)

将函数应用于此实例中的所有模块(包括此实例)。

Module.children()

返回此Module实例的直接子代。

Module.eval()

将模型设置为评估模式。

Module.filter_and_map(filter_fn[, map_fn, ...])

使用filter_fn递归过滤模块的内容,即仅选择filter_fn返回为真的键和值。

Module.freeze(*[, recurse, keys, strict])

冻结模块的参数或其中一部分。

Module.leaf_modules()

返回不包含其他模块的子模块。

Module.load_weights(file_or_weights[, strict])

.npz.safetensors文件或列表中更新模型的权重。

Module.modules()

返回此实例中所有模块的列表。

Module.named_modules()

返回一个列表,包含此实例中的所有模块及其使用点符号表示的名称。

Module.parameters()

递归地返回此模块的所有mlx.core.array成员,作为字典和列表的字典。

Module.save_weights(file)

将模型的权重保存到文件中。

Module.set_dtype(dtype[, predicate])

设置模块参数的dtype。

Module.train([mode])

设置模型进入或退出训练模式。

Module.trainable_parameters()

递归地返回此模块中所有未冻结的mlx.core.array成员,作为字典和列表的字典。

Module.unfreeze(*[, recurse, keys, strict])

解冻模块的参数或其中一部分。

Module.update(parameters)

用提供的字典和列表中的参数替换此模块的参数。

Module.update_modules(modules)

用提供的字典和列表中的子模块替换此Module实例的子模块。