HDC 学习
在学习了如何在超空间中表示和操作信息之后,我们可以实现我们的第一个HDC分类模型!我们将以著名的MNIST数据集为例,该数据集包含手写数字的图像。
我们首先导入Torchhd和其他所需的库,并指定训练参数:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
import torchmetrics
import torchhd
from torchhd.models import Centroid
from torchhd import embeddings
# Use the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))
DIMENSIONS = 10000
IMG_SIZE = 28
NUM_LEVELS = 1000
BATCH_SIZE = 1 # for GPUs with enough memory we can process multiple images at ones
数据集
接下来,我们加载训练和测试数据集:
transform = torchvision.transforms.ToTensor()
train_ds = MNIST("../data", train=True, transform=transform, download=True)
train_ld = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_ds = MNIST("../data", train=False, transform=transform, download=True)
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
除了Torch生态系统中可用的各种数据集,如MNIST,torchhd.datasets模块还提供了对HDC中几个常用数据集的接口。这些接口继承自PyTorch的数据集类,确保与其他数据集的互操作性。
培训
为了执行训练,我们首先定义一个编码。除了指定基础超向量集之外,学习的核心部分是编码函数。在下面的示例中,我们分别使用随机超向量和级别超向量来编码每个像素的位置和值:
class Encoder(nn.Module):
def __init__(self, out_features, size, levels):
super(Encoder, self).__init__()
self.flatten = torch.nn.Flatten()
self.position = embeddings.Random(size * size, out_features)
self.value = embeddings.Level(levels, out_features)
def forward(self, x):
x = self.flatten(x)
sample_hv = torchhd.bind(self.position.weight, self.value(x))
sample_hv = torchhd.multiset(sample_hv)
return torchhd.hard_quantize(sample_hv)
encode = Encoder(DIMENSIONS, IMG_SIZE, NUM_LEVELS)
encode = encode.to(device)
num_classes = len(train_ds.classes)
model = Centroid(DIMENSIONS, num_classes)
model = model.to(device)
定义模型后,我们遍历训练样本以创建类向量:
with torch.no_grad():
for samples, labels in tqdm(train_ld, desc="Training"):
samples = samples.to(device)
labels = labels.to(device)
samples_hv = encode(samples)
model.add(samples_hv, labels)
测试
模型训练完成后,我们可以通过编码测试样本并将其与类别向量进行比较来进行分类:
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
with torch.no_grad():
model.normalize()
for samples, labels in tqdm(test_ld, desc="Testing"):
samples = samples.to(device)
samples_hv = encode(samples)
outputs = model(samples_hv, dot=True)
accuracy.update(outputs.cpu(), labels)
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")