functorch.vmap¶
-
functorch.vmap(func, in_dims=0, out_dims=0, randomness='error', *, chunk_size=None)[source]¶ vmap 是向量化映射;
vmap(func)返回一个新函数,该函数将func映射到输入的某个维度上。从语义上讲,vmap 将映射推入由func调用的 PyTorch 操作中,从而有效地向量化这些操作。vmap 对于处理批量维度非常有用:可以编写一个在示例上运行的函数
func,然后将其提升为可以使用vmap(func)处理批量示例的函数。vmap 还可以与 autograd 结合使用来计算批量梯度。注意
torch.vmap()是torch.func.vmap()的别名,为了方便使用。你可以随意使用任何一个。- Parameters
func (function) – 一个接受一个或多个参数的Python函数。 必须返回一个或多个张量。
in_dims (int 或 嵌套结构) – 指定输入的哪个维度应该被映射。
in_dims应该具有与输入类似的结构。如果某个特定输入的in_dim为 None,则表示没有映射维度。 默认值:0。out_dims (int 或 Tuple[int]) – 指定映射维度在输出中出现的位置。如果
out_dims是一个元组,那么它应该为每个输出包含一个元素。默认值:0。随机性 (str) – 指定此vmap中的随机性在批次之间是相同还是不同。如果为‘different’,则每个批次的随机性将不同。如果为‘same’,则批次之间的随机性将相同。如果为‘error’,则任何对随机函数的调用都将出错。默认值:‘error’。警告:此标志仅适用于随机PyTorch操作,不适用于Python的随机模块或numpy的随机性。
chunk_size (None 或 int) – 如果为 None(默认),则对输入应用单个 vmap。 如果不为 None,则每次计算 vmap 的
chunk_size个样本。 请注意,chunk_size=1等同于使用 for 循环计算 vmap。 如果在计算 vmap 时遇到内存问题,请尝试使用非 None 的 chunk_size。
- Returns
返回一个新的“批处理”函数。它接受与
func相同的输入,除了每个输入在in_dims指定的索引处有一个额外的维度。它返回与func相同的输出,除了每个输出在out_dims指定的索引处有一个额外的维度。
使用
vmap()的一个例子是计算批量点积。PyTorch 没有提供批量的torch.dotAPI;与其在文档中徒劳地搜索,不如使用vmap()来构建一个新函数。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y)
vmap()可以帮助隐藏批次维度,从而简化模型编写体验。>>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) >>> result = torch.vmap(model)(examples)
vmap()还可以帮助向量化以前难以或无法批处理的计算。一个例子是高阶梯度计算。PyTorch 的自动求导引擎计算 vjps(向量-雅可比积)。对于某些函数 f: R^N -> R^N,计算完整的雅可比矩阵通常需要对autograd.grad进行 N 次调用,每次调用对应雅可比矩阵的一行。使用vmap(),我们可以将整个计算向量化,只需一次调用autograd.grad即可计算雅可比矩阵。>>> # Setup >>> N = 5 >>> f = lambda x: x ** 2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) >>> >>> # Sequential approach >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] >>> for v in I_N.unbind()] >>> jacobian = torch.stack(jacobian_rows) >>> >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) >>> jacobian = torch.vmap(get_vjp)(I_N)
vmap()也可以嵌套,生成具有多个批处理维度的输出>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3]
如果输入没有沿第一个维度进行批处理,
in_dims指定每个输入沿哪个维度进行批处理。>>> torch.dot # [N], [N] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
如果有多个输入,每个输入沿着不同的维度进行批处理,
in_dims必须是一个元组,其中包含每个输入的批处理维度。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
如果输入是Python结构体,
in_dims必须是一个包含与输入形状匹配的结构体的元组:>>> f = lambda dict: torch.dot(dict['x'], dict['y']) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {'x': x, 'y': y} >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) >>> batched_dot(input)
默认情况下,输出沿第一维度进行批处理。然而,可以通过使用
out_dims沿任何维度进行批处理。>>> f = lambda x: x ** 2 >>> x = torch.randn(2, 5) >>> batched_pow = torch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2]
对于任何使用kwargs的函数,返回的函数不会对kwargs进行批处理,但会接受kwargs
>>> x = torch.randn([2, 5]) >>> def fn(x, scale=4.): >>> return x * scale >>> >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
注意
vmap 不提供通用的自动批处理或开箱即用地处理可变长度序列。
警告
我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.vmap自PyTorch 2.0起已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.vmap;更多详情请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html