优化器#

adabelief(learning_rate[, b1, b2, eps, ...])

AdaBelief 优化器。

adadelta([learning_rate, rho, eps, ...])

Adadelta 优化器。

adan(learning_rate[, b1, b2, b3, eps, ...])

自适应Nesterov动量算法(Adan)。

adafactor([学习率, ...])

Adafactor 优化器。

adagrad(学习率[, ...])

Adagrad 优化器。

adam(学习率[, b1, b2, eps, eps_root, ...])

Adam 优化器。

adamw(learning_rate[, b1, b2, eps, ...])

带权重衰减正则化的Adam。

adamax(学习率[, b1, b2, eps])

一种使用无穷范数的Adam优化器变体。

adamaxw(学习率[, b1, b2, eps, ...])

具有权重衰减正则化的Adamax。

amsgrad(学习率[, b1, b2, eps, ...])

AMSGrad优化器。

fromage(learning_rate[, min_norm])

Frobenius 匹配梯度下降(Fromage)优化器。

lamb(学习率[, b1, b2, eps, eps_root, ...])

LAMB 优化器。

lars(学习速率[, 权重衰减, ...])

LARS优化器。

lbfgs([学习率, 记忆大小, ...])

L-BFGS优化器。

lion(learning_rate[, b1, b2, mu_dtype, ...])

狮子优化器。

nadam(学习率[, b1, b2, eps, ...])

NAdam优化器。

nadamw(学习率[, b1, b2, eps, ...])

NAdamW 优化器,作为 AdamW 优化器的一部分实现。

noisy_sgd(learning_rate[, eta, gamma, seed])

一种带有附加噪声的SGD变体。

novograd(learning_rate[, b1, b2, eps, ...])

NovoGrad 优化器。

optimistic_gradient_descent(learning_rate[, ...])

一种乐观的梯度下降优化器。

optimistic_adam(学习率[, 乐观, ...])

乐观的Adam优化器。

polyak_sgd([max_learning_rate, scaling, ...])

带有Polyak步长的随机梯度下降。

radam(学习率[, b1, b2, eps, ...])

修正的Adam优化器。

rmsprop(学习率[, 衰减, 小量, ...])

一个灵活的 RMSProp 优化器。

sgd(学习率[, 动量, Nesterov, ...])

一个经典的随机梯度下降优化器。

sign_sgd(learning_rate)

一种仅使用梯度分量符号的SGD变体。

sm3(学习率[, 动量])

SM3优化器。

yogi(学习率[, b1, b2, eps])

Yogi优化器。

AdaBelief#

optax.adabelief(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-16, eps_root: float = 1e-16, *, nesterov: bool = False) 基础.渐变变换[来源]#

AdaBelief优化器。

AdaBelief 是一种自适应学习率优化器,关注快速收敛、泛化和稳定性。它根据对梯度方向的“信念”调整步长——优化器根据预测梯度和观察到的梯度之间的差异自适应地缩放步长。AdaBelief 是 optax.adam() 的一个修改版本,包含相同数量的参数。

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, s_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ s_t &\leftarrow \beta_2 \cdot s_{t-1} + (1-\beta_2) \cdot (g_t - m_t)^2 + \bar{\varepsilon} \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{s}_t &\leftarrow s_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left(\sqrt{\hat{s}_t} + \varepsilon \right) \\ S_t &\leftarrow (m_t, s_t). \end{align*}\]

使用关键字参数 nesterov=True,优化器使用Nesterov动量,将上面的 \(\hat{m}_t\) 替换为

\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. \]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 添加到分母的术语,以提高数值稳定性。

  • eps_root – 添加到预测误差的第二矩的术语,以提高数值稳定性。如果通过梯度变换(例如用于元学习)反向传播梯度,则此值必须非零。

  • nesterov – 是否使用Nesterov动量。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adabelief(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01

参考文献

庄, AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients, 2020

AdaDelta#

optax.adadelta(learning_rate: base.ScalarOrSchedule | None = None, rho: float = 0.9, eps: float = 1e-06, weight_decay: float = 0.0, weight_decay_mask: MaskOrFn = None) 基础.渐变变换[来源]#

Adadelta 优化器。

Adadelta 是一种随机梯度下降方法,它根据梯度更新的移动窗口调整学习率。Adadelta 是对 Adagrad 的一种修改。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • rho – 用于计算平方梯度的滑动平均的系数。

  • eps – 添加到分母的术语,以提高数值稳定性。

  • weight_decay – 可选择的权重衰减速率。

  • weight_decay_mask – 一个与params PyTree结构相同(或其前缀)的树,或者一个返回这样的pytree的可调用对象,给定params/updates。叶子应该是布尔值,True表示您想对其应用转换的叶子/子树,而False表示您想跳过的那些。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> f = lambda x: jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adadelta(learning_rate=10.)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.36E+01
Objective function: 1.32E+01
Objective function: 1.29E+01
Objective function: 1.25E+01
Objective function: 1.21E+01

参考文献

Zeiler, Adadelta: An Adaptive Learning Rate Optimizer, 2012

阿丹#

optax.adan(learning_rate: base.ScalarOrSchedule, b1: float = 0.98, b2: float = 0.92, b3: float = 0.99, eps: float = 1e-08, eps_root: float = 1e-08, weight_decay: float = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) 基础.渐变变换[来源]#

自适应Nesterov动量算法(Adan)。

Adan 首先重新定义了经典的 Nesterov 加速方法,开发了一种新的 Nesterov 动量估计 (NME) 方法,该方法避免了在外推点计算梯度的额外开销。然后,Adan 采用 NME 来估计自适应梯度算法中梯度的一阶和二阶矩,以加速收敛。

算法如下。首先,我们定义以下参数:

  • \(\eta > 0\): 步长。

  • \(\beta_1 \in [0, 1]\): 指数加权平均梯度的衰减率。

  • \(\beta_2 \in [0, 1]\): 指数加权平均梯度差的衰减率。

  • \(\beta_3 \in [0, 1]\): 指数加权平方项的衰减率。

  • \(\varepsilon > 0\): 一个用于数值稳定性的小常数。

  • \(\lambda > 0\): 权重衰减。

其次,我们定义以下变量:

  • \(\theta_t\): 参数。

  • \(g_t\): 进入的随机梯度。

  • \(m_t\): 梯度的指数加权平均。

  • \(v_t\): 梯度差异的指数加权平均。

  • \(n_t\): 平方项的指数加权平均。

  • \(u_t\): 外部更新向量。

  • \(S_t\): 优化器的保存状态。

第三,我们如下初始化这些变量:

  • \(m_0 = g_0\)

  • \(v_0 = 0\)

  • \(v_1 = g_1 - g_0\)

  • \(n_0 = g_0^2\)

最后,在每次迭代中,我们按如下方式更新变量:

\[\begin{align*} m_t &\gets (1 - \beta_1) m_{t-1} + \beta_1 g_t \\ v_t &\gets (1 - \beta_2) v_{t-1} + \beta_2 (g_t - g_{t-1}) \\ n_t &\gets (1 - \beta_3) n_{t-1} + \beta_3 (g_t + (1 - \beta_2) (g_t - g_{t-1}))^2 \\ \eta_t &\gets \eta / ({\sqrt{n_t + \bar{\varepsilon}} + \varepsilon}) \\ u_t &\gets (\theta_t - \eta_t \circ (m_t + (1 - \beta_2) v_t)) / (1 + \lambda \eta) \\ S_t &\leftarrow (m_t, v_t, n_t). \end{align*}\]
Parameters:
  • learning_rate – 这是一个固定的全局缩放因子。

  • b1 – 均值加权移动平均(EWMA)梯度的衰减率。

  • b2 – 较差的梯度的EWMA衰减率。

  • b3 – 算法平方项的EMWA的衰减率。

  • eps – 添加到分母的术语,以提高数值稳定性。

  • eps_root – 添加到平方根内分母的项,以提高在通过重缩放反向传播梯度时的数值稳定性。

  • weight_decay – 权重衰减正则化的强度。

  • mask – 一棵具有与参数 PyTree 相同结构(或其前缀)的树,或一个可以根据参数/更新返回该 pytree 的可调用对象。叶子节点应为布尔值,True 表示希望应用权重衰减的叶子/子树,而 False 表示希望跳过的那些。

Returns:

相应的 optax.GradientTransformation.

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> f = lambda x: x @ x  # simple quadratic function
>>> solver = optax.adan(learning_rate=1e-1)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.28E+01
Objective function: 1.17E+01
Objective function: 1.07E+01
Objective function: 9.68E+00
Objective function: 8.76E+00

参考文献

Xie et al, Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models, 2022

AdaGrad#

optax.adagrad(learning_rate: base.ScalarOrSchedule, initial_accumulator_value: float = 0.1, eps: float = 1e-07) 基础.渐变变换[来源]#

Adagrad 优化器。

AdaGrad 是一种用于随机优化的子梯度算法,它根据每个特征的梯度历史单独调整学习率。

更新后的参数采用如下形式:

\[w_{t+1}^{(i)} = w_{t}^{(i)} - \eta \frac{g_{t}^{(i)}} {\sqrt{\sum_{\tau=1}^{t} (g_{\tau}^{(i)})^2 + \epsilon}}\]
where:
  • \(w_t^{(i)}\) 是时间步 \(t\) 时的参数 \(i\)

  • \(\eta\) 是学习率,

  • \(g_t^{(i)}\) 是时间步 \(t\) 在参数 \(i\) 的梯度,

  • \(\epsilon\) 是一个小常数,用以确保数值稳定性。

定义 \(G = \sum_{t=1}^\tau g_t g_t^\top\),更新可以表示为

\[w_{t+1} = w_{t} - \eta \cdot \text{diag}(G + \epsilon I)^{-1/2} \cdot g_t\]

其中 \(\text{diag} (G) = (G_{ii})_{i=1}^p\)\(G \in \mathbb{R}^p\) 的对角元素的向量,\(I\)\(\mathbb{R}^p\) 中的单位矩阵。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • initial_accumulator_value – 积累器的初始值。

  • eps – 应用于平方根内部分母的小常数(如在RMSProp中)以避免在重新缩放时除以零。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adagrad(learning_rate=1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 5.01E+00
Objective function: 2.40E+00
Objective function: 1.25E+00
Objective function: 6.86E-01
Objective function: 3.85E-01

参考文献

Duchi 等人,自适应子梯度方法用于在线学习和随机优化,2011

警告

Adagrad的主要限制是分母中平方梯度的单调累积:由于所有项都 >0,和在训练过程中不断增加,学习速率最终变得微乎其微。

AdaFactor#

optax.adafactor(learning_rate: Optional[base.ScalarOrSchedule] = None, min_dim_size_to_factor: int = 128, decay_rate: float = 0.8, decay_offset: int = 0, multiply_by_parameter_scale: float = True, clipping_threshold: Optional[float] = 1.0, momentum: Optional[float] = None, dtype_momentum: Any = <class 'jax.numpy.float32'>, weight_decay_rate: Optional[float] = None, eps: float = 1e-30, factored: bool = True, weight_decay_mask: MaskOrFn = None) 基础.渐变变换[来源]#

Adafactor优化器。

Adafactor 是一种自适应学习率优化器,专注于快速训练大规模神经网络。它通过使用一个分解的二阶矩估计来缩放梯度,从而节省内存。

Parameters:
  • learning_rate – 全局缩放因子,可以是固定的,也可以随着调度程序的迭代而变化,见 optax.scale_by_learning_rate()。请注意,Adafactor 的学习率的自然尺度与 Adam 明显不同,使用基于注意力的模型时,不需要为此优化使用 1/sqrt(hidden) 修正。

  • min_dim_size_to_factor – 仅在两个数组维度至少达到此大小时,才对统计数据进行因子化。

  • decay_rate – 控制二阶矩指数衰减计划。

  • decay_offset – 对于微调,可以将其设置为微调阶段的起始步骤编号。

  • multiply_by_parameter_scale – 如果为真,则按参数范数缩放学习率。 如果为假,提供的学习率是绝对步长。

  • clipping_threshold – 可选的裁剪阈值。必须大于等于 1。如果为 None,裁剪将被禁用。

  • 动量 – 介于 0 和 1 之间的可选值,启用动量,如果非 None,则使用额外的内存! 默认为 None。

  • dtype_momentum – 动量缓冲区的数据类型。

  • weight_decay_rate – 可选的权重衰减率。

  • eps – 均方根梯度的正则化常数。

  • factored – 是否使用分解的第二时刻估计。

  • weight_decay_mask – 与params PyTree结构相同(或为其前缀)的树,或者一个可调用对象,它在给定params/updates时返回这样的pytree。叶子应该是布尔值,True 用于您想要应用变换的叶子/子树,False 用于您想跳过的那些。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adafactor(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

参考文献

Shazeer et al, Adafactor: Adaptive Learning Rates with Sublinear Memory Cost, 2018

亚当#

optax.adam(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, *, nesterov: bool = False) 基础.渐变变换[来源]#

Adam优化器。

Adam 是一种具有梯度缩放适应性的 SGD 变体。每个参数使用的缩放是根据梯度的一阶和二阶矩的估计计算的(使用合适的指数移动平均)。

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]

使用关键字参数 nesterov=True,优化器使用Nesterov动量,将上面的 \(\hat{m}_t\) 替换为

\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. \]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 一个小常数,应用于平方根外的分母(如在Adam论文中),以避免在重新缩放时除以零。

  • eps_root – 一个应用于平方根内部分母的小常数(如在RMSProp中),用于避免在重新缩放时除以零。这在通过Adam计算(元)梯度时是必要的。

  • mu_dtype – 可选的 dtype 用于一阶累加器;如果 Nonedtype 将从 paramsupdates 中推断。

  • nesterov – 是否使用Nesterov动量。使用nesterov=True的求解器相当于optax.nadam()优化器,并在[Dozat 2016]中描述。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

参考文献

金马等人, Adam:一种随机优化的方法,2014

Dozat, 将Nesterov动量纳入Adam, 2016

警告

PyTorch和optax的实现遵循[Kingma et al. 2014]的算法1。请注意,TensorFlow使用的是论文第2.1节之前的公式。有关更多细节,请参见 deepmind/optax#571

另请参见

optax.nadam(), optax.adamw().

亚当最大#

optax.adamax(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08) 基础.渐变变换[来源]#

一种使用无限范数的Adam优化器变体。

AdaMax 是 optax.adam() 优化器的一个变体。通过将 Adam 的 \(L^2\) 范数推广到 \(L^p\) 范数,并在 \(p \rightarrow \infty\) 时取极限,我们得到一个简单而稳定的更新规则。

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\) represent the arguments b1, b2 and eps respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \max(\left| g_t \right| + \varepsilon, \beta_2 \cdot v_{t-1}) \\ \hat{m}_t &\leftarrow m_t / (1-\beta_1^t) \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / v_t \\ S_t &\leftarrow (m_t, v_t). \end{align*}\]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 用于跟踪过去梯度最大值的指数衰减率。

  • eps – 应用于分母的一个小常数,用于避免在重新缩放时除以零。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adamax(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

参考文献

Kingma 等,2014: https://arxiv.org/abs/1412.6980

另请参见

optax.adam(), optax.adamaxw().

亚当最大化W#

optax.adamaxw(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, weight_decay: float = 0.0001, mask: Any | Callable[[base.Params], Any] | None = None) 基础.渐变变换[来源]#

具有权重衰减正则化的Adamax。

AdamaxW使用权重衰减来规范学习朝向小权重,因为这会导致更好的泛化。在SGD中,你也可以使用L2正则化将其作为附加损失项来实现,但是L2正则化对于自适应梯度算法如Adam并不表现如预期。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 用于跟踪过去梯度最大值的指数衰减率。

  • eps – 一个小常数,应用于分母以避免在重新缩放时除以零。

  • weight_decay – 权重衰减正则化的强度。请注意,这个权重衰减是与学习率相乘的。这与其他框架如PyTorch是一致的,但与(Loshchilov et al, 2019)不同,在该框架中,权重衰减仅与“调度乘子”相乘,而不是与基本学习率相乘。

  • mask – 一个与 params PyTree 具有相同结构(或为其前缀)的树,或者是一个在给定 params/updates 时返回这样的 pytree 的可调用对象。叶子应该是布尔值,True 表示您想要应用权重衰减的叶子/子树,而 False 表示您想要跳过的树。请注意,Adamax 梯度变换会应用于所有参数。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adamaxw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

参考文献

Loshchilov 等, 2019: https://arxiv.org/abs/1711.05101

警告

有时您可能希望跳过 BatchNorm scale 或偏置参数的权重衰减。您可以使用 optax.masked 来制作您自己的 AdamaxW 变体,其中 additive_weight_decay 仅应用于一部分 params

另请参见

optax.adam(), optax.adamax().

亚当W#

optax.adamw(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, weight_decay: float = 0.0001, mask: Any | Callable[[base.Params], Any] | None = None, *, nesterov: bool = False) 基础.渐变变换[来源]#

带有权重衰减正则化的亚当。

AdamW使用权重衰减来规范学习以获得较小的权重,因为这会导致更好的泛化。在SGD中,您还可以使用L2正则化将其作为附加损失项,但L2正则化对于自适应梯度算法(例如Adam)的表现并不如预期,参见[Loshchilov et al,2019]。

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\), the optimizer state \(S_t\) and the parameters \(\theta_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]

该实现可以结合由[Dozat 2016]引入的Nesterov的动量。生成的优化器通常称为NAdamW。使用关键字参数 nesterov=True 时,优化器使用Nesterov动量,用上面的 \(\hat{m}_t\) 替换

\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. \]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 一个小常数,应用于平方根外的分母(如在Adam论文中),以避免在重新缩放时除以零。

  • eps_root – 应用在平方根内分母上的一个小常数(如RMSProp中),以避免在重新缩放时出现除以零的情况。这在通过Adam计算(元)梯度时是必需的。

  • mu_dtype – 可选的 dtype 用于一阶累加器;如果 Nonedtype 将从 paramsupdates 中推断。

  • weight_decay – 权重衰减正则化的强度。注意,这个权重衰减与学习率相乘。这与其他框架如PyTorch是一致的,但与(Loshchilov et al, 2019)不同,在该框架中,权重衰减只是与“调度乘子”相乘,而不是基础学习率。

  • mask - 一棵与 params PyTree 具有相同结构(或其前缀)的树,或一个可调用对象,该对象在给定 params/updates 时返回这样的 pytree。叶子应该是布尔值,对于您想要应用权重衰减的叶子/子树,应为 True,对于您想要跳过的则应为 False。请注意,Adam 梯度变换应用于所有参数。

  • nesterov – 是否使用Nesterov动量。使用nesterov=True的求解器等同于optax.nadamw()优化器。此修改在[Dozat 2016]中有描述。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adamw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

参考文献

Loshchilov 等, 解耦权重衰退正则化,2019

Dozat, 将Nesterov动量纳入Adam, 2016

另请参见

查看相关函数 optax.adam(), optax.nadamw(), 以及示例 关于小莎士比亚的字符级变换器 用于用例。

AMSGrad#

optax.amsgrad(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None) 基础.渐变变换[来源]#

AMSGrad 优化器。

在某些情况下,原始的Adam可能无法收敛到最优解。AMSGrad通过使用过去梯度的长期记忆来保证收敛。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 在平方根外应用于分母的小常数(如在Adam论文中)以避免在重新缩放时除以零。

  • eps_root – 应用在平方根内分母上的一个小常数(如RMSProp中),以避免在重新缩放时出现除以零的情况。这在通过Adam计算(元)梯度时是必需的。

  • mu_dtype – 可选的 dtype 用于一阶累加器;如果 Nonedtype 将从 paramsupdates 中推断。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.amsgrad(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

参考文献

Reddi 等人, 关于 Adam 和更高级方法的收敛性, 2023

奶酪#

optax.fromage(learning_rate: float, min_norm: float = 1e-06) optax.GradientTransformation[来源]#

弗罗贝尼乌斯匹配梯度下降(Fromage)优化器。

Fromage 是一种不需要学习率调节的学习算法。 优化器基于通过深度相对信任(深度神经网络上的一种距离函数)建模神经网络梯度。Fromage 类似于 LARS 优化器,并且可以在一系列标准神经网络基准上工作, 例如自然语言转换器和生成对抗网络。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • min_norm – 一个最小值,梯度更新的范数和层参数的范数可以被截断到该值,以避免在计算信任比率时除以零(如LARS论文中所述)。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.fromage(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

参考文献

伯恩斯坦等人, 关于两个神经网络之间的距离和学习的稳定性, 2020

Lamb#

optax.lamb(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-06, eps_root: float = 0.0, weight_decay: float = 0.0, mask: MaskOrFn = None) 基础.渐变变换[来源]#

LAMB优化器。

LAMB是一种通用的逐层自适应大批量优化器,旨在在广泛的任务中提供一致的训练性能,包括那些使用基于注意力的模型(如Transformers)和ResNet-50的任务。该优化器能够处理小批量和大批量的训练。LAMB的灵感来源于LARS学习算法。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 在平方根外应用于分母的小常数(如在Adam论文中)以避免在重新缩放时除以零。

  • eps_root – 应用在平方根内分母上的一个小常数(如RMSProp中),以避免在重新缩放时出现除以零的情况。这在通过Adam计算(元)梯度时是必需的。

  • weight_decay – 权重衰减正则化的强度。

  • mask – 具有与参数 PyTree 相同结构(或其前缀)的树,或者是一个可调用的对象,该对象根据参数/更新返回这样的 pytree。叶子应该是布尔值,True 表示要应用变换的叶子/子树,False 表示要跳过的部分。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.lamb(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

参考文献

您等人, 大批量优化深度学习:在76分钟内训练BERT,2020

Lars#

optax.lars(learning_rate: base.ScalarOrSchedule, weight_decay: float = 0.0, weight_decay_mask: MaskOrFn = True, trust_coefficient: float = 0.001, eps: float = 0.0, trust_ratio_mask: MaskOrFn = True, momentum: float = 0.9, nesterov: bool = False) 基础.渐变变换[来源]#

LARS优化器。

LARS是一种逐层自适应优化器,旨在帮助将SGD扩展到更大的批量大小。LARS后来启发了LAMB优化器。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • weight_decay – 权重衰减正则化的强度。

  • weight_decay_mask – 一个与params PyTree结构相同(或其前缀)的树,或者一个返回这样的pytree的可调用对象,给定params/updates。叶子应该是布尔值,True表示您想对其应用转换的叶子/子树,而False表示您想跳过的那些。

  • trust_coefficient – 信任比率的乘数。

  • eps – 信任比率分母中的可选加性常数。

  • trust_ratio_mask – 一棵具有与参数 PyTree 相同结构(或前缀)的树,或者一个根据参数/更新返回此类 pytree 的可调用对象。叶子应该是布尔值,True 表示您想要应用转换的叶子/子树,False 表示您想要跳过的叶子/子树。

  • 动量 – 动量的衰减率。

  • nesterov – 是否使用Nesterov动量。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.lars(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01

参考文献

你等人, Large Batch Training of Convolutional Networks, 2017

LBFGS#

optax.lbfgs(learning_rate: Optional[base.ScalarOrSchedule] = None, memory_size: int = 10, scale_init_precond: bool = True, linesearch: Optional[base.GradientTransformationExtraArgs] = (<function scale_by_zoom_linesearch.<locals>.init_fn>, <function scale_by_zoom_linesearch.<locals>.update_fn>)) 基础.渐变变换额外参数[来源]#

L-BFGS 优化器。

L-BFGS 是一种拟牛顿方法,它将更新(梯度)与逆Hessian的近似值相乘。该算法不需要访问Hessian,因为这个近似是在优化过程中利用梯度评估构建的。L-BFGS 是Broyden-Fletcher-Goldfarb-Shanno (BFGS) 算法的有限内存变体。BFGS 算法需要存储一个大小为 \(p \times p\) 的矩阵,其中 \(p\) 是参数的维度。有限变体通过仅使用 \(m\) (memory_size) 过去的参数/梯度差异来计算逆的近似值,从而规避了这个问题。即,Hessian逆的近似记作 \(P_k = P_{k, k}\), 其中

\[\begin{align*} P_{k, j+1} & = V_j^\top P_{k, j} V_j + \rho_j \delta w_j \delta w_j^\top \quad \text{对于} \ j \in \{k-m, \ldots, k-1\}\\ P_{k, k-m} & = \gamma_k I \\ V_k & = I - \rho_k \delta u_k \delta w_k^\top \\ \rho_k & = 1/(\delta u_k^\top \delta w_k) \\ \delta w_k & = w_{k+1} - w_k \\ \delta u_k & = u_{k+1} - u_k \\ \gamma_k & = \begin{cases} (\delta w_{k-1}^\top \delta u_{k-1}) / (\delta u_{k-1}^\top \delta u_{k-1}) & \text{如果} \ \texttt{scale\_init\_hess} \\ 1 & \text{否则} \end{cases}, \end{align*}\]

对于 \(u_k\) 在第 \(k\) 次迭代中的梯度/更新, \(w_k\) 在第 \(k\) 次迭代中的参数。

更新 \(P_k\) 的公式是通过计算最佳预处理矩阵来获得的,符合某些割线条件,详情请参见参考文献。计算 \(P_k u_k\) 可以通过一系列向量操作来完成,使用存储在内存缓冲区中的过去参数和梯度的差异。

当前函数仅输出LBFGS方向 \(P_k u_k\)。它可以与线搜索链式连接,以确保足够的减少和低曲率,例如缩放线搜索。线搜索计算步长 \(\eta_k\),使得更新后的参数(使用 optax.apply_updates())的形式为 \(w_{k+1} = w_k - \eta_k P_k u_k\)

Parameters:
  • learning_rate – 可选的全局缩放因子,可以是固定的,也可以是随着迭代而变化的,使用调度器,见 optax.scale_by_learning_rate()。默认情况下,学习率由线搜索处理。

  • memory_size – 保持在内存中用于近似Hessian逆的过去更新数量。

  • scale_init_precond – 是否使用缩放的单位矩阵作为初始预处理器,见上面的\(\gamma_k\)公式。

  • linesearch – 一个optax.GradientTransformationExtraArgs的实例,例如optax.scale_by_zoom_linesearch(),用于计算学习率(即步长),以满足某些标准,例如通过额外调用目标函数实现目标的充分减少。

Returns:

一个 optax.GradientTransformationExtraArgs 对象。

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)
>>> solver = optax.lbfgs()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> value_and_grad = optax.value_and_grad_from_state(f)
>>> for _ in range(2):
...   value, grad = value_and_grad(params, state=opt_state)
...   updates, opt_state = solver.update(
...      grad, opt_state, params, value=value, grad=grad, value_fn=f
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(f(params)))
Objective function: 7.52E+00
Objective function: 7.46E-14

参考文献

Nocedal等人的《数值优化》中的算法7.4、7.5(第199页),1999年

Liu et al., On the limited memory BFGS method for large scale optimization , 1989.

警告

该优化器对内存要求高,最适合用于小型到中型问题。

警告

这个优化器在使用线搜索时效果最好(当前默认是缩放线搜索)。请参见上面的示例,以在非随机设置中获得最佳使用,我们可以利用线搜索计算的梯度来重复使用 optax.value_and_grad_from_state()

注意

我们将身份的缩放初始化为梯度范数的一个上限倒数。这避免了在第一步浪费线搜索迭代,通过考虑梯度的大小。换句话说,我们将第一步的信任区域限制在第一次迭代时半径为1的欧几里得球内。\(\gamma_0\)的选择在上述参考文献中没有详细说明,因此这是一个启发式选择。

注意

该算法可以支持复杂的输入。

狮子#

optax.lion(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.99, mu_dtype: Any | None = None, weight_decay: float = 0.001, mask: Any | Callable[[base.Params], Any] | None = None) 基础.渐变变换[来源]#

狮子优化器。

Lion 是通过符号程序搜索发现的。与大多数自适应优化器(如 AdamW)不同,Lion 仅跟踪动量,从而使其在内存使用上更为高效。Lion 的更新是通过符号操作产生的,导致其范数相比于其他优化器(如 SGD 和 AdamW)产生的更新更大。Lion 的合适学习率通常比 AdamW 小 3-10 倍,而 Lion 的权重衰减应该比 AdamW 大 3-10 倍,以保持类似的强度 (lr * wd)。

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), represent the arguments b1 and b2 respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).

The init function of this optimizer initializes an internal state \(S_0 := (m_0) = (0)\), representing the intial estimate for the first moment. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\), the optimizer state \(S_t\) and the parameters \(\theta_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} c_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ u_t &\leftarrow -\alpha_t \cdot \left( sign \left( c_t \right) + \lambda \theta_{t} \right)\\ m_t &\leftarrow \beta_2 \cdot m_{t-1} + (1-\beta_2) \cdot g_t \\ S_t &\leftarrow (m_t). \end{align*}\]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 结合动量和当前梯度的比率。

  • b2 – 指数衰减率,用于跟踪过去梯度的动量。

  • mu_dtype – 可选的 dtype 用于一阶累加器;如果 Nonedtype 将从 paramsupdates 中推断。

  • weight_decay – 权重衰减正则化的强度。请注意,这个权重衰减是与学习率相乘的。这与其他框架如PyTorch是一致的,但与(Loshchilov et al, 2019)不同,在该框架中,权重衰减仅与“调度乘子”相乘,而不是与基本学习率相乘。

  • mask – 一个与params PyTree具有相同结构的树(或其前缀),或者是一个在给定params/updates时返回这种pytree的Callable。叶子应为布尔值,True表示要对其应用权重衰减的叶子/子树,False表示要跳过的叶子/子树。注意,Adam梯度变换会应用于所有参数。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.lion(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

参考文献

陈等人, 符号发现优化算法, 2023

Nadam#

optax.nadam(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, *, nesterov: bool = True) 基础.渐变变换#

NAdam优化器。

Nadam 是一种变体,具有 Nesterov 动量的 optax.adam()。该求解器的更新规则如下:

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}\\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 一个小常数,应用于平方根外的分母(如在Adam论文中),以避免在重新缩放时除以零。

  • eps_root – 一个应用于平方根内部分母的小常数(如在RMSProp中),用于避免在重新缩放时除以零。这在通过Adam计算(元)梯度时是必要的。

  • mu_dtype – 可选的 dtype 用于一阶累加器;如果 Nonedtype 将从 paramsupdates 中推断。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.nadam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01

参考文献

Dozat, 将Nesterov动量纳入Adam, 2016

另请参见

optax.adam(), optax.nadamw().

在版本 0.1.9 中添加。

NadamW#

optax.nadamw(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, weight_decay: float = 0.0001, mask: Any | Callable[[base.Params], Any] | None = None, *, nesterov: bool = True) 基础.渐变变换#

NAdamW优化器,作为AdamW优化器的一部分实现。

NadamW 是带有 Nesterov 动量的 optax.adamw() 的变体。与 AdamW 相比,这种优化器替换了赋值

\[\hat{m}_t \leftarrow m_t / {(1-\beta_1^t)}\]

\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.\]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 一个小常数,应用于平方根外的分母(如在Adam论文中),以避免在重新缩放时除以零。

  • eps_root – 应用在平方根内分母上的一个小常数(如RMSProp中),以避免在重新缩放时出现除以零的情况。这在通过Adam计算(元)梯度时是必需的。

  • mu_dtype – 可选的 dtype 用于一阶累加器;如果 Nonedtype 将从 paramsupdates 中推断。

  • weight_decay – 权重衰减正则化的强度。注意,这个权重衰减与学习率相乘。这与其他框架如PyTorch是一致的,但与(Loshchilov et al, 2019)不同,在该框架中,权重衰减只是与“调度乘子”相乘,而不是基础学习率。

  • mask - 一棵与 params PyTree 具有相同结构(或其前缀)的树,或一个可调用对象,该对象在给定 params/updates 时返回这样的 pytree。叶子应该是布尔值,对于您想要应用权重衰减的叶子/子树,应为 True,对于您想要跳过的则应为 False。请注意,Adam 梯度变换应用于所有参数。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.nadamw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01

参考文献

Loshchilov 等, 解耦权重衰退正则化,2019

Dozat, 将Nesterov动量纳入Adam, 2016

另请参见

optax.adam(), optax.adamw().

在0.1.9版本中添加。

嘈杂的SGD#

optax.noisy_sgd(learning_rate: base.ScalarOrSchedule, eta: float = 0.01, gamma: float = 0.55, seed: int = 0) 基础.渐变变换[来源]#

带有附加噪声的SGD变体。

噪声随机梯度下降是一个变体的 optax.sgd(),它将在更新中加入高斯噪声。研究发现,向梯度中添加噪声可以改进非常深层网络的训练误差和泛化误差。

更新 \(u_t\) 被修改为包含如下噪声:

\[u_t \leftarrow -\alpha_t (g_t + N(0, \sigma_t^2)), \]

其中 \(N(0, \sigma_t^2)\) 表示均值为零且方差为 \(\sigma_t^2\) 的高斯噪声。

该噪声的方差随时间根据以下公式衰减

\[\sigma_t^2 = \frac{\eta}{(1+t)^\gamma}, \]

其中 \(\gamma\) 是衰减率参数 gamma\(\eta\) 表示初始方差 eta

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • eta – 添加到梯度中的高斯噪声的初始方差。

  • gamma – 一个控制噪声随时间退火的参数 t, 方差根据 (1+t)**(-gamma) 进行衰减。

  • seed – 伪随机生成过程的种子。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.noisy_sgd(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

参考文献

Neelakantan 等, Adding Gradient Noise Improves Learning for Very Deep Networks, 2015

诺沃格拉德#

optax.novograd(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.25, eps: float = 1e-06, eps_root: float = 0.0, weight_decay: float = 0.0) 基础.渐变变换[来源]#

NovoGrad 优化器。

NovoGrad 对初始学习率和权重初始化的鲁棒性比其他方法更强。例如,NovoGrad 在没有学习率预热的情况下也能很好地工作,而其他方法则需要预热。NovoGrad 在大批量训练中表现尤为出色,比如它在 ResNet-50 上的表现优于其他方法,对于所有批次直到 32K。此外,NovoGrad 所需的内存仅为 Adam 的一半。它与 Jasper ASR 模型一起推出。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 一个指数衰减率,用于追踪过去梯度的第一时刻。

  • b2 – 用于跟踪过去梯度的第二时刻的指数衰减率。

  • eps – 在平方根外应用于分母的小常数(如在Adam论文中)以避免在重新缩放时除以零。

  • eps_root – 应用在平方根内分母上的一个小常数(如RMSProp中),以避免在重新缩放时出现除以零的情况。这在通过Adam计算(元)梯度时是必需的。

  • weight_decay – 权重衰减正则化的强度。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.novograd(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.37E+01

参考文献

Ginsburg 等, 带层自适应矩的随机梯度方法用于深度网络的训练, 2019

Li等人,Jasper: 一种端到端卷积神经声学模型,2019

乐观GD#

optax.optimistic_gradient_descent(learning_rate: base.ScalarOrSchedule, alpha: base.ScalarOrSchedule = 1.0, beta: base.ScalarOrSchedule = 1.0) 基础.渐变变换[来源]#

一个乐观的梯度下降优化器。

乐观梯度下降是一种额外梯度方法的近似,这些方法需要多次梯度调用来计算下一个更新。它在最小-最大博弈中对于最后一次迭代收敛具有强有力的形式保证,而标准梯度下降可能会振荡甚至发散。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • alpha – 广义OGD的系数。

  • beta – 广义OGD负动量的系数。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.optimistic_gradient_descent(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

参考文献

Mokhtari 等, 对鞍点问题的额外梯度和乐观梯度方法的统一分析:近端点方法, 2019

乐观的Adam#

optax.optimistic_adam(learning_rate: base.ScalarOrSchedule, optimism: float | None = None, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, *, nesterov: bool = True) 基础.渐变变换[来源]#

乐观的Adam优化器。

这是Adam优化器的一个乐观版本。它解决了在训练生成对抗网络和其他鞍点极小极大问题中的限制循环行为问题。

算法如下。首先,我们定义以下参数:

  • \(\alpha_t\): 在迭代\(t\)时的学习率或步长。

  • \(o_t\) 在迭代 \(t\) 时的乐观率。

  • \(\beta_1\) 第一个时刻估计的指数衰减率。

  • \(\beta_2\) 第二矩估计的指数衰减率。

其次,我们定义以下变量:

  • \(g_t\): 输入的梯度。

  • \(m_t\): 偏置的第一次矩估计。

  • \(v_t\): 偏置的二阶矩估计。

  • \(\hat{m}_t\): 偏差修正的第一矩估计。

  • \(\hat{v}_t\): 偏差修正的第二阶原始矩估计。

  • \(r_t\): 信号噪声比 (SNR) 向量。

  • \(u_t\): 外部更新向量。

  • \(S_t\): 优化器的状态。

最后,在每次迭代中,变量按如下方式更新:

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t - 1} + (1 - \beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t - 1} + (1 - \beta_2) \cdot g_t^2 \\ \hat{m}_t &\leftarrow m_t / {(1 - \beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1 - \beta_2^t)} \\ r_t &\leftarrow \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right) \\ u_t &\leftarrow -\alpha_t r_t - o_t (r_t - r_{t - 1}) \\ S_t &\leftarrow (m_t, v_t, r_t). \end{align*}\]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • 乐观 – 要应用的乐观程度。如果为 None,默认为学习率,如论文中所述。

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 添加到分母的术语,以提高数值稳定性。

  • eps_root – 添加到预测误差的第二矩的术语,以提高数值稳定性。如果通过梯度变换(例如用于元学习)反向传播梯度,则此值必须非零。

  • mu_dtype – 可选的 dtype 用于一阶累加器;如果 Nonedtype 将从 paramsupdates 中推断。

  • nesterov – 是否使用Nesterov动量。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> from jax import numpy as jnp, lax
>>> def f(x, y):
...  return x * y  # simple bilinear function
>>> opt = optax.optimistic_adam(1e-2, 1.0)
>>> def step(state, _):
...  params, opt_state = state
...  distance = jnp.hypot(*params)
...  grads = jax.grad(f, argnums=(0, 1))(*params)
...  grads = grads[0], -grads[1]
...  updates, opt_state = opt.update(grads, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  return (params, opt_state), distance
>>> params = 1.0, 2.0
>>> opt_state = opt.init(params)
>>> _, distances = lax.scan(step, (params, opt_state), length=1025)
>>> for i in range(6):
...  print(f"{distances[4**i]:.3f}")
2.243
2.195
2.161
2.055
0.796
0.001

参考文献

Daskalakis 等人, Training GANs with Optimism, 2017

Polyak步长SGD#

optax.polyak_sgd(max_learning_rate: float = 1.0, scaling: base.ScalarOrSchedule = 1.0, f_min: float = 0.0, eps: float = 0.0, variant: str = 'sps') 基础.渐变变换额外参数[来源]#

带有Polyak步长的SGD。

该求解器实现了具有Polyak步长的SGD(Loizou等,2021)。它将步长设置为

\[s \min\left\{\frac{f(x) - f^\star}{\|\nabla f(x)\|^2 + \epsilon}, \gamma_{\max}\right\}\,, \]

where \(f\) is the function from which a gradient is computed, \(\gamma_{\max}\) is a maximal acceptable learning rate set by max_learning_rate, \(\epsilon\) is a constant preventing division by zero set with eps, \(s\) scales the formula by scaling, and \(f^\star\) is a guess of the minimum value of the function set with f_min.

设置 variant="sps+" (Garrigos et al. 2023) 仅使用非负部分的次优性差距。也就是说,它将 \(f(x) - f^\star\) 替换为 \((f(x) - f^\star)_+\),其中 \(a_+ = \max \{x, 0\}\)

Parameters:
  • max_learning_rate – 用于使用的最大步长(默认为1)。

  • 缩放 – 一个全局缩放因子,可以是固定的或随着调度器的迭代而演变(默认为1)。

  • f_min – 目标函数的下界(默认为0)。对应上面公式中的 \(f^\star\)

  • eps – 更新分母中要添加的值(默认值为0)。

  • variant – 可以是 'sps''sps+'(默认为 'sps')。

Returns:

A optax.GradientTransformationExtraArgs,其中 update 函数接受一个额外的关键字参数 value,包含目标函数的当前值。

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.polyak_sgd()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  3.5
Objective function:  0.875
Objective function:  0.21875
Objective function:  0.0546875
Objective function:  0.013671875

参考文献

Loizou 等人 随机 Polyak 步长用于 SGD:一种加速收敛的自适应学习率,2021

Berrada 等人,通过插值训练神经网络,2020

Garrigos 等人, 函数值学习:基于 Polyak 步长和 ERM 中函数分裂的自适应学习率,2023

警告

此方法需要了解目标函数最小值的近似值,该值通过 f_min 参数传递。对于插值数据的模型,可以将其设置为 0(默认值)。未能为 f_min 设置适当的值可能会导致发散或收敛到次优解。

RAdam#

optax.radam(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, threshold: float = 5.0, *, nesterov: bool = False) 基础.渐变变换[来源]#

修正后的Adam优化器。

Adam中的自适应学习率在训练的早期阶段由于用于估计优化器统计数据的训练样本数量有限而具有不理想的大方差。修正的Adam通过分析性地减少大方差来解决此问题。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 在平方根外应用于分母的小常数(如在Adam论文中)以避免在重新缩放时除以零。

  • eps_root – 应用在平方根内分母上的一个小常数(如RMSProp中),以避免在重新缩放时出现除以零的情况。这在通过Adam计算(元)梯度时是必需的。

  • threshold – 方差可处理性的阈值。

  • nesterov – 是否使用Nesterov动量。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.radam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

参考文献

Liu et al, 2020: 关于自适应学习率的方差及其超越, 2020

RMSProp#

optax.rmsprop(learning_rate: base.ScalarOrSchedule, decay: float = 0.9, eps: float = 1e-08, initial_scale: float = 0.0, eps_in_sqrt: bool = True, centered: bool = False, momentum: float | None = None, nesterov: bool = False, bias_correction: bool = False) 基础.渐变变换[来源]#

一个灵活的 RMSProp 优化器。

RMSProp 是一种带有学习率自适应的 SGD 变体。用于每个权重的 learning_rate 是通过对前一步梯度大小的合适估计进行缩放的。文献中可以找到几种 RMSProp 的变体。此别名提供了一种易于配置的 RMSProp 优化器,可以用于在这些变体之间切换。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • decay – 用于跟踪之前梯度的大小的衰减。

  • eps – 一个小的数字常量,用于避免在重缩放时除以零。

  • initial_scale – 跟踪先前更新幅度的累加器的初始值。PyTorch使用0,TF1使用1。在重现论文中的结果时,请验证作者使用的值。

  • eps_in_sqrt – 是否在分母的平方根内或外添加 eps

  • centered – 是否使用过去梯度的二阶矩或方差来重新缩放最新的梯度。

  • 动量 – 动量项使用的衰减率,当它被设置为 None 时,动量根本不被使用。

  • nesterov – 是否使用Nesterov动量。

  • bias_correction – 是否对第二时刻的估计(如果 centered=True,还包括第一时刻)应用偏差修正。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.rmsprop(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

参考文献

Hinton, 小批量梯度下降概述 `_, 2012

Graves, 生成序列与递归神经网络, 2014

子音, LaProp: 在Adam中分离动量和自适应性 <https://arxiv.org/pdf/2002.04839>`_, 2021

警告

optax的RMSprop的默认行为(eps_in_sqrt=True)与Pytorch的实现有所不同,这可能会影响性能。如果eps_in_sqrt=True,在分母中,optax使用\(\sqrt{v + \epsilon}\),而PyTorch使用\(\sqrt{v} + \epsilon\)。在optax中使用eps_in_sqrt=False将匹配PyTorch的行为。详情请参见google-deepmind/optax#532

RProp#

optax.rprop(learning_rate: float, eta_minus: float = 0.5, eta_plus: float = 1.2, min_step_size: float = 1e-06, max_step_size: float = 50.0) optax.GradientTransformation[来源]#

Rprop 优化器。

Rprop,意为弹性反向传播,是一种一阶的梯度下降变体。它仅对梯度的符号做出反应,通过指数方式增加或减少每个参数所选择的步长,以加速收敛并避免振荡。

Parameters:
  • learning_rate – 初始步长。

  • eta_minus – 逐步大小减少的乘数因子。当梯度在一步到下一步之间改变符号时应用此因子。

  • eta_plus – 增加步长的乘法因子。当梯度从一步到下一步具有相同符号时应用此因子。

  • min_step_size – 最小允许步长。较小的步长将被缩减到该值。

  • max_step_size – 最大允许的步长。较大的步长将被限制为该值。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.rprop(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

参考文献

Riedmiller 等人 一种用于更快反向传播学习的直接自适应方法:RPROP 算法, 1993

Igel et al. 改进的Rprop学习算法的实证评估, 2003

随机梯度下降#

optax.sgd(learning_rate: base.ScalarOrSchedule, momentum: float | None = None, nesterov: bool = False, accumulator_dtype: Any | None = None) 基础.渐变变换[来源]#

一个标准的随机梯度下降优化器。

这实现了随机梯度下降。它还包括对动量和Nesterov加速的支持,因为在使用随机梯度下降训练深度神经网络时,这些是标准做法。

规范的随机梯度下降返回一个形式为\(u_t\) 的更新

\[u_t \leftarrow -\alpha_t g_t, \]

其中 \(g_t\) 是目标的梯度(可能经过其他变换的预处理),而 \(\alpha_t\) 是在时间 \(t\)learning_rate(常量或由一个 optax.Schedule 选择的)。

带动量的随机梯度下降有两种可能的形式。

\[\begin{align*} m_t &\leftarrow g_t + \mu m_{t-1} \\ u_t &\leftarrow \begin{cases} -\alpha_t m_t & \text{ 如果 } \texttt{nesterov = False} \\ -\alpha_t (g_t + \mu m_t) & \text{ 如果 } \texttt{nesterov = True} \end{cases} \\ S_t &\leftarrow m_t, \end{align*}\]

其中 \(\mu\)momentum 参数,而 \(S_t\) 是优化器的状态。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • 动量 – 动量项使用的衰减率,当其设置为 None 时,动量根本不使用。

  • nesterov – 是否使用Nesterov动量。

  • accumulator_dtype – Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.sgd(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

参考文献

Sutskever 等人, 关于初始化和动量在深度学习中的重要性,2013

签名SGD#

optax.sign_sgd(learning_rate: base.ScalarOrSchedule) 基础.渐变变换[来源]#

一种仅使用梯度分量符号的SGD变体。

SignSGD 是一种 SGD 的变体,它在更新中使用梯度分量的符号,而不是它们的实际值。更新 \(u_t\) 被修改如下:

\[u_t \leftarrow -\alpha_t\, \text{sign}\,(g_t), \]

对于 \(\alpha_t\) 在迭代 \(t\) 时给定的学习率,以及 \(\text{sign}\,(g_t)\) 梯度每个分量的符号 \(g_t\)

仅使用梯度更新符号的SGD变种自RProp以来一直被使用,现代形式包括RMSProp、Adam和Lion。SignSGD仅使用梯度更新的符号。SignSGD实现了显著的梯度压缩,显著减少了在多个工作者之间分配学习时通信梯度所带来的瓶颈。

Parameters:

learning_rate – 一个全局缩放因子,可以是固定的,也可以随着调度器在迭代中演变,见 optax.scale_by_learning_rate()

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.sign_sgd(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

参考文献

伯恩斯坦等人,signSGD: 非凸问题的压缩优化,2018

Balles 等人, The Geometry of Sign Gradient Descent,2020

SM3#

optax.sm3(learning_rate: float, momentum: float = 0.9) optax.GradientTransformation[来源]#

SM3优化器。

SM3(平方根最小值和平方梯度的最大值方法)是一种节省内存的自适应优化器,旨在减少在训练非常大模型时的内存开销,例如用于机器翻译的Transformer、用于语言建模的BERT以及用于图像分类的AmoebaNet-D。SM3:1)适用于任意维度的张量以及参数的任何预定义覆盖;2)以自适应和数据驱动的方式调整学习率(类似于Adagrad与Adafactor不同);3)在随机凸优化环境中具有严格的收敛保证。

该优化器的初始化函数初始化一个内部状态 \(S_0 := \{\mu_0, w_1\} = \{0, 0\}\),表示累积平方梯度和权重的初始估计。这些值以包含全零的pytrees存储,形状与模型更新相同。在 步骤\(t\),该优化器的更新函数将传入的梯度\(g_t\)和优化器状态\(S_t\)作为参数,并计算更新\(u_t\)和新状态\(S_{t+1}\)。因此,对于 \(t > 0\),我们有:

SM3-I 算法

\[\begin{array}{l} \text{参数: 学习率 } \eta \\ \text{初始化 } w_1 = 0; \forall r \in [k]: \mu_0(r) = 0 \\ \text{对于 } t = 1, \ldots, T \text{ 执行} \\ \quad \text{接收梯度 } g_t = \nabla \ell_t(w_t) \\ \quad \text{对于 } r = 1, \ldots, k \text{ 执行} \\ \quad \quad \mu_t(r) \leftarrow \mu_{t-1}(r) + \max_{j \in S_r} g_t^2(j) \\ \quad \text{对于 } i = 1, \ldots, d \text{ 执行} \\ \quad \quad \nu_t(i) \leftarrow \min_{r:S_r \ni i} \mu_t(r) \\ \quad \quad w_{t+1}(i) \leftarrow w_t(i) - \eta \frac{g_t(i)}{\sqrt{\nu_t(i)}} \\ \quad \quad \text{根据约定 } 0/0 = 0 \end{array}\]

SM3-II 算法

SM3-II优化器使用学习率:math:eta和权重:math:w_1初始化参数。它通过使用梯度:math:g_t迭代地更新权重,调整每个组件的最小累积值:math:nu’_t(i),并保持子集:math:S_r的累积最大值:math:mu’_t(r)。SM3-II从初始状态:math:S_0 := (m_0, s_0)设置为零,存储与模型更新形状匹配的第一和第二矩的估计值。

\[\begin{array}{l} \text{参数:学习率 } \eta \\ \text{初始化 } w_1 = 0; \forall r \in [k]: \mu'_0(r) = 0 \\ \text{对于 } t = 1, \ldots, T \text{ 执行} \\ \quad \text{接收梯度 } g_t = \nabla \ell_t(w_t) \\ \quad \text{初始化 } \mu'_t(r) = 0 \text{ 对于所有 } r \in [k] \\ \quad \text{对于 } i = 1, \ldots, d \text{ 执行} \\ \quad \quad \nu'_t(i) \leftarrow \min_{r:S_r \ni i} \mu'_{t-1}(r) + g_t^2(i) \\ \quad \quad w_{t+1}(i) \leftarrow w_t(i) - \eta \frac{g_t(i)}{\sqrt{\nu'_t(i)}} \\ \quad \quad \text{按约定 } 0/0 = 0 \\ \quad \text{对于所有 } r : S_r \ni i \text{ 执行} \\ \quad \quad \mu'_t(r) \leftarrow \max\{\mu'_t(r), \nu'_t(i)\} \end{array}\]
Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • 动量 – 动量项使用的衰减率(当它未设置为 None 时,则完全不使用动量)。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.sm3(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01

参考文献

Anil et al, Memory-Efficient Adaptive Optimization, 2019

瑜伽#

optax.yogi(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 0.001) 基础.渐变变换[来源]#

瑜伽优化器。

Yogi 是一个自适应优化器,提供控制以调整有效学习率,防止其增加。通过这样做,它专注于解决基于指数移动平均的自适应方法(如 Adam 和 RMSprop)中的收敛性和泛化性问题。Yogi 是 Adam 的一种修改,使用相同的参数。

Parameters:
  • learning_rate – 一个全局缩放因子,可以是固定的,也可以随着 迭代使用调度器而变化,见 optax.scale_by_learning_rate().

  • b1 – 指数衰减率,用于跟踪过去梯度的第一个时刻。

  • b2 – 指数衰减率,用于跟踪过去梯度的第二矩。

  • eps – 在平方根外应用于分母的小常数(如在Adam论文中)以避免在重新缩放时除以零。

Returns:

对应的 optax.GradientTransformation

示例

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.yogi(learning_rate=0.002)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01

参考文献

Zaheer et al, 自适应非凸优化方法, 2018