torch.func.functionalize¶
- torch.func.functionalize(func, *, remove='mutations')¶
functionalize 是一种转换,可以用来移除(中间)突变和别名,同时保留函数的语义。
functionalize(func)返回一个具有相同语义的新函数, 但删除了所有中间突变。 在中间张量上执行的每个就地操作:intermediate.foo_()被替换为其非就地等效操作:intermediate_updated = intermediate.foo()。functionalize 对于将 PyTorch 程序发送到无法轻松表示突变或别名操作的后端或编译器非常有用。
- Parameters
func (可调用对象) – 一个接受一个或多个参数的Python函数。
remove (str) – 一个可选的字符串参数,取值为“mutations”或“mutations_and_views”。 如果传入“mutations”,则所有可变操作符将被替换为其不可变等效操作符。 如果传入“mutations_and_views”,则此外,所有别名操作符将被替换为其非别名等效操作符。 默认值:“mutations”。
- Returns
返回一个新的“功能化”函数。它接受与
func相同的输入,并具有相同的行为,但在函数中对中间张量执行的任何突变(以及可选的别名)都将被移除。- Return type
functionalize 还会移除对函数输入执行的突变(和视图)。 然而,为了保持语义,functionalize 会在转换完成后通过检测是否有任何张量输入“应该”被突变,并在必要时将新数据复制回输入来“修复”这些突变。
示例:
```html>>> import torch >>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch.func import functionalize >>> >>> # 一个使用突变和视图的函数,但仅在中间张量上使用。 >>> def f(a): ... b = a + 1 ... c = b.view(-1) ... c.add_(1) ... return b ... >>> inpt = torch.randn(2) >>> >>> out1 = f(inpt) >>> out2 = functionalize(f)(inpt) >>> >>> # 语义相同(输出是等价的) >>> print(torch.allclose(out1, out2)) True >>> >>> f_traced = make_fx(f)(inpt) >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> >>> print(f_traced.code)
- There are a few “failure modes” for functionalize that are worth calling out:
与其他 torch.func 转换一样,functionalize() 不适用于直接使用 .backward() 的函数。对于 torch.autograd.grad 也是如此。如果你想使用 autograd,可以直接使用 functionalize(grad(f)) 计算梯度。
与其他 torch.func 转换一样,functionalize() 不支持全局状态。 如果你对一个包含非局部状态的视图/突变的函数调用 functionalize(f),功能化将简单地无操作并将视图/突变调用直接传递给后端。 解决这个问题的一种方法是确保任何非局部状态的创建都被包装到一个更大的函数中,然后对这个函数调用 functionalize。
resize_() 有一些限制:functionalize 只能用于使用 resize_() 的程序,只要被调整大小的张量不是一个视图。
as_strided() 有一些限制:functionalize 将无法在导致内存重叠的张量的 as_strided() 调用上工作。
最后,理解功能化的一个有用的思维模型是,大多数用户使用公共 torch API 编写的 pytorch 程序。 当执行时,torch 操作符通常会被分解为我们的内部 C++ “ATen” API。 功能化的逻辑完全发生在 ATen 级别。 功能化知道如何将 ATen 中的每个别名操作符映射到其非别名等效操作符 (例如
tensor.view({-1})->at::view_copy(tensor, {-1})), 以及如何将 ATen 中的每个变异操作符映射到其非变异等效操作符 (例如tensor.add_(1)->at::add(tensor, -1)), 同时跟踪别名和变异以了解何时修复事物。 关于哪些 ATen 操作符是别名或变异的信息全部来自 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml。