用户体验限制¶
torch.func,类似于JAX,在可以转换的内容上有限制。一般来说,JAX的限制是转换只能用于纯函数:也就是说,函数的输出完全由输入决定,并且不涉及副作用(如突变)。
我们有一个类似的保证:我们的转换适用于纯函数。 然而,我们确实支持某些就地操作。一方面,编写与函数转换兼容的代码可能涉及改变您编写 PyTorch 代码的方式,另一方面,您可能会发现我们的转换让您能够表达以前在 PyTorch 中难以表达的内容。
一般限制¶
所有 torch.func 转换都有一个共同的限制,即函数不应分配给全局变量。相反,函数的所有输出都必须从函数中返回。这一限制源于 torch.func 的实现方式:每个转换都将张量输入包装在特殊的 torch.func 张量子类中,以促进转换。
因此,不要使用以下内容:
import torch
from torch.func import grad
# 不要这样做
intermediate = None
def f(x):
global intermediate
intermediate = x.sin()
z = intermediate.sin()
return z
x = torch.randn([])
grad_x = grad(f)(x)
请重写 f 以返回 intermediate:
def f(x):
intermediate = x.sin()
z = intermediate.sin()
return z, intermediate
grad_x, intermediate = grad(f, has_aux=True)(x)
torch.autograd API¶
如果你尝试在一个由vmap()或torch.func的AD变换(vjp()、jvp()、jacrev()、jacfwd())转换的函数内部使用torch.autograd API,如torch.autograd.grad或torch.autograd.backward,该变换可能无法对其进行转换。如果无法进行转换,你将收到一条错误消息。
这是PyTorch的自动微分支持实现方式中的一个基本设计限制,也是我们设计torch.func库的原因。请改用torch.func中与torch.autograd API等效的函数:
- torch.autograd.grad, Tensor.backward -> torch.func.vjp 或 torch.func.grad
- torch.autograd.functional.jvp -> torch.func.jvp
- torch.autograd.functional.jacobian -> torch.func.jacrev 或 torch.func.jacfwd
- torch.autograd.functional.hessian -> torch.func.hessian
vmap 限制¶
注意
vmap() 是我们最严格的变换。
与梯度相关的变换(grad()、vjp()、jvp())没有这些限制。jacfwd()(以及 hessian(),它是通过 jacfwd() 实现的)是 vmap() 和
jvp() 的组合,因此它也有这些限制。
vmap(func) 是一个变换,它返回一个函数,该函数将 func 映射到每个输入张量的某个新维度上。vmap 的思维模型是它类似于运行一个 for 循环:对于纯函数(即在没有副作用的情况下),vmap(f)(x) 等价于:
torch.stack([f(x_i) for x_i in x.unbind(0)])
变异: Python 数据结构的任意变异¶
在存在副作用的情况下,vmap()不再像是在运行一个for循环。例如,以下函数:
def f(x, list):
list.pop()
print("你好!")
return x.sum(0)
x = torch.randn(3, 1)
lst = [0, 1, 2, 3]
result = vmap(f, in_dims=(0, None))(x, lst)
将打印“hello!”一次,并从lst中弹出一个元素。
vmap() 执行 f 一次,因此所有副作用只会发生一次。
这是由于vmap的实现方式所导致的。torch.func有一个特殊的、内部的BatchedTensor类。vmap(f)(*inputs) 将所有Tensor输入转换为BatchedTensors,并调用 f(*batched_tensor_inputs)。BatchedTensor重写了PyTorch API,以生成每个PyTorch操作符的批处理(即向量化)行为。
变异:就地 PyTorch 操作¶
您可能是因为收到关于 vmap 不兼容的就地操作的错误而来到这里。vmap() 如果在遇到不支持的 PyTorch 就地操作时会引发错误,否则它会成功。不支持的操作是指会导致将更多元素的张量写入元素较少的张量的操作。以下是一个可能发生这种情况的示例:
def f(x, y):
x.add_(y)
return x
x = torch.randn(1)
y = torch.randn(3, 1) # 当被vmapped时,看起来它具有形状[1]
# 引发错误,因为`x`的元素数量少于`y`。
vmap(f, in_dims=(None, 0))(x, y)
x 是一个包含一个元素的张量,y 是一个包含三个元素的张量。
x + y 有三个元素(由于广播),但尝试将三个元素写回 x,而 x 只有一个元素,会引发错误,因为尝试将三个元素写入一个只有一个元素的张量。
如果正在写入的张量在vmap()下是批处理的(即它正在被vmap处理),则没有问题。
def f(x, y):
x.add_(y)
return x
x = torch.randn(3, 1)
y = torch.randn(3, 1)
expected = x + y
# 不会引发错误,因为 x 正在被 vmapped 处理。
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)
常见的解决方法是将调用工厂函数的代码替换为其“new_*”等价形式。例如:
将
torch.zeros()替换为Tensor.new_zeros()将
torch.empty()替换为Tensor.new_empty()
要了解为什么这有帮助,请考虑以下内容。
def diag_embed(vec):
assert vec.dim() == 1
result = torch.zeros(vec.shape[0], vec.shape[0])
result.diagonal().copy_(vec)
return result
vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
# 运行时错误: vmap: 不支持原地算术运算(self, *extra_args) ...
vmap(diag_embed)(vecs)
在 vmap() 内部,result 是一个形状为 [3, 3] 的张量。
然而,尽管 vec 看起来具有形状 [3],vec 实际上具有底层形状 [2, 3]。
无法将 vec 复制到 result.diagonal() 中,后者具有形状 [3],因为它包含的元素过多。
def diag_embed(vec):
assert vec.dim() == 1
result = vec.new_zeros(vec.shape[0], vec.shape[0])
result.diagonal().copy_(vec)
return result
vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
vmap(diag_embed)(vecs)
将 torch.zeros() 替换为 Tensor.new_zeros() 使得
result 具有形状为 [2, 3, 3] 的底层张量,因此现在可以将具有底层形状 [2, 3] 的 vec 复制到 result.diagonal() 中。
变异: out= PyTorch 操作¶
vmap() 不支持在 PyTorch 操作中使用 out= 关键字参数。如果在代码中遇到这种情况,它将优雅地报错。
这并不是一个根本性的限制;理论上我们可以在未来支持这一点,但我们目前选择不这样做。
数据依赖的Python控制流¶
我们尚不支持在依赖数据的控制流上使用 vmap。依赖数据的控制流是指 if 语句、while 循环或 for 循环的条件是一个正在被 vmap 处理的 Tensor。例如,以下代码将会引发错误信息:
def relu(x):
if x > 0:
return x
return 0
x = torch.randn(3)
vmap(relu)(x)
然而,任何不依赖于 vmap 张量值的控制流都可以正常工作:
def custom_dot(x):
if x.dim() == 1:
return torch.dot(x, x)
return (x * x).sum()
x = torch.randn(3)
vmap(custom_dot)(x)
JAX 支持使用特殊的控制流操作符(例如 jax.lax.cond、jax.lax.while_loop)对依赖数据的控制流进行转换。我们正在研究为 PyTorch 添加这些的等效功能。
数据依赖操作 (.item())¶
我们不支持(也不会支持)在调用.item()的用户定义函数上使用vmap。例如,以下代码将引发错误信息:
def f(x):
return x.item()
x = torch.randn(3)
vmap(f)(x)
请尝试重写您的代码,以避免使用 .item() 调用。
您可能还会遇到关于使用 .item() 的错误消息,但您可能并未使用它。在这些情况下,可能是 PyTorch 内部在调用 .item() – 请在 GitHub 上提交问题,我们将修复 PyTorch 内部问题。
动态形状操作(非零及其他相关操作)¶
vmap(f) 要求 f 应用于输入中的每个“示例”时,返回一个形状相同的张量。诸如 torch.nonzero、torch.is_nonzero 等操作不受支持,并因此会导致错误。
要了解原因,请考虑以下示例:
xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)
torch.nonzero(xs[0]) 返回一个形状为 2 的张量;
但 torch.nonzero(xs[1]) 返回一个形状为 1 的张量。
我们无法构建一个单一的张量作为输出;
输出将需要是一个不规则张量(而 PyTorch 尚未有
不规则张量的概念)。
随机性¶
用户在调用随机操作时的意图可能不明确。具体来说,一些用户可能希望随机行为在批次之间保持一致,而另一些用户可能希望它在批次之间有所不同。为了解决这个问题,vmap 采用了一个随机性标志。
该标志只能传递给 vmap,并且可以取 3 个值,“error”、“different”或“same”,默认为 error。在“error”模式下,任何对随机函数的调用都会产生一个错误,要求用户根据其使用情况使用其他两个标志之一。
在“不同”的随机性下,批次中的元素会产生不同的随机值。例如,
def add_noise(x):
y = torch.randn(()) # y 在批次中会有所不同
return x + y
x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x) # 我们得到3个不同的值
在“相同”的随机性下,批次中的元素会产生相同的随机值。例如,
def add_noise(x):
y = torch.randn(()) # y 在整个批次中将是相同的
return x + y
x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x) # 我们得到相同的值,重复3次
警告
我们的系统只能确定PyTorch操作符的随机性行为,无法控制其他库(如numpy)的行为。这与JAX在其解决方案中的限制类似
注意
使用任一类型的支持随机性的多次vmap调用不会产生相同的结果。与标准PyTorch一样,用户可以通过在vmap外部使用torch.manual_seed()或使用生成器来获得随机性可重复性。
注意
最后,我们的随机性不同于JAX,因为我们没有使用无状态的伪随机数生成器(PRNG),部分原因是PyTorch对无状态PRNG的支持并不完全。相反,我们引入了一个标志系统,以支持我们常见的大多数随机性形式。如果您的使用场景不符合这些随机性形式,请提交问题。