Shortcuts

每个样本的梯度

Open In Colab

这是什么?

逐样本梯度计算是计算一批数据中每个样本的梯度。它在差分隐私、元学习和优化研究中是一个有用的量。

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

torch.manual_seed(0);
# Here's a simple CNN and loss function:

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        output = x
        return output

def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

让我们生成一批虚拟数据,并假装我们正在处理一个MNIST数据集。

虚拟图像的大小为28x28,我们使用大小为64的小批量。

device = 'cuda'

num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)

targets = torch.randint(10, (64,), device=device)

在常规模型训练中,人们会通过模型前向传播小批量数据,然后调用.backward()来计算梯度。这将生成整个小批量的“平均”梯度:

model = SimpleCNN().to(device=device)
predictions = model(data) # move the entire mini-batch through the model

loss = loss_fn(predictions, targets)
loss.backward() # back propogate the 'average' gradient of this mini-batch

与上述方法相反,每个样本的梯度计算等同于:

  • 对于数据的每个单独样本,执行前向和后向传递以获取单个(每个样本的)梯度。

def compute_grad(sample, target):
    
    sample = sample.unsqueeze(0)  # prepend batch dimension for processing
    target = target.unsqueeze(0)

    prediction = model(sample)
    loss = loss_fn(prediction, target)

    return torch.autograd.grad(loss, list(model.parameters()))


def compute_sample_grads(data, targets):
    """ manually process each sample with per sample gradient """
    sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
    sample_grads = zip(*sample_grads)
    sample_grads = [torch.stack(shards) for shards in sample_grads]
    return sample_grads

per_sample_grads = compute_sample_grads(data, targets)

sample_grads[0] 是模型.conv1.weight的每个样本梯度。model.conv1.weight.shape[32, 1, 3, 3];注意在批次中每个样本有一个梯度,总共有64个。

print(per_sample_grads[0].shape)
torch.Size([64, 32, 1, 3, 3])

每个样本的梯度,高效的方法,使用 functorch

我们可以通过使用函数变换有效地计算每个样本的梯度。

首先,让我们使用functorch.make_functional_with_buffers创建一个无状态的功能版本的model

这将把状态(参数)与模型分离,并将模型转换为纯函数:

from functorch import make_functional_with_buffers, vmap, grad

fmodel, params, buffers = make_functional_with_buffers(model)

让我们回顾一下变化 - 首先,模型已经变成了无状态的 FunctionalModuleWithBuffers:

fmodel
FunctionalModuleWithBuffers(
  (stateless_model): SimpleCNN(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (fc1): Linear(in_features=9216, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=10, bias=True)
  )
)

模型参数现在独立于模型存在,存储为一个元组:

for x in params:
  print(f"{x.shape}")

print(f"\n{type(params)}")
torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([64, 32, 3, 3])
torch.Size([64])
torch.Size([128, 9216])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])

<class 'tuple'>

接下来,让我们定义一个函数来计算模型在给定单个输入而不是一批输入时的损失。这个函数接受参数、输入和目标非常重要,因为我们将在它们上进行转换。

注意 - 因为模型最初是为了处理批次而编写的,我们将使用 torch.unsqueeze 来添加一个批次维度。

def compute_loss_stateless_model (params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = loss_fn(predictions, targets)
    return loss

现在,让我们使用functorch的grad来创建一个新函数,该函数计算相对于compute_loss的第一个参数(即参数)的梯度。

ft_compute_grad = grad(compute_loss_stateless_model)

ft_compute_grad 函数计算单个(样本,目标)对的梯度。我们可以使用 vmap 来使其计算整个批次的样本和目标的梯度。注意 in_dims=(None, None, 0, 0) 因为我们希望将 ft_compute_grad 映射到数据和目标的第0维度,并为每个使用相同的参数和缓冲区。

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

最后,让我们使用转换后的函数来计算每个样本的梯度:

ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)

# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

快速说明:vmap 可以转换的函数类型存在一些限制。最适合转换的函数是纯函数:即输出仅由输入决定,并且没有副作用(例如突变)的函数。vmap 无法处理任意 Python 数据结构的突变,但它能够处理许多原地 PyTorch 操作。

性能比较

想了解vmap的性能如何比较吗?

目前,在较新的GPU(如A100(Ampere))上获得了最佳结果,在这个示例中我们看到了高达25倍的加速,但以下是一些在Colab中完成的结果:

def get_perf(first, first_descriptor, second, second_descriptor):
  """  takes torch.benchmark objects and compares delta of second vs first. """
  second_res = second.times[0]
  first_res = first.times[0]

  gain = (first_res-second_res)/first_res
  if gain < 0: gain *=-1 
  final_gain = gain*100

  print(f" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
from torch.utils.benchmark import Timer

without_vmap = Timer( stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)

print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')
Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f71ac3f1850>
compute_sample_grads(data, targets)
  79.86 ms
  1 measurement, 100 runs , 1 thread
Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f7143e26f10>
ft_compute_sample_grad(params, buffers, data, targets)
  12.93 ms
  1 measurement, 100 runs , 1 thread
get_perf(with_vmap_timing, "vmap", no_vmap_timing,"no vmap" )
 Performance delta: 517.5791 percent improvement with vmap 

在PyTorch中,还有其他优化的解决方案(如https://github.com/pytorch/opacus)用于计算每个样本的梯度,这些方法也比简单的方法表现更好。但很酷的是,组合vmapgrad可以给我们带来很好的加速效果。

一般来说,使用vmap进行向量化应该比在for循环中运行函数更快,并且与手动批处理具有竞争力。不过也有一些例外,比如如果我们还没有为特定操作实现vmap规则,或者底层内核没有针对旧硬件(GPU)进行优化。如果您发现任何这些情况,请在我们的GitHub上提交问题告知我们!