Shortcuts

torch.func.vmap

torch.func.vmap(func, in_dims=0, out_dims=0, randomness='error', *, chunk_size=None)

vmap 是向量化映射;vmap(func) 返回一个新函数,该函数将 func 映射到输入的某个维度上。从语义上讲,vmap 将映射推入由 func 调用的 PyTorch 操作中,有效地向量化了这些操作。

vmap 对于处理批量维度非常有用:可以编写一个函数 func,该函数在示例上运行,然后将其提升为一个可以处理批量示例的函数,使用 vmap(func)。vmap 还可以与 autograd 组合使用来计算批量梯度。

注意

torch.vmap() 为了方便起见,被别名为 torch.func.vmap()。使用你喜欢的那个。

Parameters
  • func (函数) – 一个接受一个或多个参数的Python函数。 必须返回一个或多个张量。

  • in_dims (int嵌套结构) – 指定输入的哪个维度应该被映射。in_dims 应该具有与输入类似的结构。如果某个输入的 in_dim 为 None,则表示该输入没有映射维度。 默认值:0。

  • out_dims (intTuple[int]) – 指定映射维度应出现在输出中的位置。如果 out_dims 是一个元组,则它应为每个输出包含一个元素。默认值:0。

  • 随机性 (str) – 指定此 vmap 中的随机性在批次之间是相同还是不同。如果为 ‘different’,则每个批次的随机性将不同。如果为 ‘same’,则随机性将在批次之间相同。如果为 ‘error’,任何对随机函数的调用都将出错。默认值:‘error’。警告:此标志仅适用于随机的 PyTorch 操作,不适用于 Python 的 random 模块或 numpy 的随机性。

  • chunk_size (Noneint) – 如果为 None(默认),则在输入上应用单个 vmap。 如果不为 None,则一次计算 vmap chunk_size 个样本。 注意,chunk_size=1 等同于使用 for 循环计算 vmap。 如果在计算 vmap 时遇到内存问题,请尝试使用非 None 的 chunk_size。

Returns

返回一个新的“批处理”函数。它接受与 func相同的输入,除了每个输入在由 in_dims指定的索引处有一个额外的维度。它返回与 func相同的输出,除了每个输出在由 out_dims指定的索引处有一个额外的维度。

Return type

可调用

使用vmap()的一个例子是计算批量的点积。PyTorch 没有提供批量的torch.dot API;与其在文档中徒劳地寻找,不如使用vmap()来构建一个新的函数。

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot)  # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)

vmap() 可以帮助隐藏批次维度,从而简化模型编写体验。

>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>>     # 非常简单的线性模型,带有激活函数
>>>     return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)

vmap() 也可以帮助向量化那些以前难以或不可能批量处理的计算。一个例子是高阶梯度计算。 PyTorch 的自动求导引擎计算 vjps(向量-雅可比积)。 对于某个函数 f: R^N -> R^N,通常需要 N 次调用 autograd.grad 来计算完整的雅可比矩阵,每次调用对应雅可比矩阵的一行。使用 vmap(), 我们可以向量化整个计算过程,通过一次调用 autograd.grad 来计算雅可比矩阵。

>>> # 设置
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # 顺序方法
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>>                  for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # 向量化梯度计算
>>> def get_vjp(v):
>>>     return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)

vmap() 也可以嵌套使用,生成具有多个批处理维度的输出

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.vmap(torch.dot))  # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # 大小为 [2, 3] 的张量

如果输入没有沿第一个维度进行批处理,in_dims 指定每个输入沿哪个维度进行批处理,如下所示:

>>> torch.dot                            # [N], [N] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=1)  # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)   # 输出是 [5] 而不是 [2] 如果沿第0维度批处理

如果有多个输入,每个输入在不同的维度上进行批处理, in_dims 必须是一个元组,其中包含每个输入的批处理维度

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None))  # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # 第二个参数没有批次维度,因为 in_dim[1] 是 None

如果输入是一个Python结构,in_dims 必须是一个包含与输入形状匹配的结构的元组:

>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y}
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot(input)

默认情况下,输出是沿第一个维度进行批处理的。然而,可以通过使用out_dims沿任何维度进行批处理。

>>> f = lambda x: x ** 2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2]

对于使用kwargs的任何函数,返回的函数不会批量处理kwargs,但会接受kwargs

>>> x = torch.randn([2, 5])
>>> def fn(x, scale=4.):
>>>   return x * scale
>>>
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale 没有批处理,输出形状为 [2, 2, 5]

注意

vmap 不提供通用的自动批处理功能,也无法直接处理可变长度的序列。