转换#
|
Clips updates to be at most |
alias of |
|
|
添加参数并按weight_decay缩放。 |
alias of |
|
|
添加梯度噪声。 |
|
用于添加梯度噪声的状态。 |
|
每 k 步累积梯度并应用它们。 |
|
包含一个计数器和一个梯度累加器。 |
|
执行偏差校正。 |
|
仅在特定步骤调用内部更新函数。 |
|
仅在特定步骤调用内部更新函数。 |
|
|
|
维护内部转换状态并添加步骤计数器。 |
集中梯度。 |
|
|
Clips updates element-wise, to be in |
|
将每个参数向量或矩阵的梯度更新限制为最大均方根。 |
alias of |
|
|
通过它们的全局范数来截断更新。 |
alias of |
|
|
计算过去更新的指数移动平均值。 |
|
保存过去更新的指数移动平均。 |
最简单无状态转换的空状态。 |
|
|
计算嵌套张量结构的全局范数。 |
|
一对实现梯度变换的纯函数。 |
|
GradientTransformation的一个特殊化,支持额外的参数。 |
|
无状态的身份转换,输入梯度保持不变。 |
修改更新以保持参数非负,即 >= 0。 |
|
alias of |
|
|
通过更新规范的逆来缩放。 |
alias of |
|
alias of |
|
|
根据全局范数对每个示例应用梯度裁剪。 |
|
每个示例使用每层规范应用梯度裁剪。 |
|
通过某个固定的标量 step_size 更新规模。 |
alias of |
|
|
根据Adadelta算法重新调整更新。 |
|
Adadelta算法的重新缩放状态。 |
|
根据Adan算法重新调整比例。 |
|
|
|
根据Adam算法重新调整更新。 |
|
根据Adamax算法重新缩放更新。 |
|
Adam算法的状态。 |
|
根据AMSGrad算法重新缩放更新。 |
|
AMSGrad算法的状态。 |
|
回溯线搜索确保充分下降(Armijo准则)。 |
State for |
|
|
根据AdaBelief算法重新调整更新。 |
|
用于AdaBelief算法的重新缩放状态。 |
|
通过梯度均方根的因素估计进行缩放(如在Adafactor中)。 |
|
梯度变换的整体状态。 |
|
通过L-BFGS缩放更新。 |
|
LBFGS 求解器的状态。 |
|
通过(负)学习率进行缩放(可以是标量或计划)。 |
|
根据Lion算法调整缩放。 |
|
狮子算法的状态。 |
|
计算NovoGrad更新。 |
|
Novograd的状态。 |
|
计算广义乐观梯度。 |
|
通过该参数块的参数的范数来缩放每个参数块的更新。 |
|
根据每个参数向量或矩阵的梯度的均方根来缩放更新。 |
|
通过Polyak的步长缩放更新。 |
|
根据修正Adam算法重新缩放更新。 |
|
通过指数的平方根重新调整更新。 |
|
用于指数根均方(RMS)标准化更新的状态。 |
|
使用Rprop优化器进行缩放。 |
|
|
|
通过迄今为止所有平方梯度的和的平方根重新调整更新。 |
|
截至目前保存梯度平方和的状态。 |
|
使用自定义计划对步长进行缩放更新。 |
|
维护用于缩放调度的计数。 |
计算梯度元素的符号。 |
|
|
通过 sm3 缩放更新。 |
|
SM3算法的状态。 |
|
通过中心化指数的平方根重新调整更新。 |
|
用于中心化指数移动平方更新的状态。 |
|
通过 信任比例 缩放更新。 |
alias of |
|
|
根据Yogi算法更新重缩放。 |
|
线搜索确保足够的下降和小的曲率。 |
|
scale_by_zoom_linesearch的状态。 |
无状态转换,将输入梯度映射为零。 |
|
|
从类似更新的函数创建无状态转换。 |
从一个类似于更新的函数为数组创建一个无状态的转换。 |
|
|
计算过去更新的迹。 |
|
保存过去更新的聚合。 |
|
用于GradientTransformation的init步骤的可调用类型。 |
|
用于GradientTransformation的update步骤的可调用类型。 |
|
计算无限范数的指数移动平均。 |
|
计算第 order 阶动量的指数移动平均。 |
|
计算元素范数的order次动量的指数移动平均。 |
alias of |
|
包装一个梯度转换,使其忽略额外参数。 |
|
一种将NaN替换为0的变换。 |
|
|
包含一棵树。 |
|
关于缩放线搜索步骤的信息,供调试使用。 |
类型#
- class optax.GradientTransformation(init: 转换初始化函数, update: 变换更新函数)[源]#
一对实现梯度变换的纯函数。
Optax 优化器都被实现为 梯度变换。梯度变换被定义为一对纯函数,这些函数被组合在一个 命名元组 中,以便可以按名称引用。
请注意,为希望在更新步骤中传递额外参数的用户提供了扩展API。有关更多详细信息,请参见
optax.GradientTransformationExtraArgs()。由于梯度变换不包含任何内部状态,因此所有有状态的优化器属性(例如在使用优化器调度或动量值时的当前步数)通过使用优化器 state pytree 通过 optax 梯度变换传递。每次应用梯度变换时,都会计算并返回一个新状态,准备传递给下一个梯度变换的调用。
由于梯度变换是纯粹的、幂等的函数,改变梯度变换在步骤之间的行为的唯一方法是更改优化器状态中的值。要查看在控制optax梯度变换的行为时变更优化器状态的示例,请参见optax文档中的元学习示例。
- init#
一个纯函数,当使用参数的示例实例调用时,它将返回一个pytree,包含优化器状态的初始值。
- update#
一个纯函数,输入的是一个更新的 pytree(具有与传递给 init 的原始 params pytree 相同的树结构)、之前的优化器状态(可能使用 init 函数进行初始化),以及可选的当前 params。更新函数随后返回计算出的梯度更新和新的优化器状态。
- class optax.GradientTransformationExtraArgs(init: 转换初始化函数, update: 变换更新函数)[来源]#
一种支持额外参数的GradientTransformation的特化。
通过向更新函数传递额外参数来扩展现有的GradientTransformation接口。
注意,如果没有提供额外的参数,那么该函数的API与
TransformUpdateFn的情况是相同的。这意味着我们可以安全地将任何(不支持额外参数的)梯度转换包装为支持额外参数的转换。新的梯度转换将接受(并忽略)用户可能传递给它的任何额外参数。这是由optax.with_extra_args_support()实现的行为。- update#
覆盖基类型中更新的类型签名以接受额外的参数。
- Type:
optax._src.base.TransformUpdateExtraArgsFn
- class optax.TransformInitFn(*args, **kwargs)[源]#
一个可调用的类型,用于GradientTransformation的init步骤。
init 步骤接受一个 params 的树,并使用这些来构造一个任意结构的初始 state 用于梯度转换。这可能包含过去更新的统计数据或任何其他非静态信息。
- class optax.TransformUpdateFn(*args, **kwargs)[源]#
一个可调用的类型,用于GradientTransformation的update步骤。
这个 update 步骤接受一个候选参数 updates 的树(例如,它们相对于某些损失的梯度)、一个任意结构的 state 和正在优化的模型的当前 params。 params 参数是可选的,但在使用需要访问参数当前值的变换时,必须提供该参数。
对于需要附加参数的情况,可以使用替代接口,详见
TransformUpdateExtraArgsFn。
- optax.OptState#
alias of
Array|ndarray|bool|number|Iterable[ArrayTree] |Mapping[Any,ArrayTree]
转换和状态#
- optax.adaptive_grad_clip(clipping: float, eps: float = 0.001) optax.GradientTransformation[源]#
剪辑更新最多为
clipping * parameter_norm, 单位-wise。- Parameters:
clipping – 更新范数与参数范数的最大允许比率。
eps – 一个 epsilon 项,用于防止零初始化参数的截断。
- Returns:
参考文献
Brock 等, 高性能大规模图像识别无需归一化, 2021
- optax.AdaptiveGradClipState[源]#
的别名
EmptyState
- optax.add_decayed_weights(weight_decay: float | 数组 = 0.0, mask: Any | Callable[[TypeAliasForwardRef('optax.Params')], Any] | None = None) optax.GradientTransformation[来源]#
添加参数,通过 weight_decay 缩放。
- Parameters:
weight_decay – 一个标量权重衰减率。
mask – 与 params PyTree 具有相同结构(或是前缀)的树,或者是一个可调用的函数,该函数根据 params/updates 返回这样的 pytree。叶子应为布尔值,True 表示要应用变换的叶子/子树,False 表示要跳过的叶子/子树。
- Returns:
一个
optax.GradientTransformation对象。
- optax.AddDecayedWeightsState[来源]#
的别名
EmptyState
- optax.add_noise(eta: float, gamma: float, seed: int) optax.GradientTransformation[来源]#
添加梯度噪声。
- Parameters:
eta – 添加到梯度的高斯噪声的基本方差。
gamma – 方差退火的衰减指数。
seed – 随机数生成的种子。
- Returns:
参考文献
Neelakantan等人,添加梯度噪声改善非常深层网络的学习,2015
- optax.apply_every(k: int = 1) optax.GradientTransformation[源]#
累积梯度并每隔 k 步应用一次。
请注意,如果此转换是链的一部分,则其他转换的状态将在每一步仍然更新。特别是,使用 apply_every 以及批处理大小为 N/2 和 k=2 并不一定等同于不使用 apply_every 且批处理大小为 N。如果这个等价性对您很重要,请考虑使用 optax.MultiSteps。
- Parameters:
k – 每 k 步发出非零梯度,否则进行累积。
- Returns:
- optax.centralize() optax.GradientTransformation[源]#
集中化梯度。
- Returns:
一个
optax.GradientTransformation对象。
参考文献
Yong等人, 梯度中心化:一种用于深度神经网络的新优化技术,2020年。
- optax.conditionally_mask(inner: 基础.渐变转换, should_transform_fn: ConditionFn, forward_extra_args: bool = False) 基础.渐变变换额外参数[源]#
仅在特定步骤调用内部更新函数。
创建一个转换包装器,该包装器在条件满足的情况下应用内部梯度转换,如果条件不满足,则更新设置为0,同时内部状态不变地传递。行为由用户指定的函数
should_transform_fn控制,该函数由conditionally_transform调用,作为输入传递一个计数器,计数器记录了update函数之前被调用的次数,用户指定的函数必须返回一个布尔值,以控制是否应该调用内部转换。- Parameters:
inner – 内部转换。
should_transform_fn – 函数接受一个步数计数器(形状为 [] 的数组,数据类型为
int32),并返回一个形状为 [] 的布尔数组。如果forward_extra_args设置为 True,任何额外的参数也会被转发到should_transform_fn。forward_extra_args – 将额外参数传递给
should_transform_fn。
- Returns:
警告
如果你想在条件不满足时保持
updates不变,你可以使用conditionally_transform包装器。在版本 0.2.3 中添加。
- optax.conditionally_transform(inner: 基本.渐变转换, should_transform_fn: ConditionFn, forward_extra_args: bool = False) base.GradientTransformationExtraArgs[源]#
仅在特定步骤调用内部更新函数。
创建一个转换包装器,该包装器有条件地应用内部梯度 转换,如果条件不满足,则仅将更新和 内部状态原样传递。行为由用户指定的 函数
should_transform_fn控制,该函数由conditionally_transform调用,输入为update函数 之前被调用的次数计数,用户指定的函数必须返回一个布尔值 以控制是否应该调用内部转换。- Parameters:
inner – 内部转换。
should_transform_fn – function takes in a
stepcounter (array of shape [] and dtypeint32), and returns a boolean array of shape []. Ifforward_extra_argsis set to True, any extra arguments are also forwarded to theshould_transform_fn.forward_extra_args - 将额外参数传递给
should_transform_fn。
- Returns:
警告
如果你想在条件不满足时将
updates设置为零,你可以使用conditionally_mask包装器。在版本 0.2.3 中添加。
- optax.clip(max_delta: chex.Numeric) optax.GradientTransformation[源]#
剪辑逐元素更新,保持在
[-max_delta, +max_delta]之间。- Parameters:
max_delta – 更新中每个元素的最大绝对值。
- Returns:
一个
optax.GradientTransformation对象。
- optax.clip_by_block_rms(threshold: float) optax.GradientTransformation[源]#
剪辑每个参数向量或矩阵的梯度的最大均方根值。
一个 block 在这里是一个权重向量(例如,在线性层中)或一个权重矩阵(例如,在卷积层中),作为 grads/param pytree 中的一个叶子出现。
- Parameters:
threshold – 每个参数向量或矩阵的梯度最大均方根。
- Returns:
一个
optax.GradientTransformation对象。
- optax.ClipState[源]#
的别名
EmptyState
- optax.clip_by_global_norm(max_norm: float) optax.GradientTransformation[源]#
使用它们的全局范数更新剪辑。
- Parameters:
max_norm - 更新的最大全局范数。
- Returns:
一个
optax.GradientTransformation对象。
参考文献
Pascanu 等,关于训练递归神经网络的难度,2012
- optax.ClipByGlobalNormState[源]#
的别名
EmptyState
- optax.ema(decay: float, debias: bool = True, accumulator_dtype: Any | None = None) optax.GradientTransformation[源]#
计算过去更新的指数移动平均值。
- Parameters:
decay – 指数移动平均的衰减率。
debias – 是否对转换后的梯度进行去偏置。
accumulator_dtype – 可选的 dtype 用于累加器;如果 None,则 dtype 从 params 和 updates 中推断。
- Returns:
一个
optax.GradientTransformation对象。
注意
optax.trace()andoptax.ema()have very similar but distinct updates;trace = decay * trace + t, whileema = decay * ema + (1-decay) * t. Both are frequently found in the optimization literature.
- optax.identity() 渐变转换[源]#
无状态身份变换,保持输入梯度不变。
这个函数将通过 梯度更新 不变地传递。
注意,这不应与 set_to_zero 混淆,该函数将输入更新映射为零——这是在对 模型参数 应用更新时,所需的转换,以使其保持不变。
- Returns:
一个
optax.GradientTransformation对象。
- optax.keep_params_nonnegative() optax.GradientTransformation[源]#
修改更新以保持参数非负,即 >= 0。
此转换确保更新后的参数将大于或等于零。 在一系列转换中,这应该是最后一个。
- Returns:
一个
optax.GradientTransformation对象。
警告
转换要求输入参数为非负数。当参数为负时,转换后的更新将其移动到0。
- optax.NonNegativeParamsState[源]#
的别名
EmptyState
- optax.normalize_by_update_norm(scale_factor: float = 1.0, eps: float = 1e-06) optax.GradientTransformation[源]#
通过更新规范的逆向进行缩放。
示例
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.normalize_by_update_norm(scale_factor=-1.0) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function:', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 7.52E+00 Objective function: 3.03E+00 Objective function: 5.50E-01 Objective function: 6.67E-02 Objective function: 5.50E-01
- Parameters:
scale_factor – 更新将乘以的因子(默认为1)。
eps – 抖动项,以避免除以 0
- Returns:
一个
optax.GradientTransformation对象。
- optax.per_example_global_norm_clip(grads: chex.ArrayTree, l2_norm_clip: float) tuple[TypeAliasForwardRef('chex.ArrayTree'), 数组][源]#
根据每个示例的全局范数应用梯度裁剪。
- Parameters:
grads – 展平的更新;该函数期望此列表中的每个数组在第0轴上具有批次维度。
l2_norm_clip – 每个样本梯度的最大L2范数。
- Returns:
一个包含裁剪后的每个示例梯度的总和以及被裁剪的每个示例梯度数量的元组。
示例
>>> import optax >>> import jax.numpy as jnp >>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])] >>> optax.per_example_global_norm_clip(grads, jnp.inf) ([Array([7., 7., 7.], dtype=float32)], Array(0, dtype=int32)) >>> optax.per_example_global_norm_clip(grads, 0.0) ([Array([0., 0., 0.], dtype=float32)], Array(3, dtype=int32)) >>> optax.per_example_global_norm_clip(grads, 1.25) ([Array([1.75, 1.75, 1.75], dtype=float32)], Array(3, dtype=int32))
参考文献
Abadi et al., 深度学习与差分隐私, 2016
另请参阅
optax.contrib.differentially_private_aggregate()以获得更现实的 示例用法。
- optax.per_example_layer_norm_clip(grads: chex.ArrayTree, global_l2_norm_clip: float, uniform: bool = True) tuple[TypeAliasForwardRef('chex.ArrayTree'), TypeAliasForwardRef('chex.ArrayTree')][源]#
使用每层的规范进行每个示例的梯度裁剪。
如果 len(grads) == 1,则该函数等同于 optax.per_example_global_norm_clip。如果 len(grads) > 1,则 grads 中的每个数组将被独立地剪切为下面文档中记录的值
C_i。设
C = global_l2_norm_clip value。然后每层的裁剪操作如下:1. If
uniformisTrue, each of theKlayers has an individual clip norm ofC / sqrt(K).2. If
uniformisFalse, each of theKlayers has an individual clip norm ofC * sqrt(D_i / D)whereD_iis the number of parameters in layeri, andDis the total number of parameters in the model.- Parameters:
grads – 扁平化更新;即每一项是一个层的梯度的梯度列表;该函数期望这些在第0轴上具有批量维度。
global_l2_norm_clip – 使用的整体 L2 裁剪范数。
uniform – If True, per-layer clip norm is
global_l2_norm_clip/sqrt(L), whereLis the number of layers. Otherwise, per-layer clip norm isglobal_l2_norm_clip * sqrt(f), wherefis the fraction of total model parameters that are in this layer.
- Returns:
一个元组,包含每一层被裁剪的每个示例的梯度和裁剪的每个示例梯度的数量。
示例
>>> import optax >>> import jax.numpy as jnp >>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])] >>> optax.per_example_layer_norm_clip(grads, jnp.inf) ([Array([7., 7., 7.], dtype=float32)], [Array(0, dtype=int32)]) >>> optax.per_example_layer_norm_clip(grads, 0.0) ([Array([0., 0., 0.], dtype=float32)], [Array(3, dtype=int32)]) >>> optax.per_example_layer_norm_clip(grads, 1.25) ([Array([1.75, 1.75, 1.75], dtype=float32)], [Array(3, dtype=int32)])
参考文献
McMahan 等, 学习差分隐私递归语言模型,2017
- optax.scale(step_size: float) optax.GradientTransformation[源]#
按某个固定的标量step_size更新缩放。
- Parameters:
step_size – 一个标量,对应于更新的固定缩放因子。
- Returns:
一个
optax.GradientTransformation对象。
- optax.ScaleState[源]#
的别名
EmptyState
- optax.scale_by_adadelta(rho: float = 0.9, eps: float = 1e-06) optax.GradientTransformation[源]#
根据Adadelta算法进行重缩放更新。
有关更多详细信息,请参见
optax.adadelta()。- Parameters:
rho – 用于计算平方梯度的移动平均的系数。
eps – 添加到分母中的项以提高数值稳定性。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_adan(b1: float = 0.98, b2: float = 0.92, b3: float = 0.99, eps: float = 1e-08, eps_root: float = 0.0) optax.GradientTransformation[源]#
根据Adan算法进行重缩放更新。
有关更多细节,请参见
optax.adan()。- Parameters:
b1 – 渐变的EWMA衰减率。
b2 – 差值梯度的EWMA的衰减率。
b3 – 算法平方项的EMWA衰减率。
eps – 添加到分母中的项以提高数值稳定性。
eps_root – 添加到平方根内部的分母中的一个项,以提高在反向传播梯度通过重缩放时的数值稳定性。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_adam(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: str | type[Any] | 数据类型 | SupportsDType | None = None, *, nesterov: bool = False) optax.GradientTransformation[源]#
根据Adam算法更新缩放。
有关更多详细信息,请参见
optax.adam()。- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 指数加权平方梯度的衰减率。
eps – 添加到分母中的项,以提高数值稳定性。
eps_root – 在平方根内部添加到分母的项,以改善在通过重新缩放反向传播梯度时的数值稳定性。
mu_dtype – 可选的 dtype,用于一阶累加器;如果 None,则 dtype 从 params 和 updates 中推断。
nesterov – 是否使用Nesterov动量。带有Nesterov动量的Adam变体在[Dozat 2016]中描述。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_adamax(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08) optax.GradientTransformation[源]#
根据Adamax算法进行重缩放更新。
有关更多详细信息,请参见
optax.adamax()。- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 用于指数加权最大梯度的衰减率。
eps – 添加到分母中的项以提高数值稳定性。
- Returns:
一个
optax.GradientTransformation对象。
- class optax.ScaleByAdamState(count: chex.Array, mu: optax.Updates, nu: optax.Updates)[源]#
Adam算法的状态。
- optax.scale_by_amsgrad(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: str | type[Any] | 数据类型 | SupportsDType | None = None) optax.GradientTransformation[源]#
根据AMSGrad算法更新重标定。
有关更多详细信息,请参见
optax.amsgrad()。- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 平方梯度的指数加权平均的衰减率。
eps – 添加到分母中的项以提高数值稳定性。
eps_root – 在平方根内部添加的项,以提高在通过重新缩放反向传播梯度时的数值稳定性。
mu_dtype – 可选的 dtype,用于一阶累加器;如果 None,则 dtype 从 params 和 updates 中推断。
- Returns:
一个
optax.GradientTransformation对象。
- class optax.ScaleByAmsgradState(count: chex.Array, mu: optax.Updates, nu: optax.Updates, nu_max: optax.Updates)[源]#
AMSGrad算法的状态。
- optax.scale_by_backtracking_linesearch(max_backtracking_steps: int, slope_rtol: float = 0.0001, decrease_factor: float = 0.8, increase_factor: float = 1.5, max_learning_rate: float = 1.0, atol: float = 0.0, rtol: float = 0.0, store_grad: bool = False, verbose: bool = False) 基础.渐变转换额外参数[源]#
回溯线搜索确保充分减少(阿米霍标准)。
选择学习率 \(\eta\) 以验证充分下降标准
\[f(w + \eta u) \leq (1+\delta)f(w) + \eta c \langle u, \nabla f(w) \rangle + \epsilon \,, \]哪里
\(f\) is the function to minimize, \(w\) are the current parameters, \(\eta\) is the learning rate to find, \(u\) is the update direction, \(c\) is a coefficient (
slope_rtol) measuring the relative decrease of the function in terms of the slope (scalar product between the gradient and the updates), \(\delta\) is a relative tolerance (rtol), \(\epsilon\) is an absolute tolerance (atol).该算法从给定的学习率猜测开始,并通过
decrease_factor递减,直到满足上述标准。- Parameters:
max_backtracking_steps – 线性搜索的最大迭代次数。
slope_rtol – 关于斜率的相对容忍度。充分减少必须是 slope_rtol * lr *
,见上面的公式。 decrease_factor – 减小学习率的因子。
increase_factor – 增加学习率猜测的增加因子。将其设置为 1 意味着保持当前猜测,将其设置为
math.inf意味着在每轮开始时使用max_learning_rate。max_learning_rate – 最大学习率(学习率猜测被限制在此)。
atol – 满足条件所需的绝对容忍度。
rtol – 需要满足的相对容忍度。
store_grad – 是否在行搜索结束时计算并存储梯度。由于该函数被调用来计算接受学习率的值,因此我们也可以在此过程中访问梯度。通过这样做,我们可以直接重用在行搜索结束时计算的值和梯度,用于下一次迭代,使用
optax.value_and_grad_from_state()。请参见上面的示例。verbose – 是否打印调试信息。
- Returns:
A
GradientTransformationExtraArgs, where theupdatefunction takes the following additional keyword arguments:value: value of the function at the current params.grad: gradient of the function at the current params.value_fn: function returning the value of the function we seek to optimize.**extra_args: additional keyword arguments, if the function needs additional arguments such as input data, they should be put there ( see example in this docstring).
示例
一个使用回溯线搜索与随机梯度下降(SGD)的示例:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch(max_backtracking_steps=15) ... ) >>> # Function with additional inputs other than params >>> def fn(params, x, y): return optax.l2_loss(x.dot(params), y) >>> params = jnp.array([1., 2., 3.]) >>> opt_state = solver.init(params) >>> x, y = jnp.array([3., 2., 1.]), jnp.array(0.) >>> xs, ys = jnp.tile(x, (5, 1)), jnp.tile(y, (5,)) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 5.00E+01 >>> for x, y in zip(xs, ys): ... value, grad = jax.value_and_grad(fn)(params, x, y) ... updates, opt_state = solver.update( ... grad, ... opt_state, ... params, ... value=value, ... grad=grad, ... value_fn=fn, ... x=x, ... y=y ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 3.86E+01 Objective function: 2.50E+01 Objective function: 1.34E+01 Objective function: 5.87E+00 Objective function: 5.81E+00
一个类似的例子,但使用非随机函数,我们可以重用在搜索行末尾计算的值和梯度:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> # Function without extra arguments >>> def fn(params): return jnp.sum(params ** 2) >>> params = jnp.array([1., 2., 3.]) >>> # In this case we can store value and grad with the store_grad field >>> # and reuse them using optax.value_and_grad_state_from_state >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch( ... max_backtracking_steps=15, store_grad=True ... ) ... ) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 5.04E+00 Objective function: 1.81E+00 Objective function: 6.53E-01 Objective function: 2.35E-01 Objective function: 8.47E-02
参考文献
Vaswani et al., 无痛随机梯度, 2019
Nocedal & Wright, 数值优化, 1999
警告
The sufficient decrease criterion might be impossible to satisfy for some update directions. To guarantee a non-trivial solution for the sufficient decrease criterion, a descent direction for updates (\(u\)) is required. An update (\(u\)) is considered a descent direction if the derivative of \(f(w + \eta u)\) at \(\eta = 0\) (i.e., \(\langle u, \nabla f(w)\rangle\)) is negative. This condition is automatically satisfied when using
optax.sgd()(without momentum), but may not hold true for other optimizers likeoptax.adam().More generally, when chained with other transforms as
optax.chain(opt_1, ..., opt_k, scale_by_backtraking_linesearch(max_backtracking_steps=...), opt_kplusone, ..., opt_n), the updates returned by chainingopt_1, ..., opt_kmust be a descent direction. However, any transform after the backtracking line-search doesn’t necessarily need to satisfy the descent direction property (one could for example use momentum).注意
该算法可以支持复杂的输入。
另请参阅
optax.value_and_grad_from_state()使此方法在非随机目标下更高效。在版本 0.2.0 中新增。
- class optax.ScaleByBacktrackingLinesearchState(learning_rate: float | jax.数组, value: float | jax.数组, grad: base.Updates | None, info: BacktrackingLinesearchInfo)[源]#
用于
optax.scale_by_backtracking_linesearch()的状态。- value#
在行搜索的最后计算出的目标值。可以使用
optax.value_and_grad_from_state()重用。- Type:
联合[浮点数, jax.Array]
- grad#
在一轮线搜索结束时计算的目标的梯度,如果线搜索实例化时设置了store_grad = True。否则它是None。可以使用
optax.value_and_grad_from_state()进行重用。- Type:
可选的[base.Updates]
- info#
关于回溯线搜索步骤的信息,供调试使用。
- Type:
回溯线搜索信息
- optax.scale_by_belief(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-16, eps_root: float = 1e-16, *, nesterov: bool = False) optax.GradientTransformation[源]#
根据AdaBelief算法更新缩放。
有关更多细节,请参见
optax.adabelief()。- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 指数加权平均梯度方差的衰减率。
eps – 添加到分母中的项,以提高数值稳定性。
eps_root – 添加到预测误差的二阶矩中的项,以改善数值稳定性。如果在梯度变换中反向传播梯度(例如,用于元学习),则该值必须为非零。
nesterov – 是否使用Nesterov动量。
- Returns:
一个
optax.GradientTransformation对象。
- class optax.ScaleByBeliefState(count: chex.Array, mu: optax.Updates, nu: optax.Updates)[源]#
AdaBelief算法的重缩放状态。
- optax.scale_by_factored_rms(factored: bool = True, decay_rate: float = 0.8, step_offset: int = 0, min_dim_size_to_factor: int = 128, epsilon: float = 1e-30, decay_rate_fn: ~collections.abc.Callable[[int, float], TypeAliasForwardRef('chex.Array')] = <function _decay_rate_pow>)[源]#
通过梯度均方根的分解估计进行缩放(如在Adafactor中所示)。
这是一种所谓的“1+epsilon”缩放算法,与RMSProp/Adam相比,它在内存效率上极为优越,并且在基于注意力的模型的大规模训练中取得了广泛成功。
- Parameters:
factored – 布尔值:是否使用分解的二阶矩估计。
decay_rate – 浮动数:控制第二矩指数衰减计划。
step_offset – 对于微调,可以将其设置为微调阶段的起始步骤号码。
min_dim_size_to_factor – 只有在两个数组维度至少达到这个大小时才进行因子累加。
epsilon – 用于平方梯度的正则化常数。
decay_rate_fn – 一个接受当前步骤、衰减率参数并控制第二动量调度的函数。默认为原始adafactor的幂衰减调度。原始调度的一个潜在缺点是第二动量收敛到1,这实际上冻结了第二动量。为了防止这种情况,用户可以选择一个自定义调度,设置第二动量的上限,如Zhai et al., 2021中所示。
- Returns:
参考文献
Shazeer 等人, Adafactor: Adaptive Learning Rates with Sublinear Memory Cost, 2018
Zhai等人,Scaling Vision Transformers,2021
- class optax.FactoredState(count: chex.Array, v_row: chex.ArrayTree, v_col: chex.ArrayTree, v: chex.ArrayTree)[源]#
梯度变换的总体状态。
- optax.scale_by_lbfgs(memory_size: int = 10, scale_init_precond: bool = True) optax.GradientTransformation[源]#
通过L-BFGS更新尺度。
L-BFGS是一种拟牛顿法,它将更新(梯度)与逆Hessian的近似值相乘。这个算法不需要访问Hessian,因为这个近似是根据在优化过程中看到的梯度评估构建的。L-BFGS是Broyden-Fletcher-Goldfarb-Shanno (BFGS)算法的有限内存变体。BFGS算法需要存储一个大小为 \(p \times p\) 的矩阵,其中 \(p\) 是参数的维度。有限的变体通过仅使用 \(m\) (
memory_size) 过去的参数/梯度差异来计算逆的近似值,从而解决了这个问题。也就是说,Hessian逆的近似值被表示为 \(P_k = P_{k, k}\),其中\[\begin{align*} P_{k, j+1} & = V_j^\top P_{k, j} V_j + \rho_j \delta w_j \delta w_j^\top \quad \text{对于} \ j \in \{k-m, \ldots, k-1\}\\ P_{k, k-m} & = \gamma_k I \\ V_k & = I - \rho_k \delta u_k \delta w_k^\top \\ \rho_k & = 1/(\delta u_k^\top \delta w_k) \\ \delta w_k & = w_{k+1} - w_k \\ \delta u_k & = u_{k+1} - u_k \\ \gamma_k & = \begin{cases} (\delta w_{k-1}^\top \delta u_{k-1}) / (\delta u_{k-1}^\top \delta u_{k-1}) & \text{如果} \ \texttt{scale\_init\_hess} \\ 1 & \text{否则} \end{cases}, \end{align*}\]对于 \(u_k\) 在迭代 \(k\) 处的梯度/更新, \(w_k\) 在迭代 \(k\) 处的参数。
更新\(P_k\)的公式是通过计算满足某些割线条件的最优预处理矩阵获得的,详细信息请参阅参考文献。计算\(P_k u_k\)可以通过使用存储在内存缓冲区中的参数和梯度的过去差异进行一系列向量操作。
当前函数仅输出 LBFGS 方向 \(P_k u_k\)。它可以与确保充分减少和低曲率的线搜索链式调用,比如缩放线搜索。线搜索计算步长 \(\eta_k\),使得更新后的参数(使用
optax.apply_updates())的形式为 \(w_{k+1} = w_k - \eta_k P_k u_k\)。- Parameters:
memory_size – 过去参数的数量,保持的梯度/更新,以便近似海森逆。
scale_init_precond – 是否使用缩放的单位作为初始预处理器,见上面\(\gamma_k\)的公式。
- Returns:
一个
optax.GradientTransformation对象。
参考文献
Nocedal et al 的《数值优化
Liu et al., 关于大规模优化的有限内存BFGS方法, 1989.
注意
我们将恒等式的缩放初始化为梯度范数的上限倒数。这避免了在第一步中浪费线搜索迭代,考虑到梯度的大小。换句话说,我们将第一步的信任区域限制在第一次迭代时半径为1的欧几里得球体内。对\(\gamma_0\)的选择在上面的参考文献中没有详细说明,因此这是一个启发式选择。
- class optax.ScaleByLBFGSState(count: chex.Numeric, params: optax.Params, updates: optax.Params, diff_params_memory: chex.ArrayTree, diff_updates_memory: chex.ArrayTree, weights_memory: chex.Array)[源]#
LBFGS求解器的状态。
- count#
算法的迭代。
- Type:
chex.Numeric
- params#
当前参数。
- Type:
base.Params
- updates#
当前更新。
- Type:
基础.参数
- diff_params_memory#
表示过去参数差异的列表,直到某个预定的
memory_size固定在optax.scale_by_lbfgs()中。- Type:
chex.ArrayTree
- diff_updates_memory#
表示一个过去的梯度/更新的列表,直到某个预定的
memory_size固定在optax.scale_by_lbfgs()。- Type:
chex.ArrayTree
- weights_memory#
过去权重的列表,乘以定义逆Hessian近似的秩一矩阵,参见
optax.scale_by_lbfgs()获取更多详细信息。- Type:
chex.Array
- optax.scale_by_learning_rate(learning_rate: base.ScalarOrSchedule, *, flip_sign: bool = True) 基础.渐变转换[源]#
通过(负)学习率缩放(可以是标量或调度)。
- Parameters:
learning_rate – 可以是一个标量或一个调度(即一个可调用的,将一个(整数)步映射到一个浮点数)。
flip_sign – 当设置为 True(默认值)时,这对应于通过负学习率进行缩放。
- Returns:
一个optax.GradientTransformation,它对应于将梯度乘以 -learning_rate(如果flip_sign为True)或乘以 learning_rate(如果flip_sign为False)。
- optax.scale_by_lion(b1: float = 0.9, b2: float = 0.99, mu_dtype: str | type[Any] | 数据类型 | SupportsDType | None = None) optax.GradientTransformation[源]#
根据Lion算法更新缩放。
有关更多细节,请参见
optax.lion()。- Parameters:
b1 – 合并动量和当前梯度的速率。
b2 – 指数加权平均梯度的衰减率。
mu_dtype – 可选的 dtype 用于动量;如果 None 则 dtype 从 `params 和 updates 中推断。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_novograd(b1: float = 0.9, b2: float = 0.25, eps: float = 1e-08, eps_root: float = 0.0, weight_decay: float = 0.0, mu_dtype: str | type[Any] | 数据类型 | SupportsDType | None = None) optax.GradientTransformation[源]#
计算NovoGrad更新。
有关更多详细信息,请参见
optax.novograd()。- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 指数加权平方梯度的衰减率。
eps – 添加到分母中的一个术语,以提高数值稳定性。
eps_root – 在平方根内部的分母中添加的一个术语,以提高在通过重新缩放反向传播梯度时的数值稳定性。
weight_decay – 一个标量权重衰减率。
mu_dtype – 一个可选的 dtype,用于一阶累加器;如果 None,则 dtype 会从 params 和 updates 中推断。
- Returns:
- class optax.ScaleByNovogradState(count: chex.Array, mu: optax.Updates, nu: optax.Updates)[源]#
诺沃格拉德的状态。
- optax.scale_by_optimistic_gradient(alpha: float = 1.0, beta: float = 1.0) optax.GradientTransformation[源]#
计算广义乐观梯度。
参见
optax.optimistic_adam(),optax.optimistic_gradient_descent()获取更多详细信息。- Parameters:
alpha – 一般化乐观梯度下降的系数。
beta – 负动量的系数。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_param_block_norm(min_scale: float = 0.001) optax.GradientTransformation[源]#
每个参数块的规模更新由该块参数的范数决定。
一个 block 在这里是一个权重向量(例如,在线性层中)或一个权重矩阵(例如,在卷积层中),作为 grads/param pytree 中的一个叶子出现。
- Parameters:
min_scale – 最小缩放因子。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_param_block_rms(min_scale: float = 0.001) optax.GradientTransformation[源]#
按每个参数向量或矩阵的梯度均方根更新规模。
一个 block 在这里是一个权重向量(例如,在线性层中)或一个权重矩阵(例如,在卷积层中),作为 grads/param pytree 中的一个叶子出现。
- Parameters:
min_scale – 最小缩放因子。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_radam(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, threshold: float = 5.0, *, nesterov: bool = False) optax.GradientTransformation[源]#
根据修正Adam算法更新重缩放。
有关更多详细信息,请参见
optax.radam()。- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 平方梯度的指数加权平均的衰减率。
eps – 添加到分母中的项,以提高数值稳定性。
eps_root – 在平方根内部添加的项,以提高在通过重新缩放反向传播梯度时的数值稳定性。
threshold – 方差可跟踪性的阈值。
nesterov – 是否使用Nesterov动量。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_polyak(f_min: float = 0.0, max_learning_rate: float = 1.0, eps: float = 0.0, variant: str = 'sps') 基础.渐变变换额外参数[源]#
通过Polyak的步长缩放更新。
查看
optax.polyak_sgd()以获取更多详细信息。- Parameters:
f_min – 目标函数的下限(默认为0)。对应于\(f^\star\)在上面的公式中。
max_learning_rate – 使用的最大步长(默认为1)。
eps – 在更新的分母中添加的值(默认为0)。
变体 – 可以是
'sps'或'sps+'(默认为'sps')。
- Returns:
一个
optax.GradientTransformationExtraArgs,其中update函数接受一个额外的关键字参数value,包含当前目标函数的值。
- optax.scale_by_rms(decay: float = 0.9, eps: float = 1e-08, initial_scale: float = 0.0, eps_in_sqrt: bool = True, bias_correction: bool = False) optax.GradientTransformation[源]#
通过平方的指数移动平均的根来重新缩放更新。
详情请参见
optax.rmsprop()。- Parameters:
decay – 指数加权平均平方梯度的衰减率。
eps – 添加到分母中的项,以提高数值稳定性。
initial_scale – 二次矩的初始值。
eps_in_sqrt – 是否在分母的平方根中或平方根外添加
eps。bias_correction – 是否对平方梯度的指数加权平均应用偏差校正。
- Returns:
一个
optax.GradientTransformation对象。
注意
使用 scale_by_rms(decay=b2, eps_in_sqrt=False, bias_correction=True) 将与 scale_by_adam(b1=0, b2=b2) 的行为相匹配,同时节省存储第一动量的内存开销。
- optax.scale_by_rprop(learning_rate: float, eta_minus: float = 0.5, eta_plus: float = 1.2, min_step_size: float = 1e-06, max_step_size: float = 50.0) optax.GradientTransformation[源]#
使用 Rprop 优化器进行缩放。
有关更多详细信息,请参见
optax.rprop()。- Parameters:
learning_rate – 初始步长。
eta_minus – 减小步长的乘数因子。这在梯度在一步到下一步时符号发生变化时应用。
eta_plus – 增加步长的乘法因子。当梯度在一个步骤到下一个步骤之间具有相同的符号时,应用此因子。
min_step_size – 最小允许步长。较小的步长将被限制为此值。
max_step_size – 最大允许的步长。更大的步长将被限制为此值。
- Returns:
- optax.scale_by_rss(initial_accumulator_value: float = 0.1, eps: float = 1e-07) optax.GradientTransformation[源]#
通过迄今为止所有平方梯度的总和的平方根来更新缩放。
有关更多详细信息,请参见
optax.adagrad()。- Parameters:
initial_accumulator_value – 累加器的起始值,必须大于或等于 0。
eps – 一个小的浮点值,以避免零分母。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_schedule(step_size_fn: base.Schedule) 基本.渐变转换[源]#
使用自定义计划更新step_size的比例。
- Parameters:
step_size_fn – 一个接受更新计数作为输入并建议要乘以更新的步长的函数。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_sign() optax.GradientTransformation[源]#
计算梯度元素的符号。
- Returns:
一个包含输入梯度符号的optax.GradientTransformation。
- optax.scale_by_sm3(b1: float = 0.9, b2: float = 1.0, eps: float = 1e-08) optax.GradientTransformation[源]#
通过 sm3 更新规模。
有关更多详细信息,请参见
optax.sm3()。- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 平方梯度的指数加权平均的衰减率。
eps – 添加到分母中的项,以提高数值稳定性。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_stddev(decay: float = 0.9, eps: float = 1e-08, initial_scale: float = 0.0, eps_in_sqrt: bool = True, bias_correction: bool = False) optax.GradientTransformation[源]#
通过平方的中心指数移动平均的根来重新调整更新。
详情请参见
optax.rmsprop()。- Parameters:
decay – 指数加权平均平方梯度的衰减率。
eps – 添加到分母中的项,以提高数值稳定性。
initial_scale – 第二时刻的初始值。
eps_in_sqrt – 是否在分母的平方根中或平方根外添加
eps。bias_correction – 是否对第一和第二时刻应用偏差校正。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_trust_ratio(min_norm: float = 0.0, trust_coefficient: float = 1.0, eps: float = 0.0) optax.GradientTransformation[源]#
根据信任比率进行规模更新。
用于
optax.fromage(),optax.lars(),optax.lamb()。- Parameters:
min_norm – 参数和梯度范数的最小范数;默认值为零。
trust_coefficient – 信任比率的乘数。
eps – 为了数值稳定性而添加到分母的加性常数。
- Returns:
一个
optax.GradientTransformation对象。
- optax.ScaleByTrustRatioState[源]#
的别名
EmptyState
- optax.scale_by_yogi(b1: float = 0.9, b2: float = 0.999, eps: float = 0.001, eps_root: float = 0.0, initial_accumulator_value: float = 1e-06) optax.GradientTransformation[源]#
根据Yogi算法进行重缩放更新。
有关更多详细信息,请参阅
optax.yogi()。支持复数,参见 https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 指数加权平均梯度方差的衰减率。
eps – 添加到分母中的项,以提高数值稳定性。
eps_root – 在平方根内部添加的项,以提高在通过重新缩放反向传播梯度时的数值稳定性。
initial_accumulator_value – 累加器的起始值。仅允许正值。
- Returns:
一个
optax.GradientTransformation对象。
- optax.scale_by_zoom_linesearch(max_linesearch_steps: int, max_learning_rate: float | None = None, tol: float = 0.0, increase_factor: float = 2.0, slope_rtol: float = 0.0001, curv_rtol: float = 0.9, approx_dec_rtol: float | None = 1e-06, stepsize_precision: float = 1e-05, initial_guess_strategy: str = 'keep', verbose: bool = False) 基础.渐变变换额外参数[源]#
线搜索确保充分下降和小曲率。
该算法搜索一个学习率,也称为步长,满足 既定的充分减少标准,也称为Armijo-Goldstein标准,
\[f(w + \eta u) \leq f(w) + \eta c_1 \langle u, \nabla f(w) \rangle + \epsilon \,, \]以及一个小的曲率(沿着更新方向)标准,即沃尔夫或第二沃尔夫标准,
\[|\langle \nabla f(w + \eta u), u \rangle| \leq c_2 |\langle \nabla f(w), \rangle| + \epsilon\,, \]哪里
\(f\) 是要最小化的函数,
\(w\) 是当前的参数,
\(\eta\) 是要找到的学习率,
\(u\) 是更新方向,
\(c_1\) 是一个系数 (
slope_rtol),用于测量函数在斜率方面的相对降低(梯度与更新之间的标量积),\(c_2\) 是一个系数 (
curv_rtol),用来测量曲率的相对减少。\(\epsilon\) 是一个绝对容差 (
tol).
为了处理非常平坦的函数,这个线搜索将从上述提出的充分减小准则切换到Hager和Zhang引入的近似充分减小准则(见[Hager和Zhang,2006])。
\[|\langle \nabla f(w+\eta u), u \rangle| \leq (2 c_1 - 1) |\langle \nabla f(w), \rangle| + \epsilon\,. \]只有当线搜索尝试的值低于初始函数的相对减少时,才采用近似曲率标准,即:
\[f(w + \eta u) \leq f(w) + c_3 |f(w)| \]其中 \(c_3\) 是一个系数
approx_dec_rtol,用于测量目标的相对减少(请参阅下面的参考文献和代码中的注释以获取更多详细信息)。原始的充分减少准则只能捕捉到最多 \(\sqrt{\varepsilon_{machine}}\) 的差异,而近似充分减少准则可以捕捉到最多 \(\varepsilon_{machine}\) 的差异(见 [Hager and Zhang, 2006])。请注意,此附加功能不是原始实现 [Nocedal and Wright, 1999] 的一部分,可以通过将
approx_dec_rtol设置为None来移除。- Parameters:
max_linesearch_steps – 最大的线搜索迭代次数。
max_learning_rate – 允许的最大学习率。可以设置为
None表示没有上限。非None值可能会阻止线搜索找到满足小曲率标准的学习率,因为后者可能需要足够大的步长。tol – 标准的容忍度。
increase_factor – 增加因子,用于在搜索包含满足两个标准的学习速率的有效区间时增强学习速率。
slope_rtol – 足够减少准则中斜率的相对容忍度。
curv_rtol – 小曲率准则中曲率的相对容差。
approx_dec_rtol – 近似充分减少标准中初始值的相对容忍度。可以设置为
None以仅使用原始的Armijo-Goldstein减少标准。stepsize_precision – 在搜索满足两个条件的步长时的精度。该算法通过二分法进行,细化包含满足两个条件的步长的区间。如果该区间缩小到低于
stepsize_precision且已找到满足充分降低的步长,则算法选择该步长,即使曲率条件未满足。initial_guess_strategy – initial guess for the learning rate used to start the linesearch. Can be either
oneorkeep. Ifone, the initial guess is set to 1. Ifkeep, the initial guess is set to the learning rate of the previous step. We recommend to usekeepif this linesearch is used in combination with SGD. We recommend to useoneif this linesearch is used in combination with Newton methods or quasi-Newton methods such as L-BFGS.verbose – 如果行搜索失败,是否打印额外的调试信息。
- Returns:
一个
optax.GradientTransformationExtraArgs对象,包括一个初始化函数和一个更新函数。
示例
一个关于使用SGD的缩放线搜索的例子:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_zoom_linesearch(max_linesearch_steps=15) ... ) >>> # Function with additional inputs other than params >>> def fn(params, x, y): return optax.l2_loss(x.dot(params), y) >>> params = jnp.array([1., 2., 3.]) >>> opt_state = solver.init(params) >>> x, y = jnp.array([3., 2., 1.]), jnp.array(0.) >>> xs, ys = jnp.tile(x, (5, 1)), jnp.tile(y, (5,)) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 5.00E+01 >>> for x, y in zip(xs, ys): ... value, grad = jax.value_and_grad(fn)(params, x, y) ... updates, opt_state = solver.update( ... grad, ... opt_state, ... params, ... value=value, ... grad=grad, ... value_fn=fn, ... x=x, ... y=y ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 2.56E-13 Objective function: 2.84E-14 Objective function: 0.00E+00 Objective function: 0.00E+00 Objective function: 0.00E+00
一个类似的例子,但使用非随机函数,我们可以重用在搜索行末尾计算的值和梯度:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> # Function without extra arguments >>> def fn(params): return jnp.sum(params ** 2) >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_zoom_linesearch(max_linesearch_steps=15) ... ) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 0.00E+00 Objective function: 0.00E+00 Objective function: 0.00E+00 Objective function: 0.00E+00 Objective function: 0.00E+00
参考文献
Nocedal 和 Wright 的算法 3.5 3.6,数值优化,1999
Hager 和 Zhang 算法 851: CG_DESCENT,一个具有保证下降的共轭梯度方法,2006
注意
通过设置
curv_rtol=jnp.inf可以避免曲率标准。生成的算法将相当于一个回退线搜索,其中通过最小化目标的二次或三次近似来搜索满足充分减少的点。这在实践中可能是足够的,避免了线搜索花费许多迭代尝试满足小曲率标准。注意
该算法可以支持复杂的输入。
另请参阅
optax.value_and_grad_from_state()使此方法在非随机目标下更高效。
- class optax.ScaleByZoomLinesearchState(learning_rate: chex.Numeric, value: chex.Numeric, grad: base.Updates, info: 缩放线搜索信息)[源]#
用于 scale_by_zoom_linesearch 的状态。
- learning_rate#
在线搜索的一轮结束时计算的学习率, 用于缩放更新。
- Type:
chex.Numeric
- value#
在行搜索的最后计算出的目标值。可以使用
optax.value_and_grad_from_state()重用。- Type:
chex.Numeric
- grad#
在一轮线性搜索结束时计算的目标的梯度。可以使用
optax.value_and_grad_from_state()进行重用。- Type:
base.Updates
- info#
有关线搜索状态的更多信息,请参见
otpax.ZoomLinesearchInfo.- Type:
- optax.set_to_zero() 梯度变换[源]#
无状态转换,将输入梯度映射为零。
当调用生成的更新函数时,将返回一个与输入梯度形状匹配的零树。这意味着,当从此转换返回的更新应用于模型参数时,模型参数将保持不变。
这可以与 multi_transform 或 masked 结合使用,以冻结 (即保持固定)模型参数树的某些部分,同时对树的其他部分应用 梯度更新。
当更新设置为零时,在与梯度计算、optax 转换和对参数应用更新相同的 jit 编译函数内部,通常会省略不必要的计算。
- Returns:
一个
optax.GradientTransformation对象。
- optax.stateless(f: Callable[[Updates, Params | None], Updates]) 梯度变换[源]#
从类似更新的函数创建一个无状态转换。
这个包装器消除了创建一个不需要在迭代之间保存状态的转换所需的样板代码。
- Parameters:
f – 更新函数,它接受更新(例如,梯度)和参数并返回更新。参数可以是None。
- Returns:
- optax.stateless_with_tree_map(f: Callable[[TypeAliasForwardRef('chex.Array'), TypeAliasForwardRef('chex.Array') | None], TypeAliasForwardRef('chex.Array')]) 渐变转换[源]#
从类似于更新的函数为数组创建无状态转换。
这个包装器消除了创建不需要在迭代之间保存状态的转换所需的样板代码,就像optax.stateless一样。此外,这个函数将为您在更新/参数上应用树映射。
- Parameters:
f – 更新函数,接受一个更新数组(例如:梯度)和参数数组,并返回一个更新数组。参数数组可以是 None。
- Returns:
- optax.trace(decay: float, nesterov: bool = False, accumulator_dtype: Any | None = None) optax.GradientTransformation[源]#
计算过去更新的迹线。
- Parameters:
decay – 过去更新的痕迹衰减率。
nesterov – 是否使用Nesterov动量。
accumulator_dtype – 可选的 dtype 用于累加器;如果 None,则 dtype 从 params 和 updates 中推断。
- Returns:
一个
optax.GradientTransformation对象。
注意
optax.trace()andoptax.ema()have very similar but distinct updates;trace = decay * trace + t, whileema = decay * ema + (1-decay) * t. Both are frequently found in the optimization literature.
- optax.zero_nans() optax.GradientTransformation[源]#
一个将NaN替换为0的转换。
转换的状态与参数的树结构相同。每个叶子是一个布尔值,当在上一次调用
update时检测到相应参数数组中的 NaN 时,它的值为 True。这个状态在转换内部不被使用,但让用户知道 NaN 已被归零。- Returns:
- class optax.ZeroNansState(found_nan: Any)[源]#
包含一棵树。
条目 found_nan 与参数的树结构相同。每个叶子都是一个布尔值,仅在最后一次调用 update 时检测到相应参数数组中的 NaN 时为 True。
- class optax.ZoomLinesearchInfo(num_linesearch_steps: int | chex.Numeric, decrease_error: float | chex.Numeric, curvature_error: float | chex.Numeric)[源]#
有关缩放线搜索步骤的信息,公开用于调试。
正曲率误差并不严格。它可能是由于最大学习率过小导致的。充足曲率误差的正值更有问题,因为这意味着算法可能无法保证产生单调递减的值。如果线搜索失败,考虑在
scale_by_zoom_linesearch()中使用verbose=True进行额外的故障诊断。- num_linesearch_steps#
线搜索步数
- Type:
整数 | jax.Array | numpy.ndarray | numpy.bool | numpy.number | 浮点数
- decrease_error#
充足减少错误。正值表示线搜索未能找到确保充足减少的步长。空值表示成功找到这样的步长。
- Type:
浮点数 | jax.Array | numpy.ndarray | numpy.bool | numpy.number | 整数
- curvature_error#
小弯曲误差。正值表示线搜索未能找到确保小弯曲的步长。空值表示成功找到这样的步长。
- Type:
浮点数 | jax.Array | numpy.ndarray | numpy.bool | numpy.number | 整数