jax.flatten_util.ravel_pytree

jax.flatten_util.ravel_pytree#

jax.flatten_util.ravel_pytree(pytree)[源代码][源代码]#

将一个数组的 pytree 展平为 1D 数组。

参数:

pytree – 展平数组和标量的pytree。

返回:

一对元素,其中第一个元素是一个一维数组,表示扁平化和连接的叶子值,其dtype由提升叶子值的dtype确定,第二个元素是一个可调用对象,用于将相同长度的一维向量反扁平化为与输入 pytree 相同结构的pytree。如果输入的pytree为空(即没有叶子),则按照惯例,输出第一个组件中返回一个dtype为float32的一维空数组。

有关 dtype 提升的详细信息,请参阅 https://jax.readthedocs.io/en/latest/type_promotion.html