mlx.utils.tree_map

目录

mlx.utils.tree_map#

树形图(函数 可调用的, 任何, *休息 任何, 是否为叶子节点 可调用 | = ) 任何#

fn应用于Python树tree的叶子节点,并返回包含结果的新集合。

如果提供了rest,则假定每个项目都是tree的超集, 并且相应的叶子作为额外的位置参数提供给fn。在这方面,tree_map()更接近于itertools.starmap(), 而不是map()

关键字参数 is_leaf 决定了什么构成了 tree 中的叶子,类似于 tree_flatten()

import mlx.nn as nn
from mlx.utils import tree_map

model = nn.Linear(10, 10)
print(model.parameters().keys())
# dict_keys(['weight', 'bias'])

# square the parameters
model.update(tree_map(lambda x: x*x, model.parameters()))
Parameters
  • fn (可调用) – 处理树叶子节点的函数。

  • tree (Any) – 将被迭代的主要Python树。

  • rest (tuple[Any]) – 要与 tree 一起迭代的额外树。

  • is_leaf (可调用的, 可选的) – 一个可选的调用对象,如果传递的对象被认为是叶子节点,则返回 True,否则返回 False

Returns

一个带有fn返回的新值的Python树。