🚀 入门指南#

Open in Colab

Optax 是一个简单的优化库,用于 JAX。主要对象是 GradientTransformation,它可以与其他变换链接以获得最终的更新操作和优化器状态。Optax 还包含一些简单的损失函数和实用工具,以帮助您编写完整的优化步骤。本笔记本将带您了解如何使用 Optax 的几个示例。

示例:拟合线性模型#

首先导入必要的包:

import jax.numpy as jnp
import jax
import optax
import functools

在这个例子中,我们首先设置了一个简单的线性模型和一个损失函数。你可以使用任何其他库,比如haikuFlax来构建你的网络。在这里,我们保持简单并自己编写。损失函数(L2损失)来自Optax的losses,通过l2_loss

@functools.partial(jax.vmap, in_axes=(None, 0))
def network(params, x):
  return jnp.dot(params, x)

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = jnp.mean(optax.l2_loss(y_pred, y))
  return loss

在这里,我们在已知的线性模型下生成数据(使用target_params=0.5):

key = jax.random.PRNGKey(42)
target_params = 0.5

# Generate some data.
xs = jax.random.normal(key, (16, 2))
ys = jnp.sum(xs * target_params, axis=-1)

Optax的基本用法#

Optax 包含了许多流行的优化器的实现,这些优化器可以非常简单地使用。例如,Adam 优化器的梯度变换可以在optax.adam中找到。现在,让我们从调用GradientTransformation对象开始,将其作为 Adam 的optimizer。然后,我们使用init函数和网络的params来初始化优化器状态。

start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)

# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])
opt_state = optimizer.init(params)

接下来我们编写更新循环。GradientTransformation 对象包含一个 update 函数,该函数接收当前的优化器状态和梯度,并返回需要应用于参数的 updatesupdates, new_opt_state = optimizer.update(grads, opt_state)

Optax 提供了一些简单的更新规则,这些规则将梯度变换的更新应用到当前参数上以返回新的参数:new_params = optax.apply_updates(params, updates)

# A simple update loop.
for _ in range(1000):
  grads = jax.grad(compute_loss)(params, xs, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

assert jnp.allclose(params, target_params), \
'Optimization should retrieve the target params used to generate the data.'

自定义优化器#

Optax 通过chain梯度变换使得创建自定义优化器变得容易。例如,这创建了一个基于Adam的优化器。请注意,缩放是-learning_rate,这是一个重要的细节,因为apply_updates是加性的。

# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=start_learning_rate,
    transition_steps=1000,
    decay_rate=0.99)

# Combining gradient transforms using `optax.chain`.
gradient_transform = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
)
# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])  # Recall target_params=0.5.
opt_state = gradient_transform.init(params)

# A simple update loop.
for _ in range(1000):
  grads = jax.grad(compute_loss)(params, xs, ys)
  updates, opt_state = gradient_transform.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

assert jnp.allclose(params, target_params), \
'Optimization should retrieve the target params used to generate the data.'

Optax的高级用法#

在计划中修改优化器的超参数。#

在某些情况下,更改优化器的超参数(除了学习率之外)可能有助于确保训练的可靠性。我们可以通过使用inject_hyperparams轻松实现这一点。例如,这段代码随着训练的进行衰减clip_by_global_norm梯度变换的max_norm

decaying_global_norm_tx = optax.inject_hyperparams(optax.clip_by_global_norm)(
    max_norm=optax.linear_schedule(1.0, 0.0, transition_steps=99))

opt_state = decaying_global_norm_tx.init(None)
assert opt_state.hyperparams['max_norm'] == 1.0, 'Max norm should start at 1.0'

for _ in range(100):
  _, opt_state = decaying_global_norm_tx.update(None, opt_state)

assert opt_state.hyperparams['max_norm'] == 0.0, 'Max norm should end at 0.0'

示例:拟合一个MLP#

让我们使用Optax来拟合一个参数化函数。我们将考虑学习识别一个值是奇数还是偶数的问题。

我们将首先创建一个数据集,该数据集由一批随机8位整数(使用其二进制表示表示)组成,每个值使用1-hot编码标记为“奇数”或“偶数”(即[1, 0]表示奇数,[0, 1]表示偶数)。

import optax
import jax.numpy as jnp
import jax
import numpy as np

BATCH_SIZE = 5
NUM_TRAIN_STEPS = 1_000
RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))

TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)

我们现在可以使用JAX定义一个参数化函数。这将使我们能够高效地计算梯度。

有许多库提供了参数化函数的常见构建块(例如flax和haiku)。不过,在这种情况下,我们将从头开始实现我们的函数。

我们的函数将是一个具有单个隐藏层和单个输出层的1层MLP(多层感知器)。我们使用标准高斯\(\mathcal{N}(0,1)\)分布初始化所有参数。

initial_params = {
    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),
}


def net(x: jnp.ndarray, params: optax.Params) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x


def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

我们将使用optax.adam来计算每个优化步骤中参数的梯度更新。

请注意,由于Optax优化器是使用纯函数实现的,我们还需要跟踪优化器的状态。对于Adam优化器,这个状态将包含动量值。

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
step 0, loss: 15.16637897491455
step 100, loss: 0.371690034866333
step 200, loss: 0.10601001232862473
step 300, loss: 0.00941196084022522
step 400, loss: 0.004067585803568363
step 500, loss: 0.0021123955957591534
step 600, loss: 0.012820865027606487
step 700, loss: 0.003282144432887435
step 800, loss: 0.004763564560562372
step 900, loss: 0.0020072811748832464

我们看到我们的损失似乎已经收敛,这表明我们已经成功地为我们的网络找到了更好的参数。

权重衰减、调度和剪枝#

许多研究模型使用诸如学习率调度和梯度裁剪等技术。这些可以通过将梯度变换(如optax.adamoptax.clip)链接在一起来实现。

在下面,我们将使用带有权重衰减的Adam(optax.adamw),一个余弦学习率计划(带预热)以及梯度裁剪。

schedule = optax.warmup_cosine_decay_schedule(
  init_value=0.0,
  peak_value=1.0,
  warmup_steps=50,
  decay_steps=1_000,
  end_value=0.0,
)

optimizer = optax.chain(
  optax.clip(1.0),
  optax.adamw(learning_rate=schedule),
)

params = fit(initial_params, optimizer)
step 0, loss: 15.16637897491455
step 100, loss: 4.140027296678506e-12
step 200, loss: 0.0
step 300, loss: 0.0
step 400, loss: 0.0
step 500, loss: 4.852311432576579e-24
step 600, loss: 4.094312996277827e-12
step 700, loss: 0.0
step 800, loss: 0.0
step 900, loss: 0.0

组件#

我们参考文档以获取可用Optax组件的详细列表。在这里,我们重点介绍Optax提供的主要构建模块类别。

梯度变换 (transform.py)#

Optax 的关键构建模块之一是 GradientTransformation。每个转换由两个函数定义:

state = init(params)

grads, state = update(grads, state, params=None)

init 函数初始化一组(可能为空的)统计信息(也称为状态),而 update 函数根据一些统计信息(以及可选的参数当前值)转换候选梯度。

例如:

tx = optax.scale_by_rms()
state = tx.init(params)  # init stats
grads = jax.grad(loss)(params, TRAINING_DATA, LABELS)
updates, state = tx.update(grads, state, params)  # transform & update stats.

组合梯度变换 (combine.py)#

转换将候选梯度作为输入并返回处理后的梯度作为输出(与返回更新后的参数相反),这一事实对于允许将任意转换组合成自定义优化器/梯度处理器至关重要,并且还允许为操作在共享变量集上的不同梯度组合转换。

例如,chain 将它们按顺序组合,并返回一个新的 GradientTransformation,该转换按顺序应用多个转换。

例如:

max_norm = 100.
learning_rate = 1e-3

my_optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm),
    optax.scale_by_adam(eps=1e-4),
    optax.scale(-learning_rate))

包装梯度变换 (wrappers.py)#

Optax 还提供了几个包装器,这些包装器以 GradientTransformation 作为输入,并返回一个新的 GradientTransformation,该转换以特定方式修改内部转换的行为。

例如,flatten包装器在应用内部的GradientTransformation之前,将梯度展平为一个大向量。转换后的更新在返回给用户之前会被重新展开。这可以用于减少对大量小变量进行多次计算的开销,但代价是增加内存使用。

例如:

my_optimizer = optax.flatten(optax.adam(learning_rate))

包装器的其他示例包括在多个步骤中累积梯度或仅对特定参数或在特定步骤应用内部转换。

时间表 (schedule.py)#

许多流行的转换使用时间依赖的组件,例如退火一些超参数(例如学习率)。为此,Optax 提供了可以用于根据step计数衰减标量的调度。

例如,您可以使用polynomial_schedule(带有power=1)来在多个步骤中线性衰减超参数:

schedule_fn = optax.polynomial_schedule(
    init_value=1., end_value=0., power=1, transition_steps=5)

for step_count in range(6):
  print(schedule_fn(step_count))  # [1., 0.8, 0.6, 0.4, 0.2, 0.]
1.0
0.8
0.6
0.39999998
0.19999999
0.0

计划可以与其他转换结合使用,如下所示。

schedule_fn = optax.polynomial_schedule(
    init_value=-learning_rate, end_value=0., power=1, transition_steps=5)
optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm),
    optax.scale_by_adam(eps=1e-4),
    optax.scale_by_schedule(schedule_fn))

计划也可以用来代替GradientTransformationlearning_rate参数。

optimizer = optax.adam(learning_rate=schedule_fn)

应用更新 (update.py)#

在使用GradientTransformation或任何自定义操作对更新进行转换后,通常会将更新应用于一组参数。这可以通过jax.tree.map轻松完成。

为了方便,我们公开了一个apply_updates函数来对参数应用更新。该函数只是将更新和参数相加,即jax.tree.map(lambda p, u: p + u, params, updates)

updates, state = tx.update(grads, state, params)  # transform & update stats.
new_params = optax.apply_updates(params, updates)  # update the parameters.

请注意,将梯度变换与参数更新分离对于支持组合一系列变换(例如chain)以及将多个更新组合到相同参数(例如在多任务设置中,不同任务需要不同的梯度变换集)至关重要。

损失函数 (loss.py)#

Optax 提供了许多用于深度学习的标准损失函数,例如 l2_loss, softmax_cross_entropy, cosine_distance 等。

predictions = net(TRAINING_DATA, params)
loss = optax.huber_loss(predictions, LABELS)

损失函数接受批次作为输入,但它们不会在批次维度上进行缩减。这在JAX中很容易实现,例如:

avg_loss = jnp.mean(optax.huber_loss(predictions, LABELS))
sum_loss = jnp.sum(optax.huber_loss(predictions, LABELS))

二阶 (second_order.py)#

计算神经网络的Hessian矩阵或Fisher信息矩阵通常由于二次内存需求而难以处理。求解这些矩阵的对角线通常是一个更好的解决方案。该库提供了计算这些对角线的函数,且具有次二次内存需求。