Shortcuts

torch.autograd.Function.vmap

static Function.vmap(info, in_dims, *args)[源代码]

在此处定义此 autograd.Function 的行为 torch.vmap()

对于支持torch.autograd.Function()torch.vmap(),你必须重写此静态方法,或者将generate_vmap_rule设置为True(你不能同时执行这两者)。

如果你选择覆盖这个静态方法:它必须接受

  • 一个 info 对象作为第一个参数。info.batch_size 指定了被 vmapped 的维度的尺寸, 而 info.randomness 是传递给 torch.vmap() 的随机性选项。

  • 作为第二个参数的 in_dims 元组。 对于 args 中的每个参数,in_dims 都有一个对应的 Optional[int]。如果参数不是张量或参数不被 vmapped 覆盖,则为 None,否则,它是一个整数 指定张量的哪个维度被 vmapped 覆盖。

  • *args,这与传递给forward()的参数相同。

vmap静态方法的返回值是一个包含(output, out_dims)的元组。 类似于in_dimsout_dims应该与 output具有相同的结构,并且为每个输出包含一个out_dim,用于指定输出是否具有vmapped维度以及其索引位置。

请参阅使用 autograd.Function 扩展 torch.func了解更多详情。

优云智算