mlx.nn.Module.apply#
- 模块。应用(映射函数: Callable[[array], array], 过滤函数: Callable[[Module, str, Any], bool] | None = 无) → 模块#
使用提供的
map_fn映射所有参数,并立即使用映射后的参数更新模块。例如,运行
model.apply(lambda x: x.astype(mx.float16))将所有参数转换为16位浮点数。- Parameters:
map_fn (Callable) – 将一个数组映射到另一个数组
filter_fn (Callable, optional) – 用于选择要映射的数组的过滤器(默认:
Module.valid_parameter_filter())。
- Returns:
更新参数后的模块实例。