使用随机梯度哈密顿蒙特卡洛(SGHMC)进行MNIST分类¶
在本笔记本中,我们演示了如何使用Fortuna从为MNIST分类任务训练的简单神经网络模型中获取预测不确定性估计,使用SGHMC方法。
从TensorFlow下载MNIST数据¶
首先,让我们从TensorFlow Datasets下载MNIST数据。其他来源也同样适用。
[1]:
import tensorflow as tf
import tensorflow_datasets as tfds
def download(split_range, shuffle=False):
ds = tfds.load(
name="MNIST",
split=f"train[{split_range}]",
as_supervised=True,
shuffle_files=True,
).map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
if shuffle:
ds = ds.shuffle(10, reshuffle_each_iteration=True)
return ds.batch(128).prefetch(1)
train_data_loader, val_data_loader, test_data_loader = (
download(":80%", shuffle=True),
download("80%:90%"),
download("90%:"),
)
2024-11-23 10:49:20.495058: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/docs/checkouts/readthedocs.org/user_builds/aws-fortuna/envs/latest/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
将数据转换为兼容的数据加载器¶
Fortuna 帮助你将数据和数据加载器转换为 Fortuna 可以消化的数据加载器。
[2]:
from fortuna.data import DataLoader
train_data_loader = DataLoader.from_tensorflow_data_loader(train_data_loader)
val_data_loader = DataLoader.from_tensorflow_data_loader(val_data_loader)
test_data_loader = DataLoader.from_tensorflow_data_loader(test_data_loader)
构建一个概率分类器¶
让我们构建一个概率分类器。这是一个包含多个可配置属性的接口对象,即model、prior和posterior_approximator。在这个例子中,我们使用多层感知器和SGHMC后验近似器。SGHMC(以及广义上的SGMCMC方法)允许配置步长调度函数。为了简单起见,我们创建了一个恒定步长调度。
[3]:
import flax.linen as nn
from fortuna.prob_model import ProbClassifier, SGHMCPosteriorApproximator
from fortuna.model import MLP
output_dim = 10
prob_model = ProbClassifier(
model=MLP(output_dim=output_dim, activations=(nn.tanh, nn.tanh)),
posterior_approximator=SGHMCPosteriorApproximator(
burnin_length=300, step_schedule=4e-6
),
)
WARNING:root:No module named 'transformer' is installed. If you are not working with models from the `transformers` library ignore this warning, otherwise install the optional 'transformers' dependency of Fortuna using poetry. You can do so by entering: `poetry install --extras 'transformers'`.
训练概率模型:后验拟合和校准¶
我们现在可以训练概率模型。这包括拟合后验分布和校准概率模型。我们将马尔可夫链的预热阶段设置为20个周期,然后从近似后验中获取样本。
[4]:
from fortuna.prob_model import FitConfig, FitMonitor, FitOptimizer
from fortuna.metric.classification import accuracy
status = prob_model.train(
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
calib_data_loader=val_data_loader,
fit_config=FitConfig(
monitor=FitMonitor(metrics=(accuracy,)),
optimizer=FitOptimizer(n_epochs=30),
),
)
Epoch: 30 | loss: -27976.84961 | accuracy: 0.97656: 100%|██████████| 30/30 [01:46<00:00, 3.54s/it]
Epoch: 100 | loss: 1417.83374: 100%|██████████| 100/100 [00:49<00:00, 2.01it/s]
估计预测统计¶
我们现在可以通过调用概率分类器的predictive属性和感兴趣的方法来计算一些预测统计量。大多数预测统计量,例如均值或众数,需要一个输入数据点的加载器。你可以通过调用数据加载器的方法to_inputs_loader轻松获得这个加载器。
[5]:
test_log_probs = prob_model.predictive.log_prob(data_loader=test_data_loader)
test_inputs_loader = test_data_loader.to_inputs_loader()
test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader)
test_modes = prob_model.predictive.mode(
inputs_loader=test_inputs_loader, means=test_means
)
计算指标¶
在分类中,预测模式是对标签的预测,而预测均值是对每个标签概率的预测。因此,我们可以使用这些来计算几个指标,例如准确率、Brier分数、预期校准误差(ECE)等。
[6]:
from fortuna.metric.classification import (
accuracy,
expected_calibration_error,
brier_score,
)
test_targets = test_data_loader.to_array_targets()
acc = accuracy(preds=test_modes, targets=test_targets)
brier = brier_score(probs=test_means, targets=test_targets)
ece = expected_calibration_error(
preds=test_modes,
probs=test_means,
targets=test_targets,
plot=True,
plot_options=dict(figsize=(10, 2)),
)
print(f"Test accuracy: {acc}")
print(f"Brier score: {brier}")
print(f"ECE: {ece}")
Test accuracy: 0.9116666316986084
Brier score: 0.12807506322860718
ECE: 0.01654691807925701
保形预测集¶
Fortuna 允许生成符合预测集,这些集合是一些可能的标签,直到达到某个覆盖概率阈值。这些可以从使用或不使用 Fortuna 获得的概率估计开始计算。
[7]:
from fortuna.conformal import AdaptivePredictionConformalClassifier
val_means = prob_model.predictive.mean(inputs_loader=val_data_loader.to_inputs_loader())
conformal_sets = AdaptivePredictionConformalClassifier().conformal_set(
val_probs=val_means,
test_probs=test_means,
val_targets=val_data_loader.to_array_targets(),
error=0.05,
)
我们可以检查,平均而言,错误分类输入的符合集比正确分类的要大。
[8]:
import numpy as np
avg_size = np.mean([len(s) for s in np.array(conformal_sets, dtype="object")])
avg_size_wellclassified = np.mean(
[
len(s)
for s in np.array(conformal_sets, dtype="object")[test_modes == test_targets]
]
)
avg_size_misclassified = np.mean(
[
len(s)
for s in np.array(conformal_sets, dtype="object")[test_modes != test_targets]
]
)
print(f"Average conformal set size: {avg_size}")
print(
f"Average conformal set size over well classified input: {avg_size_wellclassified}"
)
print(f"Average conformal set size over misclassified input: {avg_size_misclassified}")
Average conformal set size: 9.911
Average conformal set size over well classified input: 9.956672760511884
Average conformal set size over misclassified input: 9.439622641509434
此外,我们可视化了一些具有最大和最小保形集的示例。直观上,它们对应于模型对其预测最不确定或最确定的输入。
[9]:
from matplotlib import pyplot as plt
N_EXAMPLES = 10
images = test_data_loader.to_array_inputs()
def visualize_examples(indices, n_examples=N_EXAMPLES):
n_rows = min(len(indices), n_examples)
_, axs = plt.subplots(1, n_rows, figsize=(10, 2))
axs = axs.flatten()
for i, ax in enumerate(axs):
ax.imshow(images[indices[i]], cmap="gray")
ax.axis("off")
plt.show()
[10]:
indices = np.argsort([len(s) for s in np.array(conformal_sets, dtype="object")])
[11]:
print("Examples with the smallest conformal sets:")
visualize_examples(indices[:N_EXAMPLES])
Examples with the smallest conformal sets:
[12]:
print("Examples with the largest conformal sets:")
visualize_examples(np.flip(indices[-N_EXAMPLES:]))
Examples with the largest conformal sets: