🔧 贡献#
不符合 纳入标准的实验性特征和算法。
|
ACProp优化器。 |
|
AdEMAMix. |
|
根据连续货币投注算法重新缩放更新。 |
|
连续硬币下注的状态。 |
|
通过 D-适应的无学习率 AdamW。 |
|
由dadapt_adamw返回的GradientTransformation的状态。 |
基于DPSGD算法聚合梯度。 |
|
|
包含用于differentially_private_aggregate的PRNGKey的状态。 |
|
基于梯度的距离优化器。 |
|
DoG优化器的状态。 |
|
加权梯度优化器的距离。 |
|
DoWG优化器的状态。 |
|
DPSGD优化器。 |
|
Mechanic - 一个黑箱学习率调整器/优化器。 |
|
由mechanize返回的GradientTransformation的状态。 |
|
带动量的SGD自适应学习率。 |
|
momo 返回的 GradientTransformation 的状态。 |
|
Adam(W)的自适应学习率。 |
|
State of the |
|
穆奥:由牛顿-舒尔茨正交化的动量。 |
|
亚当算法的状态。 |
|
使用Prodigy的无学习率AdamW。 |
|
prodigy 返回的 GradientTransformation 的状态。 |
|
SAM(锐度感知最小化)的实现。 |
|
sam 返回的 GradientTransformation 的状态。 |
|
使 base_optimizer 为 schedule_free。 |
|
AdamW的无调度包装器。 |
|
Params for evaluation of |
|
无调度的SGD包装器。 |
|
用于schedule_free的状态。 |
|
Sophia 优化器。 |
|
Sophia优化器的状态。 |
|
将复数更新的实部和虚部分成两个部分。 |
|
维护split_real_and_imaginary的内部转换状态。 |
AdEMAMix#
- optax.contrib.ademamix(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 5.0, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, weight_decay: float = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) 基础.渐变变换[来源]#
AdEMAMix.
AdEMAMix (自适应EMA混合) 是AdamW与两种动量项的混合,以更好地利用历史梯度。
SGD带动量(SGD+M)和Adam都使用过去梯度的指数移动平均(EMA)来结合动量
Let \(\eta\) represent the learning rate and \(\beta_1, \beta_2\), \(\beta_3, \alpha, \varepsilon, \bar{\varepsilon}\), represent the arguments
b1,b2,b3,alpha,epsandeps_rootrespectively. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).The
initfunction of this optimizer initializes an internal state \(S_0 := (m^{(1)}_0, m^{(2)}_0, \nu_0) = (0, 0, 0)\), representing initial estimates for the fast and slow EMAs of the first moment along with the second moment estimate. In practice, these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g^t\), the optimizer state \(S^t\) and the parameters \(\theta^{(t)}\). It then computes updates \(\theta^{(t+1)}\) and the new state \(S^{(t+1)}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + (1-\beta_1) \cdot g^{(t)} \\ m_2^{(t)} &\leftarrow \beta_3 \cdot m_2^{(t-1)} + (1-\beta_3) \cdot g^{(t)} \\ \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot {g^{(t)}}^2 \\ \hat{m_1}^{(t)} &\leftarrow m_1^{(t)} / {(1-\beta_1^{(t)})} \\ \hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\ \theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left( \frac{(\hat{m_1}^{(t)} + \alpha m_2^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)} \right).\\ S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}). \end{align*}\]注意
AdEMAMix 旨在利用非常古老的梯度。因此,该方法最适合于迭代次数重要的设置。论文在附录 C.1.5 中报告了这一效果,显示较小的
b3值(例如b3 = 0.999)在低迭代场景中可能更好。此外,在需要对突然的分布变化快速适应的领域中,或者在分布非平稳的一般情况下,保留数千步的梯度信息可能会出现问题。示例
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(jnp.square(x)) # simple quadratic function >>> solver = optax.contrib.ademamix(learning_rate=0.01) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.36E+01 Objective function: 1.35E+01 Objective function: 1.34E+01
参考文献
Pagliardini et al, The AdEMAMix Optimizer: Better, Faster, Older, 2024
- Parameters:
learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见
optax.scale_by_learning_rate().b1 – 指数衰减率,用于跟踪快速 EMA。
b2 – 指数衰减率,用于跟踪过去梯度的第二矩。
b3 – 用于跟踪慢速EMA的指数衰减率。
alpha – 快速和慢速指数移动平均线的线性组合中的混合系数。
eps – 一个小常数,应用于平方根外的分母(如在Adam论文中),以避免在重新缩放时除以零。
eps_root – 应用在平方根内分母上的一个小常数(如RMSProp中),以避免在重新缩放时出现除以零的情况。这在通过Adam计算(元)梯度时是必需的。
mu_dtype – 可选的 dtype 用于一阶累加器;如果 None 则 dtype 将从 params 和 updates 中推断。
weight_decay – 权重衰减正则化的强度。注意,这个权重衰减与学习率相乘。这与其他框架如PyTorch是一致的,但与(Loshchilov et al, 2019)不同,在该框架中,权重衰减只是与“调度乘子”相乘,而不是基础学习率。
mask - 一棵与 params PyTree 具有相同结构(或其前缀)的树,或一个可调用对象,该对象在给定 params/updates 时返回这样的 pytree。叶子应该是布尔值,对于您想要应用权重衰减的叶子/子树,应为 True,对于您想要跳过的则应为 False。请注意,Adam 梯度变换应用于所有参数。
- Returns:
相应的 GradientTransformation。
另请参见
查看相关函数
optax.adam(),optax.nadamw(), 以及示例 从论文中重建 AdeMAMix 罗斯布罗克图 的一个使用案例。
- optax.contrib.scale_by_ademamix(b1: float = 0.9, b2: float = 0.999, b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 6.0, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: chex.ArrayDType | None = None) 基础.渐变变换[来源]#
根据Ademamix算法进行缩放更新。
请参见
optax.contrib.ademamix.()以获取算法的完整描述。参考文献
Pagliardini et al, The AdEMAMix Optimizer: Better, Faster, Older, 2024
- Parameters:
b1 – 指数衰减率,用于跟踪快速 EMA。
b2 – 指数衰减率,用于跟踪过去梯度的第二矩。
b3 – 用于跟踪慢速EMA的指数衰减率。
alpha – 快速和慢速 EMA 的线性组合中的混合系数。
eps – 在平方根外应用于分母的小常数(如在Adam论文中)以避免在重新缩放时除以零。
eps_root – 应用在平方根内分母上的一个小常数(如RMSProp中),以避免在重新缩放时出现除以零的情况。这在通过Adam计算(元)梯度时是必需的。
mu_dtype – 可选的 dtype 用于一阶累加器;如果 None 则 dtype 将从 params 和 updates 中推断。
- Returns:
相应的 GradientTransformation。
- class optax.contrib.ScaleByAdemamixState(count: chex.Array, count_m2: chex.Array, m1: optax.Updates, m2: optax.Updates, nu: optax.Updates)[来源]#
Ademamix算法的状态。
- count#
用于更新快速EMA和第二时刻的算法迭代。
- Type:
chex.Array
- count_m2#
用于更新慢EMA和alpha的算法迭代。
- Type:
chex.Array
- m1#
第一个时刻的快速EMA
- Type:
base.更新
- m2#
第一时刻的慢EMA
- Type:
base.更新
- nu#
第二矩的估计
- Type:
base.更新
异步居中-属性#
- optax.contrib.acprop(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-16, eps_root: float = 1e-16, weight_decay: float = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) 基础.渐变变换[来源]#
ACProp 优化器。
遵循原始repo在PyTorch中的实现: juntang-zhuang/ACProp-Optimizer。
ACProp 是一种自适应优化器,结合了第二动量的中心化和异步更新。在步骤 t 的更新中,分母使用了直到步骤 t-1 的信息,而分子使用了步骤 t 的梯度。
Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1,b2,epsandeps_rootrespectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
initfunction of this optimizer initializes an internal state \(S_0 := (m_0, s_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ s_t &\leftarrow \beta_2 \cdot s_{t-1} + (1-\beta_2) \cdot (g_t - m_t)^2 + \bar{\varepsilon} \\ \hat{s}_t &\leftarrow s_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot g_t / \left(\sqrt{\hat{s}_{t-1}} + \varepsilon \right) \\ S_t &\leftarrow (m_t, s_t). \end{align*}\]- Parameters:
learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见
optax.scale_by_learning_rate().b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。
b2 – 指数衰减率,用于跟踪过去梯度的第二矩。
eps – 添加到分母的术语,以提高数值稳定性。
eps_root – 添加到预测误差的第二矩的术语,以提高数值稳定性。如果通过梯度变换(例如用于元学习)反向传播梯度,则此值必须非零。
weight_decay – 权重衰减正则化的强度。注意,这个权重衰减与学习率相乘。这与其他框架如PyTorch是一致的,但与(Loshchilov et al, 2019)不同,在该框架中,权重衰减只是与“调度乘子”相乘,而不是基础学习率。
mask - 一棵与 params PyTree 具有相同结构(或其前缀)的树,或一个可调用对象,该对象在给定 params/updates 时返回这样的 pytree。叶子应该是布尔值,对于您想要应用权重衰减的叶子/子树,应为 True,对于您想要跳过的则应为 False。请注意,Adam 梯度变换应用于所有参数。
- Returns:
相应的 GradientTransformation。
参考文献
Zuhang 等,动量中心化与自适应梯度方法的异步更新,2021
- optax.contrib.scale_by_acprop(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-16, eps_root: float = 1e-16) optax.GradientTransformation[来源]#
根据 ACProp(AdaBelief 的异步版本)更新缩放。
有关详细信息,请参见
optax.contrib.acprop()。- Parameters:
b1 – 指数加权平均梯度的衰减率。
b2 – 递减率,用于加权平均的方差的指数加权。
eps – 添加到分母的术语,以提高数值稳定性。
eps_root – 添加到预测误差的二阶矩上的项,以提高数值稳定性。如果通过梯度变换反向传播梯度(例如用于元学习),则必须为非零。
- Returns:
一个 GradientTransformation 对象。
复值优化#
- optax.contrib.split_real_and_imaginary(inner: optax.GradientTransformation) optax.GradientTransformation[来源]#
将复数更新的实部和虚部分成两个。
内部转换处理实际参数和更新,并将转化后的实际更新对合并为复杂更新。
在拆分之前真实的参数和更新会原封不动地传递。
- Parameters:
inner – 内部转换。
- Returns:
一个 optax.GradientTransformation。
持续投币赌博#
- optax.contrib.cocob(learning_rate: base.ScalarOrSchedule = 1.0, alpha: float = 100, eps: float = 1e-08, weight_decay: float = 0, mask: Any | Callable[[base.Params], Any] | None = None) 基础.渐变变换[来源]#
根据连续币投注算法更新重新缩放。
随机子梯度下降的算法。使用赌博算法通过访问非光滑目标函数的子梯度来找到最小化器。我们所需要的只是一个好的赌博策略。参见:
- Parameters:
learning_rate – 可选的学习率,例如用于注入某些调度器
alpha – COCOB优化器的投注参数的分数
eps – 抖动项,用于避免除以0
weight_decay – L2惩罚
mask – 权重衰减的掩码
- Returns:
一个 GradientTransformation 对象。
参考文献
Orabana 等, 通过下注游戏训练深度网络而不需要学习率, 2017
D-适应#
- optax.contrib.dadapt_adamw(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, estim_lr0: float = 1e-06, weight_decay: float = 0.0) 基础.渐变变换[来源]#
通过 D-自适应的学习率自由 AdamW。
通过估计无限范数下到解决方案的初始距离,自动调整AdamW的基线学习率。此方法在与将1.0视为基础(通常是最大)值的学习率调度结合使用时效果最佳。
- Parameters:
learning_rate – 学习率调度参数。推荐的调度是 linear_schedule,初始值为 1.0,结束值为 0,结合 0-20% 的学习率预热。
betas – AdamW优化器的贝塔值。
eps – 基础 AdamW 优化器的 eps。
estim_lr0 – 学习率的初始(低估)值。
weight_decay – AdamW 风格的权重衰减。要使用常规 Adam 衰减,请与 add_decayed_weights 链接。
- Returns:
参考文献
Defazio等人,无学习率学习通过D适应,2023
差分隐私聚合#
- optax.contrib.differentially_private_aggregate(l2_norm_clip: float, noise_multiplier: float, seed: int) optax.GradientTransformation[来源]#
基于DPSGD算法聚合梯度。
- Parameters:
l2_norm_clip – 每个示例梯度的最大 L2 范数。
noise_multiplier – 标准差与裁剪范数的比率。
seed – jax.random.PRNGKey使用的初始种子
- Returns:
参考文献
Abadi et al, 2016 差分隐私的深度学习, 2016
警告
与其他变换不同,differentially_private_aggregate 期望输入更新在第零轴上具有批处理维度。也就是说,这个函数期望每个示例的梯度作为输入(在JAX中使用jax.vmap很容易获得)。只要它是链中的第一个,它仍然可以与其他变换组合使用。
- class optax.contrib.DifferentiallyPrivateAggregateState(rng_key: Any)[来源]#
包含用于 differentially_private_aggregate 的 PRNGKey 的状态。
- optax.contrib.dpsgd(learning_rate: base.ScalarOrSchedule, l2_norm_clip: float, noise_multiplier: float, seed: int, momentum: float | None = None, nesterov: bool = False) 基础.渐变变换[来源]#
DPSGD优化器。
差分隐私是对从包含潜在敏感信息的汇总数据库中学习的算法的隐私保障标准。DPSGD为面对完全了解训练机制和访问模型参数的强大对手提供保护。
- Parameters:
learning_rate – 一个固定的全局缩放因子。
l2_norm_clip – 每个示例梯度的最大L2范数。
noise_multiplier – 标准差与裁剪规范的比率。
seed – 用于 jax.random.PRNGKey 的初始种子
动量 – 动量项使用的衰减率,当它被设置为 None 时,动量根本不被使用。
nesterov – 是否使用Nesterov动量。
- Returns:
参考文献
Abadi et al, 2016 差分隐私的深度学习, 2016
警告
这个
optax.GradientTransformation期望输入更新在第0轴上具有批次维度。也就是说,这个函数期望每个示例的梯度作为输入(在JAX中使用 jax.vmap 这很容易获得)。
沿梯度的距离#
- optax.contrib.dog(learning_rate: base.ScalarOrSchedule = 1.0, reps_rel: float = 1e-06, eps: float = 1e-08, init_learning_rate: float | None = None, weight_decay: float | None = None, mask: Any | Callable[[base.Params], Any] | None = None)[来源]#
基于梯度的优化器。
DoG 根据更新规则使用随机梯度 \(g_t\) 更新参数 \(w_t\):
\[\begin{align*} \eta_t &= \frac{\max_{i\le t}{\|x_i-x_0\|}}{ \sqrt{\sum_{i\le t}{\|g_i\|^2+eps}}}\\ x_{t+1} & = x_{t} - \eta_t\, g_t, \end{align*}\]- Parameters:
learning_rate – 可选的学习率(可能根据某些预定的调度器而变化)。
reps_rel – 用于计算初始距离的值 (论文中的 r_epsilon)。具体来说,第一步的大小由以下公式给出: (reps_rel * (1+|x_0|)) / (|g_0|^2 + eps)^{1/2} 其中 x_0 是 模型的初始权重(或参数组),而 g_0 是第一步的梯度。 正如论文中讨论的,这个值应该足够小,以确保 第一步更新将足够小,以避免模型发散。 建议值为 1e-6,除非模型使用批量归一化, 在这种情况下,建议值为 1e-4。
eps – 用于数值稳定性的epsilon - 添加到梯度的平方范数之和。
init_learning_rate – 如果指定,此值将用于初始学习率 (即第一步的大小),而不是上述与 reps_rel 描述的规则。
weight_decay – 权重衰减正则化的强度。
mask – 一棵与 params PyTree 结构相同(或是其前缀)的树,或者是一个返回这样的 pytree 的可调用对象,给定参数/更新。 树叶应该是布尔值,True 表示要应用权重衰减的叶子/子树,False 表示要跳过的那些。注意,梯度变换会应用于所有参数。
- Returns:
示例
>>> import optax >>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.dog() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value) ... params = optax.apply_updates(params, updates) ... print('Objective function: ', f(params)) Objective function: 13.99... Objective function: 13.99... Objective function: 13.99... Objective function: 13.99... Objective function: 13.99...
参考文献
Ivgi 等人, DoG is SGD’s Best Friend: A Parameter-Free Dynamic Step Size Schedule, 2023.
在版本 0.2.3 中添加。
- class optax.contrib.DoGState(first_step: jax.Array, init_params: chex.ArrayTree, estim_dist: jax.Array, sum_sq_norm_grads: jax.Array)[来源]#
DoG优化器的状态。
- optax.contrib.dowg(learning_rate: base.ScalarOrSchedule = 1.0, init_estim_sq_dist: float | None = None, eps: float = 0.0001, weight_decay: float | None = None, mask: Any | Callable[[base.Params], Any] | None = None)[来源]#
加权梯度优化器的距离。
示例
>>> import optax >>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.dowg() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value) ... params = optax.apply_updates(params, updates) ... print('Objective function: ', f(params)) Objective function: 13.925367 Objective function: 13.872763 Objective function: 13.775433 Objective function: 13.596172 Objective function: 13.268837
参考文献
Khaled 等, DoWG Unleashed: 一种高效的通用无参数梯度下降方法, 2023.
- Parameters:
learning_rate – 可选的学习率(可能根据某些预定的调度器而变化)。
init_estim_sq_dist – 到解的平方距离的初始猜测。
eps – 小值,用于防止分母中的除以零,定义学习率,且当
init_estim_sq_dist为 None 时也用作解决方案的初始猜测距离。weight_decay – 权重衰减正则化的强度。
mask – 一棵具有与 params PyTree 相同结构(或前缀)的树,或者一个在给定参数/更新时返回此类 pytree 的可调用对象。叶子应该是布尔值,True 表示要应用权重衰减的叶子/子树,而 False 表示要跳过的那些。请注意,梯度变换应用于所有参数。
- Returns:
在版本 0.2.3 中添加。
机械装置#
- optax.contrib.mechanize(base_optimizer: optax.GradientTransformation, weight_decay: float = 0.01, eps: float = 1e-08, s_init: float = 1e-06, num_betas: int = 6) optax.GradientTransformation[来源]#
机械师 - 一个黑箱学习率调整器/优化器。
累积由 base_optimizer 返回的更新,并学习每次迭代应用的更新规模(也称为学习率或步长)。
请注意,Mechanic 并不回避学习率调度的需求,您可以自由地应用学习率调度,基础学习率设置为 1.0(或任何其他常数),而 Mechanic 将自动学习正确的比例因子。
例如,将此更改为:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) optimizer = optax.adam(learning_rate_fn)
收件人:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0) optimizer = optax.adam(learning_rate_fn) optimizer = optax.contrib.mechanize(optimizer)
截至2023年6月,Mechanic已使用SGD、Momentum、Adam和Lion作为内部优化器进行测试,但我们预计它能与几乎所有一阶优化器(除了像LARS或LAMB这样的归一化梯度优化器)配合工作。
- Parameters:
base_optimizer – 计算更新的基础优化器。
weight_decay – 一个标量权重衰减率。请注意,这个权重衰减与您为base_optimizer使用的权重衰减不同。除了有时帮助更快收敛外,这还有助于Mechanic减少使用不同随机种子的训练运行之间的方差。您可能不需要调整这个,默认值在大多数情况下应该能够正常工作。
eps – 机械的epsilon。
s_init – 初始缩放因子。默认值几乎在所有情况下都能正常工作。
num_betas – 与传统的指数累加器(如adam的第一或第二矩)不同,在这些累加器中需要选择一个明确的beta,而mechanic则有一个巧妙的方法来自动学习所有累加器的正确beta。我们只提供可能的beta范围,而不是调优的值。例如,如果将num_betas设置为3,它将使用betas = [0.9, 0.99, 0.999]。
- Returns:
参考文献
Cutkosky 等人, Mechanic: A Learning Rate Tuner 2023
Momo#
- optax.contrib.momo(learning_rate: base.ScalarOrSchedule = 1.0, beta: float = 0.9, lower_bound: float = 0.0, weight_decay: float = 0.0, adapt_lower_bound: bool = False) 基础.渐变变换额外参数[来源]#
带动量的SGD自适应学习率。
MoMo 通常在
learning_rate的值上需要更少的调整,利用了已知损失的下限(或最优值)的事实。对于大多数任务,零是一个下限,也是最终损失的准确估计。MoMo执行带动量的随机梯度下降(SGD),使用聚类型学习率。有效步长为
min(learning_rate, <adaptive term>),其中自适应项是实时计算的。请注意,需要通过关键字参数
value将最新的(批次)损失值传递给更新函数。- Parameters:
learning_rate – 用户指定的学习率。建议选择较大的值,默认值为1.0。
beta – 动量系数(用于EMA)。
lower_bound – 损失的下界。对于许多任务,零应该是一个不错的选择。
weight_decay – 权重衰减参数。
adapt_lower_bound – 如果没有合适的下界猜测可用,请将其设置为 true,以便动态估计下界(有关详细信息,请参阅论文)。
- Returns:
一个
optax.GradientTransformation对象。
示例
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.momo() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 3.5 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0
参考文献
Schaipp 等人,MoMo: 动量模型用于自适应学习率,2023
在版本 0.2.3 中添加。
- class optax.contrib.MomoState(exp_avg: optax.Updates, barf: chex.Array, gamma: chex.Array, lb: chex.Array, count: chex.Array)[来源]#
momo 返回的 GradientTransformation 的状态。
- optax.contrib.momo_adam(learning_rate: base.ScalarOrSchedule = 0.01, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, lower_bound: float = 0.0, weight_decay: float = 0.0, adapt_lower_bound: bool = False) 基础.渐变变换额外参数[来源]#
自适应学习率用于Adam(W)。
MoMo-Adam 通常需要更少的调整对于
learning_rate的值,利用了已知损失的下界(或最优值)的事实。对于大多数任务,零是下界,并且是最终损失的准确估计。MoMo执行带有Polyak类型学习率的Adam(W)。有效步长为
min(learning_rate, <adaptive term>),自适应项是即时计算的。请注意,需要通过关键字参数
value将最新的(批次)损失值传递给更新函数。- Parameters:
learning_rate – 用户指定的学习率。建议选择较大的值,默认值为1.0。
b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。
b2 – 指数衰减率,用于跟踪过去梯度的第二矩。
eps – 基础Adam优化器的 eps。
lower_bound – 损失的下界。对于许多任务,零应该是一个不错的选择。
weight_decay – 权重衰减参数。Momo-Adam以类似于AdamW的方式执行权重衰减。
adapt_lower_bound – 如果没有合适的下界猜测可用,请将其设置为 true,以便动态估计下界(有关详细信息,请参阅论文)。
- Returns:
A
GradientTransformation对象。
示例
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.momo_adam() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 0.00029999594 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0
参考文献
Schaipp 等人,MoMo: 动量模型用于自适应学习率,2023
在版本 0.2.3 中添加。
缪子#
- optax.contrib.muon(learning_rate: base.ScalarOrSchedule, ns_coeffs: tuple[float, float, float] | tuple[tuple[float, float, float], ...] = (3.4445, -4.775, 2.0315), ns_steps: int = 5, beta: float = 0.95, eps: float = 1e-08, mu_dtype: chex.ArrayDType | None = None, *, nesterov: bool = True, adaptive: bool = False, adam_b1: float = 0.9, adam_b2: float = 0.999, adam_eps_root: float = 0.0, adam_weight_decay: float = 0.0) 基础.渐变变换[来源]#
穆昂:由牛顿-舒尔茨正交化的动量。
Muon是Shampoo的一种变体,它使用Newton-schulz方法来正交化优化器积累的动量。从数学上讲,它在某个大p下执行最速下降。在p=infty的情况下,它等同于不进行积累的Shampoo,或者在谱范数下的最速下降。
请注意,Muon 当前仅针对 2D 参数定义,即矩阵。这是因为 Newton-Schulz 迭代器期望一个矩阵作为输入。而非 2D 参数则通过 Adam 优化器传递。
- Parameters:
learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见
optax.scale_by_learning_rate().ns_coeffs – 牛顿-舒尔茨方法的系数。
ns_steps - 牛顿-舒尔茨迭代次数。 如果 ns_coeffs 是一个元组的元组,则忽略此参数。
beta – 指数加权平均梯度的衰减率。
eps – 添加到分母的术语,以提高数值稳定性。
mu_dtype – 动量累加器的数据类型。
nesterov – 是否使用Nesterov动量。
自适应 – 是否通过原始更新的对偶范数来缩放更新。请参阅 <https://arxiv.org/abs/2409.20325>
adam_b1 – Adam 第一个时刻估计的指数衰减率。
adam_b2 – 亚当第二矩估计的指数衰减率。
adam_eps_root – 用于稳定Adam中除法的Epsilon,平方根版本。
adam_weight_decay – Adam的权重衰减因子。
- Returns:
相应的 GradientTransformation。
参考文献
乔丹, modded-nanogpt: Speedrunning the NanoGPT baseline,2024
伯恩斯坦等,旧优化器,新常态:选集,2024
- optax.contrib.scale_by_muon(ns_coeffs: tuple[float, float, float] | tuple[tuple[float, float, float], ...] = (3.4445, -4.775, 2.0315), ns_steps: int = 5, beta: float = 0.95, eps: float = 1e-08, mu_dtype: str | type[Any] | 数据类型 | SupportsDType | None = None, *, nesterov: bool = True, adaptive: bool = False) optax.GradientTransformation[来源]#
根据Muon算法更新重新缩放。
Muon是Shampoo的一种变体,它使用Newton-schulz方法来正交化优化器积累的动量。从数学上讲,它在某个大p下执行最速下降。在p=infty的情况下,它等同于不进行积累的Shampoo,或者在谱范数下的最速下降。
- Parameters:
ns_coeffs – 牛顿-舒尔茨方法的系数。
ns_steps - 牛顿-舒尔茨迭代次数。 如果 ns_coeffs 是一个元组的元组,则忽略此参数。
beta – 指数加权平均梯度的衰减率。
eps – 添加到分母以提高数值稳定性的术语。
mu_dtype – 动量累加器的数据类型。
nesterov – 是否使用Nesterov动量。
自适应 – 是否通过原始更新的对偶范数来缩放更新。请参阅 <https://arxiv.org/abs/2409.20325>
- Returns:
一个 GradientTransformation 对象。
参考文献
乔丹, modded-nanogpt: Speedrunning the NanoGPT baseline,2024
伯恩斯坦等,旧优化器,新常态:选集,2024
天才#
- optax.contrib.prodigy(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[float, float] = (0.9, 0.999), beta3: float | None = None, eps: float = 1e-08, estim_lr0: float = 1e-06, estim_lr_coef: float = 1.0, weight_decay: float = 0.0, safeguard_warmup: bool = False) 基础.渐变变换[来源]#
使用Prodigy的无学习率AdamW。
来自“Prodigy: An Expeditiously Adaptive Parameter-Free Learner”的Prodigy方法的实现,这是D-Adapt AdamW的一个版本,通过使用对梯度的加权,使基线学习率更快地适应,赋予较新的梯度更高的权重。结合将1.0视为基准(通常为最大)值的学习率调度时,这种方法效果最好。
- Parameters:
learning_rate – 学习率调度参数。推荐的调度是 linear_schedule,初始值为 1.0,结束值为 0,结合 0-20% 的学习率预热。
betas – AdamW优化器的贝塔值。
beta3 – D的估计的可选动量参数。
eps – 基础 AdamW 优化器的 eps。
estim_lr0 – 学习率的初始(低估)值。
estim_lr_coef – LR估计值乘以此参数。
weight_decay – AdamW 风格的权重衰减。要使用常规 Adam 衰减,请与 add_decayed_weights 链接。
safeguard_warmup – 从D估计的分母中移除lr,以避免在热身阶段出现问题。默认关闭。
- Returns:
一个
optax.GradientTransformation对象。
参考文献
Mishchenko et al, Prodigy: An Expeditiously Adaptive Parameter-Free Learner, 2023
无需安排#
- optax.contrib.schedule_free(base_optimizer: 基础.渐变变换, learning_rate: base.ScalarOrSchedule, b1: float = 0.9, weight_lr_power: float = 2.0, state_dtype: jax.typing.DTypeLike | None = None) 基础.渐变变换额外参数[来源]#
将 base_optimizer 设置为 schedule_free。
累积由 base_optimizer 返回的更新,无需动量,并用插值和平均的组合替换底层优化器的动量。在梯度下降的情况下,更新为
\[\begin{align*} y_{t} & = (1-\beta_1)z_{t} + \beta_1 x_{t},\\ z_{t+1} & =z_{t}-\gamma\nabla f(y_{t}),\\ x_{t+1} & =\left(1-\frac{1}{t}\right)x_{t}+\frac{1}{t}z_{t+1}, \end{align*}\]这里 \(x\) 是测试/验证损失评估应该发生的序列,它与主要迭代 \(z\) 和梯度评估位置 \(y\) 不同。对 \(z\) 的更新对应于底层优化器,在这种情况下是一个简单的梯度步骤。请注意, \(\beta_1\) 对应于代码中的 b1。
顾名思义,无调度学习不需要逐渐减少的学习率调度,但通常表现优于或至多匹配SOTA调度,例如余弦衰减和线性衰减。一次只需存储两个序列(第三个可以实时从另外两个计算出来),因此此方法的内存需求与基础优化器相同(参数缓冲区 + 动量)。
在实践中,作者建议针对每个问题单独调整 \(\beta_1\)、warmup_steps 和 peak_lr。\(\beta_1\) 的默认值为 0.9,但 0.95 和 0.98 也可能表现良好。无调度可以在任何 optax 优化器之上包装。在测试时,应该使用
optax.contrib.schedule_free_eval_params()来评估参数,如下所示。例如,将此更改为:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) optimizer = optax.adam(learning_rate_fn, b1=b1)
收件人:
learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr) optimizer = optax.adam(learning_rate_fn, b1=0.) optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1) .. params_for_eval = optax.contrib.schedule_free_eval_params(state, params)
尤其要注意,关闭基础优化器的动量是很重要的。截止到2024年4月,schedule_free已使用SGD和Adam进行测试。
- Parameters:
base_optimizer – 计算更新的基础优化器。
learning_rate – 不带衰减的学习率调度,但带有预热。
b1 – y 更新中的 beta_1 参数。
weight_lr_power – 我们使用这个来降低平均权重。这个在热身期间的早期迭代中特别有帮助。
state_dtype – 在调度自由方法中用于 z 序列的 dtype。
- Returns:
参考文献
Defazio 等, The Road Less Scheduled, 2024
Defazio 等人, Schedule-Free Learning - A New Way to Train, 2024
警告
当前的实现要求参数
b1必须为正数。
- optax.contrib.schedule_free_adamw(learning_rate: float = 0.0025, warmup_steps: int | None = None, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, weight_decay: float = 0.0, weight_lr_power: float = 2.0, state_dtype: jax.typing.DTypeLike | None = None) 基础.渐变变换额外参数[来源]#
无调度的AdamW封装类。
使用AdamW的schedule_free的快捷示例,这是一个常见的使用案例。请注意,这只是一个示例,还有其他使用案例,例如使用权重衰减掩码、Nesterov等。还请注意,schedule free方法的EMA参数(b1)必须是严格正的。
- Parameters:
learning_rate – AdamW 学习率。
warmup_steps – 正整数,线性预热的长度。
b1 – y 更新中的 beta_1 参数。
b2 – 指数衰减率,用于跟踪过去梯度的第二矩。
eps – 在平方根外应用于分母的小常数(如在Adam论文中)以避免在重新缩放时除以零。
weight_decay – 权重衰减正则化的强度。
weight_lr_power – 我们使用这个来降低平均权重。这个在热身期间的早期迭代中特别有帮助。
state_dtype – 在调度自由方法中用于 z 序列的 dtype。
- Returns:
示例
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.contrib.schedule_free_adamw(1.0) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... eval_params = optax.contrib.schedule_free_eval_params( ... opt_state, params) ... print('Objective function: {:.2E}'.format(f(eval_params))) Objective function: 5.00E+00 Objective function: 3.05E+00 Objective function: 1.73E+00 Objective function: 8.94E-01 Objective function: 4.13E-01
注意
Note that
optax.scale_by_adam()withb1=0stores in its state an unused first moment always equal to zero. To avoid this waste of memory, we replaceoptax.scale_by_adam()withb1=0by the equivalentoptax.scale_by_rms()witheps_in_sqrt=False, bias_correction=True.
- optax.contrib.schedule_free_eval_params(state: optax.OptState, params: optax.Params)[来源]#
用于评估
optax.contrib.schedule_free()的参数。
- optax.contrib.schedule_free_sgd(learning_rate: float = 1.0, warmup_steps: int | None = None, b1: float = 0.9, weight_decay: float | None = None, weight_lr_power: float = 2.0, state_dtype: jax.typing.DTypeLike | None = None) 基础.渐变变换额外参数[来源]#
无调度的SGD包装器。
使用SGD的schedule_free的快捷示例,这是一个常见的用例。 注意,这只是一个示例,其他用例也是可能的,例如使用权重衰减掩码。 还要注意,schedule free方法的EMA参数(b1)必须严格为正。
- Parameters:
learning_rate – SGD 学习率。
warmup_steps – 正整数,线性预热的长度。
b1 – y 更新中的 beta_1 参数。
weight_decay – 权重衰减正则化的强度。请注意,这个权重衰减是与学习率相乘的。这与其他框架如PyTorch是一致的,但与(Loshchilov et al, 2019)不同,在该框架中,权重衰减仅与“调度乘子”相乘,而不是与基本学习率相乘。
weight_lr_power – 我们使用这个来降低平均权重。这个在热身期间的早期迭代中特别有帮助。
state_dtype – 在调度自由方法中用于 z 序列的 dtype。
- Returns:
示例
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.contrib.schedule_free_sgd() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... eval_params = optax.contrib.schedule_free_eval_params( ... opt_state, params) ... print('Objective function: {:.2E}'.format(f(eval_params))) Objective function: 1.40E+01 Objective function: 1.75E-14 Objective function: 9.96E-01 Objective function: 8.06E-01 Objective function: 2.41E-01
索非亚#
- optax.contrib.hutchinson_estimator_diag_hessian(random_seed: 数组 | None = None)[来源]#
返回一个计算Hessian对角线的GradientTransformation。
海森对角线是使用哈钦森估计量估计的,该估计量是无偏的,但具有高方差。
- Parameters:
random_seed – 生成随机向量的密钥。
- Returns:
梯度变换额外参数
- optax.contrib.sophia(learning_rate: base.ScalarOrSchedule, b1: float = 0.965, b2: float = 0.99, eps: float = 1e-08, weight_decay: float = 0.0001, weight_decay_mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, gamma: float = 0.01, clip_threshold: Optional[float] = 1.0, update_interval: int = 10, hessian_diagonal_fn: Union[base.GradientTransformation, base.GradientTransformationExtraArgs] = (<function hutchinson_estimator_diag_hessian.<locals>.init_fn>, <function hutchinson_estimator_diag_hessian.<locals>.update_fn>), mu_dtype: Optional[Any] = None, verbose: bool = False, print_win_rate_every_n_steps: int = 0) 基础.渐变变换额外参数[来源]#
Sophia 优化器。
通过参数 hessian_diagonal_fn 需要一个单独的 GradientTransformation 来计算 Hessian 的对角线。hessian_diagonal_fn 的更新函数所需的任何额外参数可以通过 sophia 的更新函数作为尾部关键字参数 (**kwargs) 传递。默认的 hessian_diagonal_fn 是 Hutchinson 的估计器,需要目标函数作为额外参数 obj_fn。obj_fn 必须接受 params 作为其唯一参数,并且仅返回一个标量(损失)。
例如,假设你的实验损失函数是 loss_fn(params, batch) -> loss, aux,它接受多个参数并且 返回多个输出,我们必须将其修改为 loss_fn(params) -> loss:
obj_fn = lambda params: loss_fn(params, batch)[0]
其中 batch 是当前步骤的批次。
然后它可以被传递给sophia的更新函数(这个函数会将它传递给hessian_diagonal_fn的更新函数):
更新,状态 = sophia.update(更新,状态,参数,obj_fn=sophia_obj_fn)
可选地,您可以编写自己的GradientTransformation来计算海森矩阵对角线。使用此文件的hutchinson_estimator_diag_hessian函数作为示例。如果您使用多个设备,请确保海森矩阵对角线函数正确地在设备之间平均海森矩阵对角线。默认的hessian_diagonal_fn不执行此操作,例如,如果使用pmap,会导致参数在设备之间彼此发散。
- Parameters:
learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见
optax.scale_by_learning_rate().b1 – 第一个时刻估计的指数衰减率。
b2 – hessian 对角线估计的指数衰减率。请记住,有效的 b2 是 1 - (1 - b2) / update_interval,例如,默认的 b2 为 0.99,实际上是 0.999,因为默认的 update_interval 是每 10 次。
eps – 避免除以零的小常数。
weight_decay – 权重衰减的速率。
weight_decay_mask – 一个与params PyTree结构相同(或其前缀)的树,或者一个返回这样的pytree的可调用对象,给定params/updates。叶子应该是布尔值,True表示您想对其应用转换的叶子/子树,而False表示您想跳过的那些。
gamma – 赫西恩对角线的归一化常数。
clip_threshold – 更新剪切的阈值。
update_interval – 更新海森对角线的间隔。
hessian_diagonal_fn – 计算Hessian对角线的GradientTransformation。默认是Hutchinson的估计器(sophia-h)。如果使用多个设备,请确保此函数正确地在设备之间平均Hessian对角线。
mu_dtype – 第一个矩估计的dtype。
verbose – 如果为真, 每n步打印一次胜率。
print_win_rate_every_n_steps – 每n步打印sophia的胜率,用于诊断目的。作者指出该值在训练期间应保持在0.1和0.5之间。如果胜率太低,请尝试增加gamma。0表示关闭。
- Returns:
optax.GradientTransformationExtraArgs
参考文献
Liu et al., Sophia: 一种可扩展的随机二阶优化器用于语言模型预训练, 2023
注意
我们使用一个Rademacher向量来估计Hessian的对角线, 与原始实现使用的正态随机向量相反。