torch.autograd.Function.vmap¶
- static Function.vmap(info, in_dims, *args)[源代码]¶
在此处定义此 autograd.Function 的行为
torch.vmap()。对于支持
torch.autograd.Function()的torch.vmap(),你必须重写此静态方法,或者将generate_vmap_rule设置为True(你不能同时执行这两者)。如果你选择覆盖这个静态方法:它必须接受
一个
info对象作为第一个参数。info.batch_size指定了被 vmapped 的维度的尺寸, 而info.randomness是传递给torch.vmap()的随机性选项。作为第二个参数的
in_dims元组。 对于args中的每个参数,in_dims都有一个对应的Optional[int]。如果参数不是张量或参数不被 vmapped 覆盖,则为None,否则,它是一个整数 指定张量的哪个维度被 vmapped 覆盖。*args,这与传递给forward()的参数相同。
vmap静态方法的返回值是一个包含
(output, out_dims)的元组。 类似于in_dims,out_dims应该与output具有相同的结构,并且为每个输出包含一个out_dim,用于指定输出是否具有vmapped维度以及其索引位置。请参阅使用 autograd.Function 扩展 torch.func了解更多详情。