jax.tree_util.tree_flatten_with_path# jax.tree_util.tree_flatten_with_path(tree, is_leaf=None)[源代码]# 类似于 tree_flatten,但还会返回每个叶子的键路径。 参数: tree (Any) – 一个要展平的 pytree。如果它包含自定义类型,则必须使用 register_pytree_with_keys 进行注册。 is_leaf (Callable[[Any], bool] | None) 返回: 一对,其中第一个元素是键-叶对列表,每个键-叶对包含一个叶及其键路径。第二个元素是一个表示扁平化树结构的treedef。 返回类型: tuple[list[tuple[KeyPath, Any]], PyTreeDef]