functorch.functionalize¶
-
functorch.functionalize(func, *, remove='mutations')[source]¶ functionalize 是一种转换,可用于从函数中移除(中间)突变和别名,同时保留函数的语义。
functionalize(func)返回一个与func语义相同的新函数,但移除了所有中间突变。 对中间张量执行的每个原地操作:intermediate.foo_()都被其非原地等效操作替换:intermediate_updated = intermediate.foo()。functionalize 对于将 PyTorch 程序发送到无法轻松表示突变或别名操作符的后端或编译器非常有用。
- Parameters
func (Callable) – 一个接受一个或多个参数的Python函数。
remove (str) – 一个可选的字符串参数,取值为‘mutations’或‘mutations_and_views’。 如果传入‘mutations’,则所有变异操作符将被替换为其非变异等价物。 如果传入‘mutations_and_views’,则此外,所有别名操作符将被替换为其非别名等价物。 默认值:‘mutations’。
- Returns
返回一个新的“功能化”函数。它接受与
func相同的输入,并具有相同的行为,但在函数中对中间张量执行的任何突变(以及可选的别名)将被移除。
functionalize 还会移除对函数输入执行的突变(和视图)。 然而,为了保留语义,functionalize 会在转换完成后“修复”这些突变, 通过检测是否有任何张量输入“应该”被突变,并在必要时将新数据复制回输入。
示例:
>>> # xdoctest: +SKIP >>> import torch >>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch.func import functionalize >>> >>> # A function that uses mutations and views, but only on intermediate tensors. >>> def f(a): ... b = a + 1 ... c = b.view(-1) ... c.add_(1) ... return b ... >>> inpt = torch.randn(2) >>> >>> out1 = f(inpt) >>> out2 = functionalize(f)(inpt) >>> >>> # semantics are the same (outputs are equivalent) >>> print(torch.allclose(out1, out2)) True >>> >>> f_traced = make_fx(f)(inpt) >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> >>> print(f_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view = torch.ops.aten.view(add, [-1]) add_ = torch.ops.aten.add_(view, 1); view = None return add >>> print(f_no_mutations_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view = torch.ops.aten.view(add, [-1]); add = None add_1 = torch.ops.aten.add(view, 1); view = None view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None return view_1 >>> print(f_no_mutations_and_views_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view_copy = torch.ops.aten.view_copy(add, [-1]); add = None add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None return view_copy_1 >>> # A function that mutates its input tensor >>> def f(a): ... b = a.view(-1) ... b.add_(1) ... return a ... >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> # >>> # All mutations and views have been removed, >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input >>> # after the function has completed. >>> print(f_no_mutations_and_views_traced.code) def forward(self, a_1): view_copy = torch.ops.aten.view_copy(a_1, [-1]) add = torch.ops.aten.add(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None return view_copy_1
- There are a few “failure modes” for functionalize that are worth calling out:
与其他 torch.func 转换一样,functionalize() 不适用于直接使用 .backward() 的函数。对于 torch.autograd.grad 也是如此。如果你想使用自动求导,可以直接使用 functionalize(grad(f)) 计算梯度。
与其他 torch.func 转换一样,functionalize() 不适用于全局状态。 如果你在一个函数上调用 functionalize(f),该函数涉及非局部状态的视图/突变,功能化将简单地不操作,并将视图/突变调用直接传递给后端。 解决这个问题的一种方法是确保任何非局部状态的创建都被包装到一个更大的函数中,然后你对该函数调用功能化。
resize_() 有一些限制:functionalize 只能在程序中使用 resize_(),只要被调整大小的张量不是视图。
as_strided() 有一些限制:functionalize 将无法处理导致张量内存重叠的 as_strided() 调用。
最后,理解功能化的一个有用的心智模型是,大多数用户编写的PyTorch程序都是使用公共的torch API编写的。在执行时,torch操作符通常被分解为我们的内部C++“ATen”API。功能化的逻辑完全发生在ATen的层面上。功能化知道如何处理ATen中的每一个别名操作符,并将其映射到其非别名等效操作符(例如
tensor.view({-1})->at::view_copy(tensor, {-1})),以及如何处理ATen中的每一个变异操作符,并将其映射到其非变异等效操作符(例如tensor.add_(1)->at::add(tensor, -1)),同时跟踪别名和变异以知道何时修复问题。关于哪些ATen操作符是别名或变异的信息都来自https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml。警告
我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.functionalize在PyTorch 2.0中已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.func.functionalize;有关更多详细信息,请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html