使用Flax的简单神经网络。

使用Flax的简单神经网络。#

Open in Colab

本笔记本使用Optax和Flax训练一个简单的单层神经网络。对于这两个库的更高级应用,我们建议查看cifar10_resnet示例。

import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from flax import linen as nn
# @markdown Learning rate for the optimizer:
LEARNING_RATE = 1e-2  # @param{type:"number"}
# @markdown Number of training steps:
NUM_STEPS = 100  # @param{type:"integer"}
# @markdown Number of samples in the training dataset:
NUM_SAMPLES = 20  # @param{type:"integer"}
# @markdown Shape of the input:
X_DIM = 10  # @param{type:"integer"}
# @markdown Shape of the target:
Y_DIM = 5  # @param{type:"integer"}

在这个单元格中,我们初始化了一个随机数生成器(RNG),并使用它为所有与随机性相关的事情创建了独立的RNG。

rng = jax.random.PRNGKey(0)
params_rng, w_rng, b_rng, samples_rng, noise_rng = jax.random.split(rng, num=5)

在下一个单元格中,我们定义一个模型并获取其初始参数。

# Creates a one linear layer instance.
model = nn.Dense(features=Y_DIM)

# Initializes the parameters.
params = model.init(params_rng, jnp.ones((X_DIM,), dtype=jnp.float32))

在下一个单元格中,我们生成我们的训练数据。

我们将近似一个形式为 y = wx + b 的函数,因此我们生成 wb、训练样本 x 并使用上述公式获得 y

# Generates ground truth w and b.
w = jax.random.normal(w_rng, (X_DIM, Y_DIM))
b = jax.random.normal(b_rng, (Y_DIM,))

# Generates training samples.
x_samples = jax.random.normal(samples_rng, (NUM_SAMPLES, X_DIM))
y_samples = jnp.dot(x_samples, w) + b
# Adds noise to the target.
y_samples += 0.1 * jax.random.normal(noise_rng, (NUM_SAMPLES, Y_DIM))

接下来我们定义一个自定义的MSE损失函数。

def make_mse_func(x_batched, y_batched):
  def mse(params):
    # Defines the squared loss for a single (x, y) pair.
    def squared_error(x, y):
      pred = model.apply(params, x)
      return jnp.inner(y-pred, y-pred) / 2.0
    # Vectorizes the squared error and computes mean over the loss values.
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
  return jax.jit(mse)  # `jit`s the result.

# Instantiates the sampled loss.
loss = make_mse_func(x_samples, y_samples)

# Creates a function that returns value and gradient of the loss.
loss_grad_fn = jax.value_and_grad(loss)

在下一个单元格中,我们使用传递给optax.chain的Optax梯度变换构建了一个简单的Adam优化器。

同样的结果可以通过使用optax.adam别名来实现。然而,在这里,我们演示了如何手动处理梯度变换,以便在需要时构建自己的自定义优化器。

tx = optax.chain(
    # Sets the parameters of Adam. Note the learning_rate is not here.
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    # Puts a minus sign to *minimize* the loss.
    optax.scale(-LEARNING_RATE)
)

然后我们将模型的初始参数传递给优化器以进行初始化。

opt_state = tx.init(params)

最后,我们训练模型 NUM_STEPS 步。

loss_history = []

# Minimizes the loss.
for _ in range(NUM_STEPS):
  # Computes gradient of the loss.
  loss_val, grads = loss_grad_fn(params)
  loss_history.append(loss_val)
  # Updates the optimizer state, creates an update to the params.
  updates, opt_state = tx.update(grads, opt_state)
  # Updates the parameters.
  params = optax.apply_updates(params, updates)
plt.plot(loss_history)
plt.title('Train loss')
plt.xlabel('Step')
plt.ylabel('MSE')
plt.show()
../../_images/27bf66d644eead6684b21306422208ab09b950fb06437231043fb039348e27ad.png