工具#

一般#

scale_gradient(inputs, scale)

为反向传播缩放梯度。

value_and_grad_from_state(value_fn)

从状态中获取值和梯度的替代方案 jax.value_and_grad

比例梯度#

optax.scale_gradient(inputs: chex.ArrayTree, scale: float) chex.ArrayTree[来源]#

为反向传播缩放梯度。

Parameters:
  • inputs – 一个嵌套数组。

  • scale – 反向传播中梯度的缩放因子。

Returns:

一个与 inputs 具有相同结构的数组,具有缩放的向后梯度。

从状态获取值和梯度#

optax.value_and_grad_from_state(value_fn: Callable[[...], 数组 | float]) Callable[[...], tuple[float | 数组, TypeAliasForwardRef('optax.Updates')]][来源]#

替代 jax.value_and_grad 的方法,从状态中获取值和梯度。

诸如 optax.scale_by_backtracking_linesearch() 的线搜索方法需要在候选迭代点计算梯度和目标函数。这个目标值和梯度可以在下一个迭代中重新使用,以利用这个实用函数节省一些计算。

Parameters:

value_fn – 返回一个标量(float或维度为1的数组)的函数,可以在jax中使用 jax.value_and_grad() 进行微分。

Returns:

一个类似于 jax.value_and_grad() 的可调用函数,如果状态中存在,则获取值和梯度。如果未找到值或梯度,或者找到多个值和梯度,该函数会引发错误。如果找到的值是无限大或nan,则使用 jax.value_and_grad() 计算值和梯度。如果在状态中找到的梯度为None,则引发错误。

示例

>>> import optax
>>> import jax.numpy as jnp
>>> def fn(x): return jnp.sum(x ** 2)
>>> solver = optax.chain(
...     optax.sgd(learning_rate=1.),
...     optax.scale_by_backtracking_linesearch(
...         max_backtracking_steps=15, store_grad=True
...     )
... )
>>> value_and_grad = optax.value_and_grad_from_state(fn)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 1.40E+01
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...   value, grad = value_and_grad(params, state=opt_state)
...   updates, opt_state = solver.update(
...       grad, opt_state, params, value=value, grad=grad, value_fn=fn
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 5.04E+00
Objective function: 1.81E+00
Objective function: 6.53E-01
Objective function: 2.35E-01
Objective function: 8.47E-02

数值稳定性#

safe_increment(count)

在避免溢出的情况下将计数器增加一个。

safe_norm(x, min_norm[, ord, axis, keepdims])

返回 jnp.maximum(jnp.linalg.norm(x), min_norm),并带有正确的梯度。

safe_root_mean_squares(x, min_rms)

返回 最大值(sqrt(mean(abs_sq(x))), min_norm),并带有正确的梯度。

安全递增#

optax.safe_increment(count: chex.Numeric) chex.Numeric[来源]#

在避免溢出的情况下将计数器增加一个。

Denote max_val, min_val as the maximum, minimum, possible values for the dtype of count. Normally max_val + 1 would overflow to min_val. This functions ensures that when max_val is reached the counter stays at max_val.

Parameters:

count – 一个要递增的计数器。

Returns:

一个每次递增1的计数器,或 max_val 如果达到最大值时。

示例

>>> import jax.numpy as jnp
>>> import optax
>>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32))
Array(2, dtype=int32)
>>> optax.safe_increment(jnp.asarray(2147483647, dtype=jnp.int32))
Array(2147483647, dtype=int32)

在版本 0.2.4 中添加。

安全范数#

optax.safe_norm(x: chex.Array, min_norm: chex.Numeric, ord: int | float | str | None = None, axis: None | tuple[int, ...] | int = None, keepdims: bool = False) chex.Array[来源]#

返回 jnp.maximum(jnp.linalg.norm(x), min_norm) 并具有正确的梯度。

在0.0处,jnp.maximum(jnp.linalg.norm(x), min_norm)的梯度为NaN,因为jax会评估jnp.maximum的两个分支。该函数在这种情况下也会返回正确的0.0梯度。

Parameters:
  • x – jax 数组。

  • min_norm – 返回的范数的下限。

  • ord – {非零整数, inf, -inf, ‘fro’, ‘nuc’}, 可选。范数的阶。 inf表示numpy的inf对象。默认值为None。

  • axis – {无,整数,整数的2元组},可选。如果轴是一个整数,它指定沿着计算向量范数的 x 的轴。如果轴是一个2元组,它指定保持二维矩阵的轴,这些矩阵的矩阵范数将被计算。如果轴是 None,则返回向量范数(当 x 是1维时)或矩阵范数(当 x 是2维时)。默认值是 None。

  • keepdims – 布尔值, 可选。如果设置为 True,那么被归一化的轴将以大小为一的维度保留在结果中。使用此选项时,结果将能够正确地与原始 x 进行广播。

Returns:

输入向量的安全范数,考虑了正确的梯度。

安全的均方根#

optax.safe_root_mean_squares(x: chex.Array, min_rms: chex.Numeric) chex.Array[来源]#

返回 最大值(sqrt(mean(abs_sq(x))), min_norm),并具有正确的梯度。

在0.0处,maximum(sqrt(mean(abs_sq(x))), min_norm) 的梯度是NaN,因为jax将评估jnp.maximum的两个分支。这个函数在这种情况下也将返回正确的0.0的梯度。

Parameters:
  • x – jax 数组。

  • min_rms – 返回的范数的下限。

Returns:

输入向量的安全 RMS,考虑到正确的梯度。

线性代数运算符#

matrix_inverse_pth_root(matrix, p[, ...])

计算 matrix^(-1/p),其中 p 是一个正整数。

power_iteration(matrix, *[, v0, num_iters, ...])

幂迭代算法。

nnls(A, b, iters[, unroll, L])

求解非负最小二乘问题。

矩阵逆 p 阶根#

optax.matrix_inverse_pth_root(matrix: chex.Array, p: int, num_iters: int = 100, ridge_epsilon: float = 1e-06, error_tolerance: float = 1e-06, precision: 精准度 = Precision.HIGHEST)[来源]#

计算 matrix^(-1/p),其中 p 是一个正整数。

此函数使用耦合牛顿迭代算法来计算矩阵的逆pth根。

Parameters:
  • matrix – 要计算其幂的对称PSD矩阵

  • p – 指数,对于p为正整数。

  • num_iters – 最大迭代次数。

  • ridge_epsilon – 加入的Ridge epsilon以使矩阵为正定的。

  • error_tolerance – 错误指示器,适用于提前终止。

  • precision – 精度 XLA 相关标志, 可选项有: a) lax.Precision.DEFAULT (更好的步骤时间,但不精确); b) lax.Precision.HIGH (增加精度,速度较慢); c) lax.Precision.HIGHEST (可能的最佳精度,速度最慢)。

Returns:

矩阵^(-1/p)

参考文献

[Functions of Matrices, Theory and Computation,

Nicholas J Higham, 第184页, 方程 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)

幂迭代#

optax.power_iteration(matrix: TypeAliasForwardRef('chex.Array') | Callable[[TypeAliasForwardRef('chex.ArrayTree')], TypeAliasForwardRef('chex.ArrayTree')], *, v0: TypeAliasForwardRef('chex.ArrayTree') | None = None, num_iters: int = 100, error_tolerance: float = 1e-06, precision: 精度 = Precision.HIGHEST, key: 数组 | None = None) tuple[TypeAliasForwardRef('chex.Numeric'), TypeAliasForwardRef('chex.ArrayTree')][来源]#

幂迭代算法。

此算法计算可对角化矩阵的主特征值及其相关特征向量。该矩阵可以作为数组提供,也可以作为实现矩阵-向量乘法的可调用对象提供。

Parameters:
  • matrix – 方阵,可以作为数组或实现矩阵-向量乘法的可调用对象。

  • v0 – 初始向量,近似于主特征向量。如果 matrix 是 大小为 (n, n) 的数组,则 v0 必须是大小为 (n,) 的向量。如果 matrix 是一个可调用对象,则 v0 必须是具有与该可调用对象输入相同结构的树。如果此参数为 None 而 matrix 是 一个数组,则将使用从均匀分布 [-1, 1] 中采样的随机向量作为初始向量。

  • num_iters – 权力迭代的次数。

  • error_tolerance – 迭代退出条件。当主特征值估计的相对误差低于该阈值时,程序停止。

  • precision – 与精度相关的XLA标志, 可用选项为: a) lax.Precision.DEFAULT(更好的步长时间,但不精确); b) lax.Precision.HIGH(增加精度,较慢); c) lax.Precision.HIGHEST(可能的最佳精度,最慢)。

  • key – 随机键用于v0的初始化,当未明确给出时。 当此参数为None时,使用jax.random.PRNGKey(0)

Returns:

一对(特征值,特征向量),其中特征值是 matrix 的主特征值,特征向量是其相关的特征向量。

参考文献

维基百科贡献者。 Power iteration

在版本 0.2.2 中更改:matrix 可以是一个可调用对象。返回参数的顺序已更改,从 (特征向量, 特征值) 改为 (特征值, 特征向量)。

非负最小二乘#

optax.nnls(A: 数组, b: 数组, iters: int, unroll: int | bool = 1, L: 数组 | float | None = None) 数组[来源]#

解决非负最小二乘问题。

最小化 \(\|A x - b\|_2\) 使 \(x \geq 0\)

使用Polyak 2015的快速投影梯度(FPG)算法。

Parameters:
  • A – 输入矩阵。

  • b – 输入向量。

  • iters – 要运行算法的迭代次数。

  • unroll – 传递给 lax.scan 的展开参数。

  • LA.T @ A 的谱半径的上界(可选)。

Returns:

解向量。

示例

>>> from jax import numpy as jnp
>>> import optax
>>> A = jnp.array([[1., 2.], [3., 4.]])
>>> b = jnp.array([5., 6.])
>>> x = optax.nnls(A, b, 10**3)
>>> print(f"{x[0]:.2f}")
0.00
>>> print(f"{x[1]:.2f}")
1.70

参考文献

罗曼·A·波利亚克,非负最小二乘的投影梯度方法,2015

二阶优化#

fisher_diag(negative_log_likelihood, params, ...)

计算(观察到的)Fisher信息矩阵的对角线。

hessian_diag(loss, params, inputs, targets)

Computes the diagonal hessian of loss at (inputs, targets).

hvp(loss, v, params, inputs, targets)

执行高效的向量-海森矩阵(loss)乘法。

费舍尔对角线#

optax.second_order.fisher_diag(negative_log_likelihood: LossFn, params: Any, inputs: 数组, targets: 数组) 数组[来源]#

计算(观察到的)费舍尔信息矩阵的对角线。

Parameters:
  • negative_log_likelihood – 负对数似然函数,预期签名 loss = fn(params, inputs, targets)

  • params – 模型参数。

  • inputs – 计算negative_log_likelihood时的输入。

  • targets – 评估negative_log_likelihood的目标。

  • None
Returns:

一个数组对应于评估在(params, inputs, targets)上的Hessian的产品negative_log_likelihood

海森对角线#

optax.second_order.hessian_diag(loss: LossFn, params: Any, inputs: 数组, targets: 数组) 数组[来源]#

计算在 (inputs, targets) 处的 loss 的对角哈希矩阵。

Parameters:
  • loss – 损失函数。

  • params – 模型参数。

  • 输入 – 评估损失的输入。

  • targets – 评估loss的目标。

Returns:

一个与产品对应的 DeviceArray 至于 loss 的 Hessian 在 (params, inputs, targets) 处评估。

海森矩阵向量积#

optax.second_order.hvp(loss: LossFn, v: 数组, params: Any, inputs: 数组, targets: 数组) 数组[来源]#

执行高效的向量-海森矩阵(损失)乘法。

Parameters:
  • loss – 损失函数。

  • v – 一个大小为 ravel(params) 的向量。

  • params – 模型参数。

  • 输入 – 评估损失的输入。

  • targets – 评估 loss 的目标。

Returns:

一个数组,对应于v和在(params, inputs, targets)处评估的loss的Hessian的乘积。

#

NamedTupleKey(tuple_name, name)

树中命名元组的键类型。

tree_add(tree_x, tree_y, *other_trees)

添加两个(或多个)pytrees。

tree_add_scalar_mul(tree_x, scalar, tree_y)

将两棵树相加,其中第二棵树按标量缩放。

tree_batch_shape(tree[, shape])

向pytree的每个叶子添加前导批量维度。

tree_cast(tree, dtype)

将树转换为给定的dtype,如果为None则跳过。

tree_div(tree_x, tree_y)

分割两个pytrees。

tree_dtype(tree[, mixed_dtype_handler])

获取树的dtype。

tree_get(tree, key[, default, filtering])

从一个pytree中提取与给定键匹配的值。

tree_get_all_with_path(tree, key[, filtering])

提取与给定键匹配的pytree值。

tree_l1_norm(tree)

计算 pytree 的 l1 范数。

tree_l2_norm(tree[, squared])

计算pytree的l2范数。

tree_linf_norm(tree)

计算一个 pytree 的 l-无限范数。

tree_map_params(initable, f, state, /, *rest)

在给定的优化器状态下应用一个可调用对象到所有参数上。

tree_max(tree)

计算pytree中所有元素的最大值。

tree_mul(tree_x, tree_y)

对两个pytrees进行乘法运算。

tree_ones_like(tree[, dtype])

创建一个结构相同的全一树。

tree_random_like(rng_key, target_tree[, ...])

创建一个与目标树形状相同的随机条目树。

tree_split_key_like(rng_key, target_tree)

将键拆分以匹配目标树的结构。

tree_scalar_mul(scalar, tree)

将树乘以标量。

tree_set(tree[, filtering])

创建一个树的副本,某些值根据指定的kwargs进行替换。

tree_sub(tree_x, tree_y)

减去两个pytrees。

tree_sum(tree)

计算pytree中所有元素的总和。

tree_vdot(tree_x, tree_y)

计算两个pytrees之间的内积。

tree_where(condition, tree_x, tree_y)

如果条件为真,则选择 tree_x 值,否则选择 tree_y 值。

tree_zeros_like(tree[, dtype])

创建一个结构相同的全零树。

命名元组键#

class optax.tree_utils.NamedTupleKey(tuple_name: str, name: str)[来源]#

树中命名元组的关键类型。

When using a function filtering(path: KeyPath, value: Any) -> bool: ... in a tree in optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), or optax.tree_utils.tree_set(), can filter the path to check if of the KeyEntry is a NamedTupleKey and then check if the name of named tuple is the one intended to be searched.

tuple_name#

包含键的元组的名称。

Type:

str

name#

键的名称。

Type:

字符串

另请参见

jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey, jax.tree_util.GetAttrKey, jax.tree_util.SequenceKey, optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), optax.tree_utils.tree_set(),

在版本 0.2.2 中添加。

树添加#

optax.tree_utils.tree_add(tree_x: Any, tree_y: Any, *other_trees: Any) Any[来源]#

添加两个(或更多)pytrees。

Parameters:
  • tree_x – 第一个 pytree。

  • tree_y – 第二个 pytree。

  • *other_trees – 可选的其他树以添加

Returns:

两个(或多个)pytrees 的和。

在版本 0.2.1 中更改: 添加了可选的 *other_trees 参数。

树的添加和数乘#

optax.tree_utils.tree_add_scalar_mul(tree_x: Any, scalar: float | 数组, tree_y: Any) Any[来源]#

添加两棵树,其中第二棵树通过一个标量进行缩放。

在中缀表示法中,函数执行 out = tree_x + scalar * tree_y

Parameters:
  • tree_x – 第一个pytree。

  • scalar – 标量值。

  • tree_y – 第二个 pytree。

Returns:

一个与 tree_xtree_y 结构相同的 pytree。

树批量重塑#

optax.tree_utils.tree_batch_shape(tree: Any, shape: tuple[int, ...] = ())[来源]#

为每个pytree的叶子添加前导批次维度。

Parameters:
  • tree – 一个 pytree。

  • shape – 一个形状,指示要添加的前导批次维度。

Returns:

添加了前导批次维度的pytree。

树的类型转换#

optax.tree_utils.tree_cast(tree: chex.ArrayTree, dtype: str | type[Any] | 数据类型 | SupportsDType | None) chex.ArrayTree[来源]#

将树转换为给定的数据类型,如果为None则跳过。

Parameters:
  • tree – 要转换的树。

  • dtype – 要转换的dtype,或None以跳过。

Returns:

树,叶子转换为数据类型。

示例

>>> import jax.numpy as jnp
>>> import optax
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
...         'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16)
{'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)}

树数据类型#

optax.tree_utils.tree_dtype(tree: chex.ArrayTree, mixed_dtype_handler: str | None = None) str | type[Any] | 数据类型 | SupportsDType[来源]#

获取树的dtype。

如果树是空的,返回JAX数组的默认数据类型。

Parameters:
  • tree – 要提取dtype的树。

  • mixed_dtype_handler – how to handle mixed dtypes in the tree. - If mixed_dtype_handler=None, returns the common dtype of the leaves of the tree if it exists, otherwise raises an error. - If mixed_dtype_handler='promote', promotes the dtypes of the leaves of the tree to a common promoted dtype using jax.numpy.promote_types(). - If mixed_dtype_handler='highest' or mixed_dtype_handler='lowest', returns the highest/lowest dtype of the leaves of the tree. We consider a partial ordering of dtypes as dtype1 <= dtype2 if dtype1 is promoted to dtype2, that is, if jax.numpy.promote_types(dtype1, dtype2) == dtype2. Since some dtypes cannot be promoted to one another, this is not a total ordering, and the ‘highest’ or ‘lowest’ options may not be applicable. These options will throw an error if the dtypes of the leaves of the tree cannot be promoted to one another.

Returns:

树的 dtype。

Raises:
  • 值错误 – 如果 mixed_dtype_handler 设置为 None 并且在树中发现多个 数据类型。

  • 值错误 – 如果 mixed_dtype_handler 设置为 'highest''lowest' ,而树中某些叶子节点的 dtypes 不能互相提升。

示例

>>> import jax.numpy as jnp
>>> import optax
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
...         'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_dtype(tree)
dtype('float32')
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float16)},
...         'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_dtype(tree, 'lowest')
dtype('float16')
>>> optax.tree_utils.tree_dtype(tree, 'highest')
dtype('float32')
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.int32)},
...         'c': jnp.array(2.0, dtype=jnp.uint32)}
>>> # optax.tree_utils.tree_dtype(tree, 'highest')
>>> # -> will throw an error because int32 and uint32
>>> # cannot be promoted to one another.
>>> optax.tree_utils.tree_dtype(tree, 'promote')
dtype('int64')

在版本 0.2.4 中添加。

树形划分#

optax.tree_utils.tree_div(tree_x: Any, tree_y: Any) Any[来源]#

将两个pytrees相除。

Parameters:
  • tree_x – 第一个pytree。

  • tree_y – 第二个 pytree。

Returns:

两个pytrees的商。

获取与给定键匹配的单个值#

optax.tree_utils.tree_get(tree: optax.PyTree, key: Any, default: Any | None = None, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | 命名元组键, ...], Any], bool] | None = None) Any[来源]#

从匹配给定键的pytree中提取一个值。

tree中搜索特定的key(可以是字典中的键、NamedTuple中的字段或者NamedTuple的名称)。

如果 tree 不包含 key,则返回 default

如果在 tree 中找到多个 key 的值,将引发一个 KeyError

Generally, you may first get all pairs (path_to_value, value) for a given key using optax.tree_utils.tree_get_all_with_path(). You may then define a filtering operation filtering(path: Key_Path, value: Any) -> bool: ... that enables you to select the specific values you wanted to fetch by looking at the type of the value, or looking at the path to that value. Note that contrarily to the paths returned by jax.tree_util.tree_leaves_with_path() the paths analyzed by the filtering operation in optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), or optax.tree_utils.tree_set() detail the names of the named tuples considered in the path. Concretely, if the value considered is in the attribute key of a named tuple called MyNamedTuple the last element of the path will be a optax.tree_utils.NamedTupleKey containing both name=key and tuple_name='MyNamedTuple'. That way you may distinguish between identical values in different named tuples (arising for example when chaining transformations in optax). See the last example below.

Parameters:
  • tree – 要搜索的树。

  • key – 要在 tree 中搜索的关键字或字段。

  • 默认值 – 如果在 tree 中找不到 key,则返回的默认值。

  • filtering – optional callable to further filter values in tree that match the key. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match the given key.

Returns:

value

value in tree matching the given key. If none are found return default value. If multiple are found raises an error.

Raises:

KeyError – 如果在 tree 中找到多个 key 值。

示例

基本用法

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.adam(learning_rate=1.)
>>> state = opt.init(params)
>>> count = optax.tree_utils.tree_get(state, 'count')
>>> print(count)
0

与过滤操作的使用

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = opt.init(params)
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> lr = optax.tree_utils.tree_get(
...   state, 'learning_rate', filtering=filtering
... )
>>> print(lr)
1.0

通过名称提取命名元组

>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.chain(
...     optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam()
... )
>>> state = opt.init(params)
>>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState')
>>> print(noise_state)
AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], dtype=uint32))

通过命名元组的名称区分两个值。

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.chain(
...   optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam()
... )
>>> state = opt.init(params)
>>> filtering = (
...      lambda p, v: isinstance(p[-1], optax.tree_utils.NamedTupleKey)
...      and p[-1].tuple_name == 'ScaleByAdamState'
... )
>>> count = optax.tree_utils.tree_get(state, 'count', filtering=filtering)
>>> print(count)
0

在版本 0.2.2 中添加。

获取所有匹配给定键的值#

optax.tree_utils.tree_get_all_with_path(tree: optax.PyTree, key: Any, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | 命名元组键, ...], Any], bool] | None = None) list[tuple[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | 命名元组键, ...], Any]][来源]#

提取与给定键匹配的py树的值。

tree中搜索特定的 key(可以是字典中的一个键、NamedTuple中的一个字段或NamedTuple的名称)。

That key/field key may appear more than once in tree. So this function returns a list of all values corresponding to key with the path to that value. The path is a sequence of KeyEntry that can be transformed in readable format using jax.tree_util.keystr(), see the example below.

Parameters:
  • tree – 要搜索的树。

  • key – 在树中搜索的关键字或字段。

  • filtering – optional callable to further filter values in tree that match the key. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match the given key.

Returns:

values_with_path

list of tuples where each tuple is of the form (path_to_value, value). Here value is one entry of the tree that corresponds to the key, and path_to_value is a tuple of KeyEntry that is a tuple of jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey, jax.tree_util.GetAttrKey, jax.tree_util.SequenceKey, or optax.tree_utils.NamedTupleKey.

示例

基本用法

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> solver = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = solver.init(params)
>>> found_values_with_path = optax.tree_utils.tree_get_all_with_path(
...   state, 'learning_rate'
... )
>>> print(
... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path],
... sep="\n",
... )
("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32))
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))

与过滤操作的使用

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> solver = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = solver.init(params)
>>> filtering = lambda path, value: isinstance(value, tuple)
>>> found_values_with_path = optax.tree_utils.tree_get_all_with_path(
...   state, 'learning_rate', filtering
... )
>>> print(
... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path],
... sep="\n",
... )
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))

在版本 0.2.2 中添加。

树的 l1 范数#

optax.tree_utils.tree_l1_norm(tree: Any) chex.Numeric[来源]#

计算pytree的l1范数。

Parameters:

– pytree.

Returns:

一个标量值。

树的l2范数#

optax.tree_utils.tree_l2_norm(tree: Any, squared: bool = False) chex.Numeric[来源]#

计算pytree的l2范数。

Parameters:
  • tree – pytree.

  • squared – 是否应该返回平方的范数。

Returns:

一个标量值。

树 l-无穷范数#

optax.tree_utils.tree_linf_norm(tree: Any) chex.Numeric[来源]#

计算pytree的l-infinity范数。

Parameters:

– pytree.

Returns:

一个标量值。

树图参数#

optax.tree_utils.tree_map_params(initable: Callable[[TypeAliasForwardRef('optax.Params')], TypeAliasForwardRef('optax.OptState')] | Initable, f: Callable[[...], Any], state: optax.OptState, /, *rest: Any, transform_non_params: Callable[[...], Any] | None = None, is_leaf: Callable[[TypeAliasForwardRef('optax.Params')], bool] | None = None) optax.OptState[来源]#

在给定的优化器状态下,对所有参数应用可调用对象。

这个函数旨在帮助构建优化器状态的分区规格,以便在已知参数的分区规格的情况下。

例如,下面的内容将用给定的分区规范的副本替换所有优化器状态参数树。可以使用参数 transform_non_params 根据需要替换任何剩余字段,在这种情况下,我们将这些字段替换为 None。

>>> params, specs = jnp.array(0.), jnp.array(0.)  # Trees with the same shape
>>> opt = optax.sgd(1e-3)
>>> state = opt.init(params)
>>> opt_specs = optax.tree_map_params(
...     opt,
...     lambda _, spec: spec,
...     state,
...     specs,
...     transform_non_params=lambda _: None,
...     )
Parameters:
  • initable – 一个可调用的,接受参数并返回优化器状态,或者一个具有相同功能的init属性的对象。

  • f – 一个可调用的对象,将应用于该优化器状态下参数树的所有副本。

  • state – 要映射的优化器状态。

  • *rest – 额外的参数,与参数树的形状相同,将被传递给 f。

  • transform_non_params – 一个可选的函数,将在优化器状态中的所有非参数字段上调用。

  • is_leaf – 传递给 jax.tree.map。这使得可以忽略参数树的某些部分,例如当梯度变换修改原始 pytree 的形状时,如对于 optax.masked

Returns:

对优化器状态中所有与参数树具有相同形状的树应用函数 f 的结果,以及给定的可选额外参数。

树的最大值#

optax.tree_utils.tree_max(tree: Any) chex.Numeric[来源]#

计算pytree中所有元素的最大值。

Parameters:

– pytree.

Returns:

一个标量值。

树的乘法#

optax.tree_utils.tree_mul(tree_x: Any, tree_y: Any) Any[来源]#

乘法两个pytrees。

Parameters:
  • tree_x – 第一个pytree。

  • tree_y – 第二个 pytree。

Returns:

这两个pytrees的乘积。

像这样的树#

optax.tree_utils.tree_ones_like(tree: Any, dtype: str | type[Any] | 数据类型 | SupportsDType | None = None) Any[来源]#

创建一个具有相同结构的全一树。

Parameters:
  • tree – pytree.

  • dtype – 可选的用于一树的dtype。

Returns:

一个与 tree 具有相同结构的全一树。

根据树的结构拆分键#

optax.tree_utils.tree_split_key_like(rng_key: 数组, target_tree: chex.ArrayTree) chex.ArrayTree[来源]#

将键拆分以匹配目标树的结构。

Parameters:
  • rng_key – 用于分割的键。

  • target_tree – 要匹配其结构的树。

Returns:

一个rng密钥的树。

具有随机值的树#

optax.tree_utils.tree_random_like(rng_key: ~jax.Array, target_tree: chex.ArrayTree, sampler: ~collections.abc.Callable[[~jax.Array, ~typing.Sequence[int | ~typing.Any], str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType], TypeAliasForwardRef('chex.Array')] = <function normal>, dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None) chex.ArrayTree[来源]#

创建与目标树形状相同的随机条目树。

Parameters:
  • rng_key – 随机数生成器的钥匙。

  • target_tree – 要匹配结构的树。叶子必须是数组。

  • sampler – 噪声采样函数,默认是 jax.random.normal

  • dtype – 随机数的期望数据类型,传递给 sampler。如果为 None,则如果可能,使用目标树的 dtype。

Returns:

一个具有与 target_tree 相同结构的随机树,其叶子具有 sampler 的分布。

警告

可能的数据类型可能受到采样器的限制,例如 jax.random.rademacher 仅支持整数数据类型,如果目标树的 dtype 不是整数,或者 dtype 不是整数类型,则会引发错误。

在版本 0.2.1 中新增。

树的标量乘法#

optax.tree_utils.tree_scalar_mul(scalar: float | 数组, tree: Any) Any[来源]#

将树乘以标量。

在中缀表示法中,函数执行 out = scalar * tree

Parameters:
  • scalar – 标量值。

  • tree – pytree.

Returns:

一个与 tree 结构相同的 pytree。

在树中设置值#

optax.tree_utils.tree_set(tree: optax.PyTree, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | 命名元组键, ...], Any], bool] | None = None, /, **kwargs: Any) optax.PyTree[来源]#

创建一个树的副本,替换指定kwargs的一些值。

tree中搜索keys,在**kwargs中(它可以是字典中的一个键,NamedTuple中的一个字段或NamedTuple的名称)。如果找到这样的键,将相应的值替换为在**kwargs中给定的值。

如果 **kwargs 中的一些键在树中不存在,则会引发一个 KeyError

Parameters:
  • tree – 要被替换的pytree值。

  • filtering – optional callable to further filter values in tree that match the keys to replace. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match a given key.

  • **kwargs – 用于替换tree中的值的键字典。

Returns:

new_tree

new pytree with the same structure as tree. For each element in tree whose key/field matches a key in **kwargs, its value is set by the corresponding value in **kwargs.

Raises:

KeyError – 如果在 **kwargs 中找不到某个键的值,或者没有值满足过滤操作。

示例

基本用法

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.adam(learning_rate=1.)
>>> state = opt.init(params)
>>> print(state)
(ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
>>> new_state = optax.tree_utils.tree_set(state, count=2.)
>>> print(new_state)
(ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())

与过滤操作的使用

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.inject_hyperparams(optax.sgd)(
...     learning_rate=lambda count: 1/(count+1)
...  )
>>> state = opt.init(params)
>>> print(state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> new_state = optax.tree_utils.tree_set(
...   state, filtering, learning_rate=jnp.asarray(0.1)
... )
>>> print(new_state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))

注意

推荐使用注入超参数调度的方法是通过 optax.inject_hyperparams()。这个函数是其他用途的辅助工具。

在版本 0.2.2 中添加。

树的减法#

optax.tree_utils.tree_sub(tree_x: Any, tree_y: Any) Any[来源]#

减去两个pytrees。

Parameters:
  • tree_x – 第一个pytree。

  • tree_y – 第二个 pytree。

Returns:

两个pytrees的区别。

树的总和#

optax.tree_utils.tree_sum(tree: Any) chex.Numeric[来源]#

计算pytree中所有元素的总和。

Parameters:

– pytree.

Returns:

一个标量值。

树内积#

optax.tree_utils.tree_vdot(tree_x: Any, tree_y: Any) chex.Numeric[来源]#

计算两个pytrees之间的内积。

Parameters:
  • tree_x – 第一个要使用的pytree。

  • tree_y – 第二个 pytree 以供使用。

Returns:

tree_xtree_y 之间的内积,标量值。

示例

>>> optax.tree_utils.tree_vdot(
...   {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])},
...   {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])},
... )
Array(0, dtype=int32)

注意

我们将值提升到最高精度以避免数值问题。

树在哪里#

optax.tree_utils.tree_where(condition, tree_x, tree_y)[来源]#

如果条件为真,则选择 tree_x 值,否则选择 tree_y 值。

Parameters:
  • condition – 布尔值,用于指定从树 x 或树_y 中选择哪些值

  • tree_x – 如果条件为真,选择pytree

  • tree_y – 如果条件为假,选择pytree

Returns:

根据条件选择 tree_x 或 tree_y。

树的零像#

optax.tree_utils.tree_zeros_like(tree: Any, dtype: str | type[Any] | 数据类型 | SupportsDType | None = None) Any[来源]#

创建一个结构相同的全零树。

Parameters:
  • tree – pytree.

  • dtype – 用于零树的可选数据类型。

Returns:

一个与 tree 结构相同的全零树。