mlx.utils.tree_map#
- 树形图(函数: 可调用的, 树: 任何, *休息: 任何, 是否为叶子节点: 可调用 | 无 = 无) → 任何#
将
fn应用于Python树tree的叶子节点,并返回包含结果的新集合。如果提供了
rest,则假定每个项目都是tree的超集, 并且相应的叶子作为额外的位置参数提供给fn。在这方面,tree_map()更接近于itertools.starmap(), 而不是map()。关键字参数
is_leaf决定了什么构成了tree中的叶子,类似于tree_flatten()。import mlx.nn as nn from mlx.utils import tree_map model = nn.Linear(10, 10) print(model.parameters().keys()) # dict_keys(['weight', 'bias']) # square the parameters model.update(tree_map(lambda x: x*x, model.parameters()))
- Parameters:
fn (可调用) – 处理树叶子节点的函数。
tree (Any) – 将被迭代的主要Python树。
rest (tuple[Any]) – 要与
tree一起迭代的额外树。is_leaf (可调用的, 可选的) – 一个可选的调用对象,如果传递的对象被认为是叶子节点,则返回
True,否则返回False。
- Returns:
一个带有
fn返回的新值的Python树。