双线性最小-最大问题中的乐观梯度下降

双线性最小-最大问题中的乐观梯度下降#

Open in Colab

import functools
import jax
import optax
import matplotlib.pyplot as plt
from jax import lax, numpy as jnp

考虑以下最小-最大问题:

\[ \min_{x \in \mathbb R^m} \max_{y\in\mathbb R^n} f(x,y), \]

其中 \(f: \mathbb R^m \times \mathbb R^n \to \mathbb R\) 是一个凸凹函数。此类问题的解是一个鞍点 \((x^\star, y^\star)\in \mathbb R^m \times \mathbb R^n\),使得

\[ f(x^\star, y) \leq f(x^\star, y^\star) \leq f(x, y^\star). \]

标准梯度下降-上升(GDA)在第\(k\)步根据以下更新规则更新\(x\)\(y\)

\[ x_{k+1} = x_k - \eta_k \nabla_x f(x_k, y_k) \\ y_{k+1} = y_k + \eta_k \nabla_y f(x_k, y_k), \]

其中 \(\eta_k\) 是步长。然而,众所周知,GDA 在这种情况下可能无法收敛。这是一个重要的问题,因为基于梯度的最小-最大优化在机器学习中越来越普遍(例如,GANs,约束 RL)。乐观 GDA (OGDA) 通过引入一种基于记忆的负动量来解决这个缺点:

\[ x_{k+1} = x_k - 2 \eta_k \nabla_x f(x_k, y_k) + \eta_k \nabla_x f(x_{k-1}, y_{k-1}) \\ y_{k+1} = y_k + 2 \eta_k \nabla_y f(x_k, y_k) - \eta_k \nabla_y f(x_{k-1}, y_{k-1})). \]

因此,为了实现OGD(或OGA),优化器需要跟踪上一步的梯度。OGDA已被正式证明在这种设置下收敛到最优解\((x_k, y_k) \to (x^\star, y^\star)\)。OGDA更新规则的广义形式由以下给出

\[ x_{k+1} = x_k - (\alpha + \beta) \eta_k \nabla_x f(x_k, y_k) + \beta \eta_k \nabla_x f(x_{k-1}, y_{k-1}) \\ y_{k+1} = y_k + (\alpha + \beta) \eta_k \nabla_y f(x_k, y_k) - \beta \eta_k \nabla_y f(x_{k-1}, y_{k-1})), \]

\(\alpha=\beta=1\)时,恢复标准OGDA。更多详情请参见Mokhtari et al., 2019

\[ \pi^{k+1} = \pi^k - \tau_\pi^k \nabla_\pi \mathcal L(\pi^k, \mu^k) \\ \mu^{k+1} = \mu^k + \tau_\mu^k \nabla_\mu \mathcal L(\pi^k_k, \mu^k), \]
\[ \pi^{k+1} = \pi^k - 2\tau_\pi^k \nabla_\pi \mathcal L(\pi^k, \mu^k) + \tau_\pi^k \nabla_\pi \mathcal L(\pi^{k-1}, \mu^{k-1})\\ \mu^{k+1} = \mu^k + 2\tau_\mu^k \nabla_\mu \mathcal L(\pi^k_k, \mu^k)+ \tau_\mu^k \nabla_\mu \mathcal L(\pi^{k-1}, \mu^{k-1}) \]

其中 \(\eta_k\) 是步长。然而,众所周知,GDA 在这种情况下可能无法收敛。这是一个重要的问题,因为基于梯度的最小-最大优化在机器学习中越来越普遍(例如,GANs,约束 RL)。乐观 GDA (OGDA) 通过引入一种基于记忆的负动量来解决这个缺点:

\[ x_{k+1} = x_k - 2 \eta_k \nabla_x f(x_k, y_k) + \eta_k \nabla_x f(x_{k-1}, y_{k-1}) \\ y_{k+1} = y_k + 2 \eta_k \nabla_y f(x_k, y_k) - \eta_k \nabla_y f(x_{k-1}, y_{k-1})). \]

我们现在将展示一个例子。 首先,我们定义我们的函数 \(f\):

def f(params):
    x, y = params
    return x * y

其次,我们定义我们的辅助函数:

def update(optimizer, state, _):
    params, opt_state = state
    grads = jax.grad(f)(params)
    grads = grads.at[1].apply(jnp.negative)
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return (new_params, new_opt_state), params

def optimize(optimizer, params, iters):
    opt_state = optimizer.init(params)
    _, params_hist = lax.scan(functools.partial(update, optimizer), (params, opt_state), length=iters)
    return params_hist

第三,我们运行优化器并绘制结果。

_, ax_params = plt.subplots()
_, ax_distances = plt.subplots()

params = jnp.array([1.0, 1.0])
ax_params.scatter(*params, label="start", color="black")

for label, optimizer in [
    ("SGD", optax.sgd(0.1)),
    ("Optimistic GD", optax.optimistic_gradient_descent(0.1)),
    ("Adam", optax.adam(0.05, nesterov=True)),
    ("Optimistic Adam", optax.optimistic_adam(0.05, 0.5, nesterov=True)),
]:
    params_hist = optimize(optimizer, params, 10**4)
    distances_to_origin = jnp.hypot(*params_hist.T)
    ax_params.plot(*params_hist.T, label=label, lw=1)
    ax_distances.plot(distances_to_origin, label=label, lw=1)

ax_params.legend()
ax_distances.legend()
ax_params.set(title="parameters", aspect="equal", xlim=(-3, 3), ylim=(-3, 3))
ax_distances.set(xlabel="iteration", ylabel="distance to origin", yscale="log")
plt.show()
../../_images/ff2d8882a5e0dbd66b77bde64dde040929ffbc59db3807664896920c0e03ee2e.png ../../_images/47e9c8b6749e96b99f72c6c121db72bbdfee10eab3d6c1b29302a31842780d63.png