工具#
一般#
|
为反向传播缩放梯度。 |
|
从状态中获取值和梯度的替代方案 |
比例梯度#
从状态获取值和梯度#
- optax.value_and_grad_from_state(value_fn: Callable[[...], 数组 | float]) Callable[[...], tuple[float | 数组, TypeAliasForwardRef('optax.Updates')]][来源]#
替代
jax.value_and_grad的方法,从状态中获取值和梯度。诸如
optax.scale_by_backtracking_linesearch()的线搜索方法需要在候选迭代点计算梯度和目标函数。这个目标值和梯度可以在下一个迭代中重新使用,以利用这个实用函数节省一些计算。- Parameters:
value_fn – 返回一个标量(float或维度为1的数组)的函数,可以在jax中使用
jax.value_and_grad()进行微分。- Returns:
一个类似于
jax.value_and_grad()的可调用函数,如果状态中存在,则获取值和梯度。如果未找到值或梯度,或者找到多个值和梯度,该函数会引发错误。如果找到的值是无限大或nan,则使用jax.value_and_grad()计算值和梯度。如果在状态中找到的梯度为None,则引发错误。
示例
>>> import optax >>> import jax.numpy as jnp >>> def fn(x): return jnp.sum(x ** 2) >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch( ... max_backtracking_steps=15, store_grad=True ... ) ... ) >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 5.04E+00 Objective function: 1.81E+00 Objective function: 6.53E-01 Objective function: 2.35E-01 Objective function: 8.47E-02
数值稳定性#
|
在避免溢出的情况下将计数器增加一个。 |
|
返回 jnp.maximum(jnp.linalg.norm(x), min_norm),并带有正确的梯度。 |
|
返回 最大值(sqrt(mean(abs_sq(x))), min_norm),并带有正确的梯度。 |
安全递增#
- optax.safe_increment(count: chex.Numeric) chex.Numeric[来源]#
在避免溢出的情况下将计数器增加一个。
Denote
max_val,min_valas the maximum, minimum, possible values for thedtypeofcount. Normallymax_val + 1would overflow tomin_val. This functions ensures that whenmax_valis reached the counter stays atmax_val.- Parameters:
count – 一个要递增的计数器。
- Returns:
一个每次递增1的计数器,或
max_val如果达到最大值时。
示例
>>> import jax.numpy as jnp >>> import optax >>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32)) Array(2, dtype=int32) >>> optax.safe_increment(jnp.asarray(2147483647, dtype=jnp.int32)) Array(2147483647, dtype=int32)
在版本 0.2.4 中添加。
安全范数#
- optax.safe_norm(x: chex.Array, min_norm: chex.Numeric, ord: int | float | str | None = None, axis: None | tuple[int, ...] | int = None, keepdims: bool = False) chex.Array[来源]#
返回 jnp.maximum(jnp.linalg.norm(x), min_norm) 并具有正确的梯度。
在0.0处,jnp.maximum(jnp.linalg.norm(x), min_norm)的梯度为NaN,因为jax会评估jnp.maximum的两个分支。该函数在这种情况下也会返回正确的0.0梯度。
- Parameters:
x – jax 数组。
min_norm – 返回的范数的下限。
ord – {非零整数, inf, -inf, ‘fro’, ‘nuc’}, 可选。范数的阶。 inf表示numpy的inf对象。默认值为None。
axis – {无,整数,整数的2元组},可选。如果轴是一个整数,它指定沿着计算向量范数的 x 的轴。如果轴是一个2元组,它指定保持二维矩阵的轴,这些矩阵的矩阵范数将被计算。如果轴是 None,则返回向量范数(当 x 是1维时)或矩阵范数(当 x 是2维时)。默认值是 None。
keepdims – 布尔值, 可选。如果设置为 True,那么被归一化的轴将以大小为一的维度保留在结果中。使用此选项时,结果将能够正确地与原始 x 进行广播。
- Returns:
输入向量的安全范数,考虑了正确的梯度。
安全的均方根#
- optax.safe_root_mean_squares(x: chex.Array, min_rms: chex.Numeric) chex.Array[来源]#
返回 最大值(sqrt(mean(abs_sq(x))), min_norm),并具有正确的梯度。
在0.0处,maximum(sqrt(mean(abs_sq(x))), min_norm) 的梯度是NaN,因为jax将评估jnp.maximum的两个分支。这个函数在这种情况下也将返回正确的0.0的梯度。
- Parameters:
x – jax 数组。
min_rms – 返回的范数的下限。
- Returns:
输入向量的安全 RMS,考虑到正确的梯度。
线性代数运算符#
|
计算 matrix^(-1/p),其中 p 是一个正整数。 |
|
幂迭代算法。 |
|
求解非负最小二乘问题。 |
矩阵逆 p 阶根#
- optax.matrix_inverse_pth_root(matrix: chex.Array, p: int, num_iters: int = 100, ridge_epsilon: float = 1e-06, error_tolerance: float = 1e-06, precision: 精准度 = Precision.HIGHEST)[来源]#
计算 matrix^(-1/p),其中 p 是一个正整数。
此函数使用耦合牛顿迭代算法来计算矩阵的逆pth根。
- Parameters:
matrix – 要计算其幂的对称PSD矩阵
p – 指数,对于p为正整数。
num_iters – 最大迭代次数。
ridge_epsilon – 加入的Ridge epsilon以使矩阵为正定的。
error_tolerance – 错误指示器,适用于提前终止。
precision – 精度 XLA 相关标志, 可选项有: a) lax.Precision.DEFAULT (更好的步骤时间,但不精确); b) lax.Precision.HIGH (增加精度,速度较慢); c) lax.Precision.HIGHEST (可能的最佳精度,速度最慢)。
- Returns:
矩阵^(-1/p)
参考文献
- [Functions of Matrices, Theory and Computation,
Nicholas J Higham, 第184页, 方程 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
幂迭代#
- optax.power_iteration(matrix: TypeAliasForwardRef('chex.Array') | Callable[[TypeAliasForwardRef('chex.ArrayTree')], TypeAliasForwardRef('chex.ArrayTree')], *, v0: TypeAliasForwardRef('chex.ArrayTree') | None = None, num_iters: int = 100, error_tolerance: float = 1e-06, precision: 精度 = Precision.HIGHEST, key: 数组 | None = None) tuple[TypeAliasForwardRef('chex.Numeric'), TypeAliasForwardRef('chex.ArrayTree')][来源]#
幂迭代算法。
此算法计算可对角化矩阵的主特征值及其相关特征向量。该矩阵可以作为数组提供,也可以作为实现矩阵-向量乘法的可调用对象提供。
- Parameters:
matrix – 方阵,可以作为数组或实现矩阵-向量乘法的可调用对象。
v0 – 初始向量,近似于主特征向量。如果
matrix是 大小为 (n, n) 的数组,则 v0 必须是大小为 (n,) 的向量。如果matrix是一个可调用对象,则 v0 必须是具有与该可调用对象输入相同结构的树。如果此参数为 None 而matrix是 一个数组,则将使用从均匀分布 [-1, 1] 中采样的随机向量作为初始向量。num_iters – 权力迭代的次数。
error_tolerance – 迭代退出条件。当主特征值估计的相对误差低于该阈值时,程序停止。
precision – 与精度相关的XLA标志, 可用选项为: a) lax.Precision.DEFAULT(更好的步长时间,但不精确); b) lax.Precision.HIGH(增加精度,较慢); c) lax.Precision.HIGHEST(可能的最佳精度,最慢)。
key – 随机键用于
v0的初始化,当未明确给出时。 当此参数为None时,使用jax.random.PRNGKey(0)。
- Returns:
一对(特征值,特征向量),其中特征值是
matrix的主特征值,特征向量是其相关的特征向量。
参考文献
维基百科贡献者。 Power iteration。
在版本 0.2.2 中更改:
matrix可以是一个可调用对象。返回参数的顺序已更改,从 (特征向量, 特征值) 改为 (特征值, 特征向量)。
非负最小二乘#
- optax.nnls(A: 数组, b: 数组, iters: int, unroll: int | bool = 1, L: 数组 | float | None = None) 数组[来源]#
解决非负最小二乘问题。
最小化 \(\|A x - b\|_2\) 使 \(x \geq 0\)。
使用Polyak 2015的快速投影梯度(FPG)算法。
- Parameters:
A – 输入矩阵。
b – 输入向量。
iters – 要运行算法的迭代次数。
unroll – 传递给 lax.scan 的展开参数。
L – A.T @ A 的谱半径的上界(可选)。
- Returns:
解向量。
示例
>>> from jax import numpy as jnp >>> import optax >>> A = jnp.array([[1., 2.], [3., 4.]]) >>> b = jnp.array([5., 6.]) >>> x = optax.nnls(A, b, 10**3) >>> print(f"{x[0]:.2f}") 0.00 >>> print(f"{x[1]:.2f}") 1.70
参考文献
罗曼·A·波利亚克,非负最小二乘的投影梯度方法,2015
二阶优化#
|
计算(观察到的)Fisher信息矩阵的对角线。 |
|
Computes the diagonal hessian of loss at (inputs, targets). |
|
执行高效的向量-海森矩阵(loss)乘法。 |
费舍尔对角线#
- optax.second_order.fisher_diag(negative_log_likelihood: LossFn, params: Any, inputs: 数组, targets: 数组) 数组[来源]#
计算(观察到的)费舍尔信息矩阵的对角线。
- Parameters:
negative_log_likelihood – 负对数似然函数,预期签名 loss = fn(params, inputs, targets)。
params – 模型参数。
inputs – 计算negative_log_likelihood时的输入。
targets – 评估negative_log_likelihood的目标。
None
- Returns:
一个数组对应于评估在(params, inputs, targets)上的Hessian的产品negative_log_likelihood。
海森对角线#
海森矩阵向量积#
树#
|
树中命名元组的键类型。 |
|
添加两个(或多个)pytrees。 |
|
将两棵树相加,其中第二棵树按标量缩放。 |
|
向pytree的每个叶子添加前导批量维度。 |
|
将树转换为给定的dtype,如果为None则跳过。 |
|
分割两个pytrees。 |
|
获取树的dtype。 |
|
从一个pytree中提取与给定键匹配的值。 |
|
提取与给定键匹配的pytree值。 |
|
计算 pytree 的 l1 范数。 |
|
计算pytree的l2范数。 |
|
计算一个 pytree 的 l-无限范数。 |
|
在给定的优化器状态下应用一个可调用对象到所有参数上。 |
|
计算pytree中所有元素的最大值。 |
|
对两个pytrees进行乘法运算。 |
|
创建一个结构相同的全一树。 |
|
创建一个与目标树形状相同的随机条目树。 |
|
将键拆分以匹配目标树的结构。 |
|
将树乘以标量。 |
|
创建一个树的副本,某些值根据指定的kwargs进行替换。 |
|
减去两个pytrees。 |
|
计算pytree中所有元素的总和。 |
|
计算两个pytrees之间的内积。 |
|
如果条件为真,则选择 tree_x 值,否则选择 tree_y 值。 |
|
创建一个结构相同的全零树。 |
命名元组键#
- class optax.tree_utils.NamedTupleKey(tuple_name: str, name: str)[来源]#
树中命名元组的关键类型。
When using a function
filtering(path: KeyPath, value: Any) -> bool: ...in a tree inoptax.tree_utils.tree_get_all_with_path(),optax.tree_utils.tree_get(), oroptax.tree_utils.tree_set(), can filter the path to check if of the KeyEntry is a NamedTupleKey and then check if the name of named tuple is the one intended to be searched.- tuple_name#
包含键的元组的名称。
- Type:
str
- name#
键的名称。
- Type:
字符串
另请参见
jax.tree_util.DictKey,jax.tree_util.FlattenedIndexKey,jax.tree_util.GetAttrKey,jax.tree_util.SequenceKey,optax.tree_utils.tree_get_all_with_path(),optax.tree_utils.tree_get(),optax.tree_utils.tree_set(),在版本 0.2.2 中添加。
树添加#
树的添加和数乘#
树批量重塑#
树的类型转换#
- optax.tree_utils.tree_cast(tree: chex.ArrayTree, dtype: str | type[Any] | 数据类型 | SupportsDType | None) chex.ArrayTree[来源]#
将树转换为给定的数据类型,如果为None则跳过。
- Parameters:
tree – 要转换的树。
dtype – 要转换的dtype,或None以跳过。
- Returns:
树,叶子转换为数据类型。
示例
>>> import jax.numpy as jnp >>> import optax >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16) {'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)}
树数据类型#
- optax.tree_utils.tree_dtype(tree: chex.ArrayTree, mixed_dtype_handler: str | None = None) str | type[Any] | 数据类型 | SupportsDType[来源]#
获取树的dtype。
如果树是空的,返回JAX数组的默认数据类型。
- Parameters:
tree – 要提取dtype的树。
mixed_dtype_handler – how to handle mixed dtypes in the tree. - If
mixed_dtype_handler=None, returns the common dtype of the leaves of the tree if it exists, otherwise raises an error. - Ifmixed_dtype_handler='promote', promotes the dtypes of the leaves of the tree to a common promoted dtype usingjax.numpy.promote_types(). - Ifmixed_dtype_handler='highest'ormixed_dtype_handler='lowest', returns the highest/lowest dtype of the leaves of the tree. We consider a partial ordering of dtypes asdtype1 <= dtype2ifdtype1is promoted todtype2, that is, ifjax.numpy.promote_types(dtype1, dtype2) == dtype2. Since some dtypes cannot be promoted to one another, this is not a total ordering, and the ‘highest’ or ‘lowest’ options may not be applicable. These options will throw an error if the dtypes of the leaves of the tree cannot be promoted to one another.
- Returns:
树的 dtype。
- Raises:
值错误 – 如果
mixed_dtype_handler设置为None并且在树中发现多个 数据类型。值错误 – 如果
mixed_dtype_handler设置为'highest'或'lowest',而树中某些叶子节点的 dtypes 不能互相提升。
示例
>>> import jax.numpy as jnp >>> import optax >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_dtype(tree) dtype('float32') >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float16)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_dtype(tree, 'lowest') dtype('float16') >>> optax.tree_utils.tree_dtype(tree, 'highest') dtype('float32') >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.int32)}, ... 'c': jnp.array(2.0, dtype=jnp.uint32)} >>> # optax.tree_utils.tree_dtype(tree, 'highest') >>> # -> will throw an error because int32 and uint32 >>> # cannot be promoted to one another. >>> optax.tree_utils.tree_dtype(tree, 'promote') dtype('int64')
在版本 0.2.4 中添加。
树形划分#
获取与给定键匹配的单个值#
- optax.tree_utils.tree_get(tree: optax.PyTree, key: Any, default: Any | None = None, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | 命名元组键, ...], Any], bool] | None = None) Any[来源]#
从匹配给定键的pytree中提取一个值。
在
tree中搜索特定的key(可以是字典中的键、NamedTuple中的字段或者NamedTuple的名称)。如果
tree不包含key,则返回default。如果在
tree中找到多个key的值,将引发一个KeyError。Generally, you may first get all pairs
(path_to_value, value)for a givenkeyusingoptax.tree_utils.tree_get_all_with_path(). You may then define a filtering operationfiltering(path: Key_Path, value: Any) -> bool: ...that enables you to select the specific values you wanted to fetch by looking at the type of the value, or looking at the path to that value. Note that contrarily to the paths returned byjax.tree_util.tree_leaves_with_path()the paths analyzed by the filtering operation inoptax.tree_utils.tree_get_all_with_path(),optax.tree_utils.tree_get(), oroptax.tree_utils.tree_set()detail the names of the named tuples considered in the path. Concretely, if the value considered is in the attributekeyof a named tuple calledMyNamedTuplethe last element of the path will be aoptax.tree_utils.NamedTupleKeycontaining bothname=keyandtuple_name='MyNamedTuple'. That way you may distinguish between identical values in different named tuples (arising for example when chaining transformations in optax). See the last example below.- Parameters:
tree – 要搜索的树。
key – 要在
tree中搜索的关键字或字段。默认值 – 如果在
tree中找不到key,则返回的默认值。filtering – optional callable to further filter values in
treethat match thekey.filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match the given key.
- Returns:
- value
value in
treematching the givenkey. If none are found returndefaultvalue. If multiple are found raises an error.
- Raises:
KeyError – 如果在
tree中找到多个key值。
示例
基本用法
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> count = optax.tree_utils.tree_get(state, 'count') >>> print(count) 0
与过滤操作的使用
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = opt.init(params) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> lr = optax.tree_utils.tree_get( ... state, 'learning_rate', filtering=filtering ... ) >>> print(lr) 1.0
通过名称提取命名元组
>>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( ... optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam() ... ) >>> state = opt.init(params) >>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState') >>> print(noise_state) AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], dtype=uint32))
通过命名元组的名称区分两个值。
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( ... optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam() ... ) >>> state = opt.init(params) >>> filtering = ( ... lambda p, v: isinstance(p[-1], optax.tree_utils.NamedTupleKey) ... and p[-1].tuple_name == 'ScaleByAdamState' ... ) >>> count = optax.tree_utils.tree_get(state, 'count', filtering=filtering) >>> print(count) 0
在版本 0.2.2 中添加。
获取所有匹配给定键的值#
- optax.tree_utils.tree_get_all_with_path(tree: optax.PyTree, key: Any, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | 命名元组键, ...], Any], bool] | None = None) list[tuple[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | 命名元组键, ...], Any]][来源]#
提取与给定键匹配的py树的值。
在
tree中搜索特定的key(可以是字典中的一个键、NamedTuple中的一个字段或NamedTuple的名称)。That key/field
keymay appear more than once intree. So this function returns a list of all values corresponding tokeywith the path to that value. The path is a sequence ofKeyEntrythat can be transformed in readable format usingjax.tree_util.keystr(), see the example below.- Parameters:
tree – 要搜索的树。
key – 在树中搜索的关键字或字段。
filtering – optional callable to further filter values in tree that match the key.
filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match the given key.
- Returns:
- values_with_path
list of tuples where each tuple is of the form (
path_to_value,value). Herevalueis one entry of the tree that corresponds to thekey, andpath_to_valueis a tuple of KeyEntry that is a tuple ofjax.tree_util.DictKey,jax.tree_util.FlattenedIndexKey,jax.tree_util.GetAttrKey,jax.tree_util.SequenceKey, oroptax.tree_utils.NamedTupleKey.
示例
基本用法
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = solver.init(params) >>> found_values_with_path = optax.tree_utils.tree_get_all_with_path( ... state, 'learning_rate' ... ) >>> print( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) ("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32)) ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
与过滤操作的使用
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = solver.init(params) >>> filtering = lambda path, value: isinstance(value, tuple) >>> found_values_with_path = optax.tree_utils.tree_get_all_with_path( ... state, 'learning_rate', filtering ... ) >>> print( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
在版本 0.2.2 中添加。
树的 l1 范数#
树的l2范数#
树 l-无穷范数#
树图参数#
- optax.tree_utils.tree_map_params(initable: Callable[[TypeAliasForwardRef('optax.Params')], TypeAliasForwardRef('optax.OptState')] | Initable, f: Callable[[...], Any], state: optax.OptState, /, *rest: Any, transform_non_params: Callable[[...], Any] | None = None, is_leaf: Callable[[TypeAliasForwardRef('optax.Params')], bool] | None = None) optax.OptState[来源]#
在给定的优化器状态下,对所有参数应用可调用对象。
这个函数旨在帮助构建优化器状态的分区规格,以便在已知参数的分区规格的情况下。
例如,下面的内容将用给定的分区规范的副本替换所有优化器状态参数树。可以使用参数 transform_non_params 根据需要替换任何剩余字段,在这种情况下,我们将这些字段替换为 None。
>>> params, specs = jnp.array(0.), jnp.array(0.) # Trees with the same shape >>> opt = optax.sgd(1e-3) >>> state = opt.init(params) >>> opt_specs = optax.tree_map_params( ... opt, ... lambda _, spec: spec, ... state, ... specs, ... transform_non_params=lambda _: None, ... )
- Parameters:
initable – 一个可调用的,接受参数并返回优化器状态,或者一个具有相同功能的init属性的对象。
f – 一个可调用的对象,将应用于该优化器状态下参数树的所有副本。
state – 要映射的优化器状态。
*rest – 额外的参数,与参数树的形状相同,将被传递给 f。
transform_non_params – 一个可选的函数,将在优化器状态中的所有非参数字段上调用。
is_leaf – 传递给 jax.tree.map。这使得可以忽略参数树的某些部分,例如当梯度变换修改原始 pytree 的形状时,如对于
optax.masked。
- Returns:
对优化器状态中所有与参数树具有相同形状的树应用函数 f 的结果,以及给定的可选额外参数。
树的最大值#
树的乘法#
像这样的树#
根据树的结构拆分键#
具有随机值的树#
- optax.tree_utils.tree_random_like(rng_key: ~jax.Array, target_tree: chex.ArrayTree, sampler: ~collections.abc.Callable[[~jax.Array, ~typing.Sequence[int | ~typing.Any], str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType], TypeAliasForwardRef('chex.Array')] = <function normal>, dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None) chex.ArrayTree[来源]#
创建与目标树形状相同的随机条目树。
- Parameters:
rng_key – 随机数生成器的钥匙。
target_tree – 要匹配结构的树。叶子必须是数组。
sampler – 噪声采样函数,默认是
jax.random.normal。dtype – 随机数的期望数据类型,传递给
sampler。如果为 None,则如果可能,使用目标树的 dtype。
- Returns:
一个具有与
target_tree相同结构的随机树,其叶子具有sampler的分布。
警告
可能的数据类型可能受到采样器的限制,例如
jax.random.rademacher仅支持整数数据类型,如果目标树的 dtype 不是整数,或者 dtype 不是整数类型,则会引发错误。在版本 0.2.1 中新增。
树的标量乘法#
在树中设置值#
- optax.tree_utils.tree_set(tree: optax.PyTree, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | 命名元组键, ...], Any], bool] | None = None, /, **kwargs: Any) optax.PyTree[来源]#
创建一个树的副本,替换指定kwargs的一些值。
在
tree中搜索keys,在**kwargs中(它可以是字典中的一个键,NamedTuple中的一个字段或NamedTuple的名称)。如果找到这样的键,将相应的值替换为在**kwargs中给定的值。如果
**kwargs中的一些键在树中不存在,则会引发一个KeyError。- Parameters:
tree – 要被替换的pytree值。
filtering – optional callable to further filter values in
treethat match the keys to replace.filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match a given key.**kwargs – 用于替换
tree中的值的键字典。
- Returns:
- new_tree
new pytree with the same structure as
tree. For each element intreewhose key/field matches a key in**kwargs, its value is set by the corresponding value in**kwargs.
- Raises:
KeyError – 如果在
**kwargs中找不到某个键的值,或者没有值满足过滤操作。
示例
基本用法
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> print(state) (ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) >>> new_state = optax.tree_utils.tree_set(state, count=2.) >>> print(new_state) (ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
与过滤操作的使用
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = opt.init(params) >>> print(state) InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> new_state = optax.tree_utils.tree_set( ... state, filtering, learning_rate=jnp.asarray(0.1) ... ) >>> print(new_state) InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
注意
推荐使用注入超参数调度的方法是通过
optax.inject_hyperparams()。这个函数是其他用途的辅助工具。在版本 0.2.2 中添加。
树的减法#
树的总和#
树内积#
- optax.tree_utils.tree_vdot(tree_x: Any, tree_y: Any) chex.Numeric[来源]#
计算两个pytrees之间的内积。
- Parameters:
tree_x – 第一个要使用的pytree。
tree_y – 第二个 pytree 以供使用。
- Returns:
tree_x和tree_y之间的内积,标量值。
示例
>>> optax.tree_utils.tree_vdot( ... {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])}, ... {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])}, ... ) Array(0, dtype=int32)
注意
我们将值提升到最高精度以避免数值问题。