Shortcuts

用户体验限制

functorch,像JAX一样,对可以转换的内容有一些限制。一般来说,JAX的限制是转换仅适用于纯函数:即输出完全由输入决定且不涉及副作用(如突变)的函数。

我们有一个类似的保证:我们的转换与纯函数配合得很好。 然而,我们确实支持某些原地操作。一方面,编写与functorch转换兼容的代码可能涉及改变你编写PyTorch代码的方式,另一方面,你可能会发现我们的转换让你能够表达以前在PyTorch中难以表达的内容。

一般限制

所有functorch变换都有一个共同的限制,即函数不应分配给全局变量。相反,函数的所有输出都必须从函数中返回。这个限制来自于functorch的实现方式:每个变换都将Tensor输入包装在特殊的functorch Tensor子类中,以促进变换。

所以,不要使用以下内容:

import torch
from functorch import grad

# Don't do this
intermediate = None

def f(x):
  global intermediate
  intermediate = x.sin()
  z = intermediate.sin()
  return z

x = torch.randn([])
grad_x = grad(f)(x)

请重写 f 以返回 intermediate

def f(x):
  intermediate = x.sin()
  z = intermediate.sin()
  return z, intermediate

grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd API

如果您尝试在由vmap()或functorch的AD变换(vjp()jvp()jacrev()jacfwd())转换的函数中使用torch.autograd API,如torch.autograd.gradtorch.autograd.backward,则变换可能无法对其进行变换。如果无法进行变换,您将收到错误消息。

这是PyTorch的AD支持实现中的一个基本设计限制,也是我们设计functorch库的原因。请改用functorch中与torch.autograd API等效的功能: - torch.autograd.grad, Tensor.backward -> functorch.vjpfunctorch.grad - torch.autograd.functional.jvp -> functorch.jvp - torch.autograd.functional.jacobian -> functorch.jacrevfunctorch.jacfwd - torch.autograd.functional.hessian -> functorch.hessian

vmap 限制

注意

vmap() is our most restrictive transform. The grad-related transforms (grad(), vjp(), jvp()) do not have these limitations. jacfwd() (and hessian(), which is implemented with jacfwd()) is a composition of vmap() and jvp() so it also has these limitations.

vmap(func) 是一个转换,它返回一个函数,该函数将 func 映射到每个输入张量的某个新维度上。vmap 的心理模型是它类似于运行一个 for 循环:对于纯函数(即没有副作用的情况),vmap(f)(x) 等价于:

torch.stack([f(x_i) for x_i in x.unbind(0)])

突变:Python数据结构的任意突变

在存在副作用的情况下,vmap() 不再像运行一个 for 循环那样工作。例如,以下函数:

def f(x, list):
  list.pop()
  print("hello!")
  return x.sum(0)

x = torch.randn(3, 1)
lst = [0, 1, 2, 3]

result = vmap(f, in_dims=(0, None))(x, lst)

将打印“hello!”一次,并从lst中仅弹出一个元素。

vmap() 执行 f 一次,因此所有副作用只发生一次。

这是vmap实现方式的结果。functorch有一个特殊的内部BatchedTensor类。vmap(f)(*inputs)会接收所有的Tensor输入,将它们转换为BatchedTensors,并调用f(*batched_tensor_inputs)。BatchedTensor重写了PyTorch API,以便为每个PyTorch操作符生成批处理(即向量化)行为。

突变:原地PyTorch操作

vmap() 如果遇到不支持的 PyTorch 原地操作,将会引发错误,否则将会成功。不支持的操作是那些会导致具有更多元素的张量被写入具有较少元素的张量的操作。以下是一个可能发生这种情况的示例:

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(1)
y = torch.randn(3)

# Raises an error because `y` has fewer elements than `x`.
vmap(f, in_dims=(None, 0))(x, y)

x 是一个包含一个元素的张量,y 是一个包含三个元素的张量。 x + y 有三个元素(由于广播机制),但尝试将三个元素写回 x,它只有一个元素,会引发错误,因为尝试将三个元素写入一个只有一个元素的张量。

如果要写入的张量具有相同数量的元素(或更多),则没有问题:

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(3)
y = torch.randn(3)
expected = x + y

# Does not raise an error because x and y have the same number of elements.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

突变: 输出= PyTorch 操作

vmap() 在 PyTorch 操作中不支持 out= 关键字参数。 如果在代码中遇到这种情况,它会优雅地报错。

这不是一个根本性的限制;理论上我们可以在未来支持这一点,但目前我们选择不这样做。

数据依赖的Python控制流

我们目前还不支持对数据依赖控制流进行vmap操作。数据依赖控制流是指if语句、while循环或for循环的条件是一个正在被vmap操作的Tensor。例如,以下代码将引发错误信息:

def relu(x):
  if x > 0:
    return x
  return 0

x = torch.randn(3)
vmap(relu)(x)

然而,任何不依赖于vmap张量值的控制流都将有效:

def custom_dot(x):
  if x.dim() == 1:
    return torch.dot(x, x)
  return (x * x).sum()

x = torch.randn(3)
vmap(custom_dot)(x)

JAX 支持通过使用特殊的控制流操作符(例如 jax.lax.cond, jax.lax.while_loop)来转换数据依赖的控制流。我们正在研究在 functorch 中添加这些操作符的等效功能(在 GitHub 上提出问题以表达您的支持!)。

依赖于数据的操作 (.item())

我们不(也不会)支持在调用.item()的用户自定义函数上使用vmap。例如,以下代码将引发错误消息:

def f(x):
  return x.item()

x = torch.randn(3)
vmap(f)(x)

请尝试重写您的代码,不要使用.item()调用。

您可能还会遇到关于使用.item()的错误消息,但您可能并未使用它。在这些情况下,可能是PyTorch内部正在调用.item()——请在GitHub上提交问题,我们将修复PyTorch的内部问题。

动态形状操作(nonzero 及其相关函数)

vmap(f) 要求 f 应用于输入中的每个“示例”时返回一个具有相同形状的张量。诸如 torch.nonzerotorch.is_nonzero 等操作不受支持,因此会导致错误。

要了解原因,请考虑以下示例:

xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)

torch.nonzero(xs[0]) 返回一个形状为2的张量; 但是 torch.nonzero(xs[1]) 返回一个形状为1的张量。 我们无法构造一个单一的张量作为输出; 输出将需要是一个不规则的张量(而PyTorch目前还没有不规则张量的概念)。

随机性

用户在调用随机操作时的意图可能不明确。具体来说,一些用户可能希望随机行为在批次之间保持一致,而另一些用户可能希望它在批次之间有所不同。为了解决这个问题,vmap 接受一个随机性标志。

该标志只能传递给vmap,并且可以取3个值,“error”、“different”或“same”,默认为error。在“error”模式下,任何对随机函数的调用都会产生一个错误,要求用户根据他们的使用情况使用其他两个标志之一。

在“不同”的随机性下,批次中的元素会产生不同的随机值。例如,

def add_noise(x):
  y = torch.randn(())  # y will be different across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x)  # we get 3 different values

在“相同”随机性下,批次中的元素产生相同的随机值。例如,

def add_noise(x):
  y = torch.randn(())  # y will be the same across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times

警告

我们的系统仅确定PyTorch操作符的随机性行为,无法控制其他库(如numpy)的行为。这与JAX解决方案的限制类似。

注意

使用任何一种支持的随机性进行多次vmap调用不会产生相同的结果。与标准的PyTorch一样,用户可以通过在vmap外部使用torch.manual_seed()或使用生成器来实现随机性的可重复性。

注意

最后,我们的随机性与JAX不同,因为我们没有使用无状态的伪随机数生成器(PRNG),部分原因是PyTorch对无状态PRNG的支持不完全。相反,我们引入了一个标志系统,以允许我们最常见的随机性形式。如果您的使用场景不符合这些随机性形式,请提交一个问题。