jax.tree_util.tree_map_with_path#
- jax.tree_util.tree_map_with_path(f, tree, *rest, is_leaf=None)[源代码]#
将一个多输入函数映射到 pytree 键路径和参数上,以生成一个新的 pytree。
这是
tree_map的一个更强大的替代方案,它可以同时将每个叶子的键路径作为输入参数。- 参数:
f (Callable[..., Any]) – 接受
2 + len(rest)参数的函数,即键路径和每个对应的 pytrees 的叶子。tree (Any) – 一个要映射的 pytree,每个叶子的键路径作为第一个位置参数,叶子本身作为第二个参数传递给
f。*rest (Any) – 一个pytrees的元组,每个pytree与``tree``具有相同的结构,或者作为``tree``的前缀。
is_leaf (Callable[[Any], bool] | None)
- 返回:
一个新的 pytree ,其结构与
tree相同,但每个叶子的值由f(kp, x, *xs)给出,其中kp是tree中相应叶子的键路径,x是叶子值,xs是rest中相应节点的值的元组。- 返回类型:
Any