mlx.nn.Module.apply

mlx.nn.Module.apply#

模块。应用(映射函数 Callable[[array], array], 过滤函数 Callable[[Module, str, Any], bool] | None = ) 模块#

使用提供的map_fn映射所有参数,并立即使用映射后的参数更新模块。

例如,运行 model.apply(lambda x: x.astype(mx.float16)) 将所有参数转换为16位浮点数。

Parameters
  • map_fn (Callable) – 将一个数组映射到另一个数组

  • filter_fn (Callable, optional) – 用于选择要映射的数组的过滤器(默认:Module.valid_parameter_filter())。

Returns

更新参数后的模块实例。