使用Flax的简单神经网络。#
本笔记本使用Optax和Flax训练一个简单的单层神经网络。对于这两个库的更高级应用,我们建议查看cifar10_resnet示例。
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from flax import linen as nn
# @markdown Learning rate for the optimizer:
LEARNING_RATE = 1e-2 # @param{type:"number"}
# @markdown Number of training steps:
NUM_STEPS = 100 # @param{type:"integer"}
# @markdown Number of samples in the training dataset:
NUM_SAMPLES = 20 # @param{type:"integer"}
# @markdown Shape of the input:
X_DIM = 10 # @param{type:"integer"}
# @markdown Shape of the target:
Y_DIM = 5 # @param{type:"integer"}
在这个单元格中,我们初始化了一个随机数生成器(RNG),并使用它为所有与随机性相关的事情创建了独立的RNG。
rng = jax.random.PRNGKey(0)
params_rng, w_rng, b_rng, samples_rng, noise_rng = jax.random.split(rng, num=5)
在下一个单元格中,我们定义一个模型并获取其初始参数。
# Creates a one linear layer instance.
model = nn.Dense(features=Y_DIM)
# Initializes the parameters.
params = model.init(params_rng, jnp.ones((X_DIM,), dtype=jnp.float32))
在下一个单元格中,我们生成我们的训练数据。
我们将近似一个形式为 y = wx + b 的函数,因此我们生成 w、b、训练样本 x 并使用上述公式获得 y。
# Generates ground truth w and b.
w = jax.random.normal(w_rng, (X_DIM, Y_DIM))
b = jax.random.normal(b_rng, (Y_DIM,))
# Generates training samples.
x_samples = jax.random.normal(samples_rng, (NUM_SAMPLES, X_DIM))
y_samples = jnp.dot(x_samples, w) + b
# Adds noise to the target.
y_samples += 0.1 * jax.random.normal(noise_rng, (NUM_SAMPLES, Y_DIM))
接下来我们定义一个自定义的MSE损失函数。
def make_mse_func(x_batched, y_batched):
def mse(params):
# Defines the squared loss for a single (x, y) pair.
def squared_error(x, y):
pred = model.apply(params, x)
return jnp.inner(y-pred, y-pred) / 2.0
# Vectorizes the squared error and computes mean over the loss values.
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
return jax.jit(mse) # `jit`s the result.
# Instantiates the sampled loss.
loss = make_mse_func(x_samples, y_samples)
# Creates a function that returns value and gradient of the loss.
loss_grad_fn = jax.value_and_grad(loss)
在下一个单元格中,我们使用传递给optax.chain的Optax梯度变换构建了一个简单的Adam优化器。
同样的结果可以通过使用optax.adam别名来实现。然而,在这里,我们演示了如何手动处理梯度变换,以便在需要时构建自己的自定义优化器。
tx = optax.chain(
# Sets the parameters of Adam. Note the learning_rate is not here.
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
# Puts a minus sign to *minimize* the loss.
optax.scale(-LEARNING_RATE)
)
然后我们将模型的初始参数传递给优化器以进行初始化。
opt_state = tx.init(params)
最后,我们训练模型 NUM_STEPS 步。
loss_history = []
# Minimizes the loss.
for _ in range(NUM_STEPS):
# Computes gradient of the loss.
loss_val, grads = loss_grad_fn(params)
loss_history.append(loss_val)
# Updates the optimizer state, creates an update to the params.
updates, opt_state = tx.update(grads, opt_state)
# Updates the parameters.
params = optax.apply_updates(params, updates)
plt.plot(loss_history)
plt.title('Train loss')
plt.xlabel('Step')
plt.ylabel('MSE')
plt.show()