使用随机梯度哈密顿蒙特卡洛(SGHMC)进行MNIST分类

在本笔记本中,我们演示了如何使用Fortuna从为MNIST分类任务训练的简单神经网络模型中获取预测不确定性估计,使用SGHMC方法。

从TensorFlow下载MNIST数据

首先,让我们从TensorFlow Datasets下载MNIST数据。其他来源也同样适用。

[1]:
import tensorflow as tf
import tensorflow_datasets as tfds


def download(split_range, shuffle=False):
    ds = tfds.load(
        name="MNIST",
        split=f"train[{split_range}]",
        as_supervised=True,
        shuffle_files=True,
    ).map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
    if shuffle:
        ds = ds.shuffle(10, reshuffle_each_iteration=True)
    return ds.batch(128).prefetch(1)


train_data_loader, val_data_loader, test_data_loader = (
    download(":80%", shuffle=True),
    download("80%:90%"),
    download("90%:"),
)
2024-11-23 10:49:20.495058: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/docs/checkouts/readthedocs.org/user_builds/aws-fortuna/envs/latest/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

将数据转换为兼容的数据加载器

Fortuna 帮助你将数据和数据加载器转换为 Fortuna 可以消化的数据加载器。

[2]:
from fortuna.data import DataLoader

train_data_loader = DataLoader.from_tensorflow_data_loader(train_data_loader)
val_data_loader = DataLoader.from_tensorflow_data_loader(val_data_loader)
test_data_loader = DataLoader.from_tensorflow_data_loader(test_data_loader)

构建一个概率分类器

让我们构建一个概率分类器。这是一个包含多个可配置属性的接口对象,即modelpriorposterior_approximator。在这个例子中,我们使用多层感知器和SGHMC后验近似器。SGHMC(以及广义上的SGMCMC方法)允许配置步长调度函数。为了简单起见,我们创建了一个恒定步长调度。

[3]:
import flax.linen as nn

from fortuna.prob_model import ProbClassifier, SGHMCPosteriorApproximator
from fortuna.model import MLP

output_dim = 10
prob_model = ProbClassifier(
    model=MLP(output_dim=output_dim, activations=(nn.tanh, nn.tanh)),
    posterior_approximator=SGHMCPosteriorApproximator(
        burnin_length=300, step_schedule=4e-6
    ),
)
WARNING:root:No module named 'transformer' is installed. If you are not working with models from the `transformers` library ignore this warning, otherwise install the optional 'transformers' dependency of Fortuna using poetry. You can do so by entering: `poetry install --extras 'transformers'`.

训练概率模型:后验拟合和校准

我们现在可以训练概率模型。这包括拟合后验分布和校准概率模型。我们将马尔可夫链的预热阶段设置为20个周期,然后从近似后验中获取样本。

[4]:
from fortuna.prob_model import FitConfig, FitMonitor, FitOptimizer
from fortuna.metric.classification import accuracy

status = prob_model.train(
    train_data_loader=train_data_loader,
    val_data_loader=val_data_loader,
    calib_data_loader=val_data_loader,
    fit_config=FitConfig(
        monitor=FitMonitor(metrics=(accuracy,)),
        optimizer=FitOptimizer(n_epochs=30),
    ),
)
Epoch: 30 | loss: -27976.84961 | accuracy: 0.97656: 100%|██████████| 30/30 [01:46<00:00,  3.54s/it]
Epoch: 100 | loss: 1417.83374: 100%|██████████| 100/100 [00:49<00:00,  2.01it/s]

估计预测统计

我们现在可以通过调用概率分类器的predictive属性和感兴趣的方法来计算一些预测统计量。大多数预测统计量,例如均值或众数,需要一个输入数据点的加载器。你可以通过调用数据加载器的方法to_inputs_loader轻松获得这个加载器。

[5]:
test_log_probs = prob_model.predictive.log_prob(data_loader=test_data_loader)
test_inputs_loader = test_data_loader.to_inputs_loader()
test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader)
test_modes = prob_model.predictive.mode(
    inputs_loader=test_inputs_loader, means=test_means
)

计算指标

在分类中,预测模式是对标签的预测,而预测均值是对每个标签概率的预测。因此,我们可以使用这些来计算几个指标,例如准确率、Brier分数、预期校准误差(ECE)等。

[6]:
from fortuna.metric.classification import (
    accuracy,
    expected_calibration_error,
    brier_score,
)

test_targets = test_data_loader.to_array_targets()
acc = accuracy(preds=test_modes, targets=test_targets)
brier = brier_score(probs=test_means, targets=test_targets)
ece = expected_calibration_error(
    preds=test_modes,
    probs=test_means,
    targets=test_targets,
    plot=True,
    plot_options=dict(figsize=(10, 2)),
)
print(f"Test accuracy: {acc}")
print(f"Brier score: {brier}")
print(f"ECE: {ece}")
Test accuracy: 0.9116666316986084
Brier score: 0.12807506322860718
ECE: 0.01654691807925701
../_images/examples_mnist_classification_sghmc_13_1.svg

保形预测集

Fortuna 允许生成符合预测集,这些集合是一些可能的标签,直到达到某个覆盖概率阈值。这些可以从使用或不使用 Fortuna 获得的概率估计开始计算。

[7]:
from fortuna.conformal import AdaptivePredictionConformalClassifier

val_means = prob_model.predictive.mean(inputs_loader=val_data_loader.to_inputs_loader())
conformal_sets = AdaptivePredictionConformalClassifier().conformal_set(
    val_probs=val_means,
    test_probs=test_means,
    val_targets=val_data_loader.to_array_targets(),
    error=0.05,
)

我们可以检查,平均而言,错误分类输入的符合集比正确分类的要大。

[8]:
import numpy as np

avg_size = np.mean([len(s) for s in np.array(conformal_sets, dtype="object")])
avg_size_wellclassified = np.mean(
    [
        len(s)
        for s in np.array(conformal_sets, dtype="object")[test_modes == test_targets]
    ]
)
avg_size_misclassified = np.mean(
    [
        len(s)
        for s in np.array(conformal_sets, dtype="object")[test_modes != test_targets]
    ]
)
print(f"Average conformal set size: {avg_size}")
print(
    f"Average conformal set size over well classified input: {avg_size_wellclassified}"
)
print(f"Average conformal set size over misclassified input: {avg_size_misclassified}")
Average conformal set size: 9.911
Average conformal set size over well classified input: 9.956672760511884
Average conformal set size over misclassified input: 9.439622641509434

此外,我们可视化了一些具有最大和最小保形集的示例。直观上,它们对应于模型对其预测最不确定或最确定的输入。

[9]:
from matplotlib import pyplot as plt

N_EXAMPLES = 10
images = test_data_loader.to_array_inputs()


def visualize_examples(indices, n_examples=N_EXAMPLES):
    n_rows = min(len(indices), n_examples)
    _, axs = plt.subplots(1, n_rows, figsize=(10, 2))
    axs = axs.flatten()
    for i, ax in enumerate(axs):
        ax.imshow(images[indices[i]], cmap="gray")
        ax.axis("off")
    plt.show()
[10]:
indices = np.argsort([len(s) for s in np.array(conformal_sets, dtype="object")])
[11]:
print("Examples with the smallest conformal sets:")
visualize_examples(indices[:N_EXAMPLES])
Examples with the smallest conformal sets:
../_images/examples_mnist_classification_sghmc_21_1.svg
[12]:
print("Examples with the largest conformal sets:")
visualize_examples(np.flip(indices[-N_EXAMPLES:]))
Examples with the largest conformal sets:
../_images/examples_mnist_classification_sghmc_22_1.svg