functorch.compile.aot_function¶
-
functorch.compile.aot_function(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, num_params_buffers=0, hasher_type=None, static_argnums=None, keep_inference_input_mutations=False)[source]¶ 使用torch调度机制跟踪
fn的前向和后向图,然后通过fw_compiler和bw_compiler编译生成的前向和后向图。aot_function()提前追踪前向和后向图,并生成一个联合的前向和后向图。partition_fn然后用于分离前向和后向图。分区函数可以用于执行诸如重新计算等优化。可以设置decompositions字典,将操作符分解为后端编译器支持的核心或更简单操作符的序列。aot_function()使用基于输入张量属性的编译缓存,以检测何时需要重新编译。警告
此API是实验性的,可能会发生变化。
- Parameters
fn (可调用) – 一个接受一个或多个参数的Python函数。必须返回一个或多个张量。
fw_compiler (Callable) – 一个Python函数,它接受一个带有Aten操作和输入参数的Fx图,并返回一个在语义上等同于输入Fx图的可调用对象。
bw_compiler (可选[可调用]) – 一个Python函数,它接受一个带有Aten操作和输入参数的Fx图,并返回一个在语义上等同于输入Fx图的可调用对象。默认值:None(当为None时,默认为
fw_compiler)partition_fn (Callable) – 一个Python函数,它接受一个联合的前向和后向图,并将其分割成独立的前向和后向图。
分解 (字典) – 一个字典,用于定义将较大的Aten操作分解为更简单或核心的Aten操作。
- Returns
返回一个
Callable,它保留了原始fn的急切行为,但通过fw_compile和bw_compile编译了前向和后向图。
aot_function()的一个简单使用示例如下。此示例将打印函数fn的前向和后向图。>>> fn = lambda x : x.sin().cos() >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module >>> aot_fn = aot_function(fn, print_compile_fn) >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x)