Shortcuts

functorch.compile.aot_function

functorch.compile.aot_function(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, num_params_buffers=0, hasher_type=None, static_argnums=None, keep_inference_input_mutations=False)[source]

使用torch调度机制跟踪fn的前向和后向图,然后通过fw_compilerbw_compiler编译生成的前向和后向图。

aot_function() 提前追踪前向和后向图,并生成一个联合的前向和后向图。partition_fn 然后用于分离前向和后向图。分区函数可以用于执行诸如重新计算等优化。可以设置decompositions字典,将操作符分解为后端编译器支持的核心或更简单操作符的序列。

aot_function() 使用基于输入张量属性的编译缓存,以检测何时需要重新编译。

警告

此API是实验性的,可能会发生变化。

Parameters
  • fn (可调用) – 一个接受一个或多个参数的Python函数。必须返回一个或多个张量。

  • fw_compiler (Callable) – 一个Python函数,它接受一个带有Aten操作和输入参数的Fx图,并返回一个在语义上等同于输入Fx图的可调用对象。

  • bw_compiler (可选[可调用]) – 一个Python函数,它接受一个带有Aten操作和输入参数的Fx图,并返回一个在语义上等同于输入Fx图的可调用对象。默认值:None(当为None时,默认为fw_compiler

  • partition_fn (Callable) – 一个Python函数,它接受一个联合的前向和后向图,并将其分割成独立的前向和后向图。

  • 分解 (字典) – 一个字典,用于定义将较大的Aten操作分解为更简单或核心的Aten操作。

Returns

返回一个Callable,它保留了原始fn的急切行为,但通过fw_compilebw_compile编译了前向和后向图。

aot_function() 的一个简单使用示例如下。此示例将打印函数 fn 的前向和后向图。

>>> fn = lambda x : x.sin().cos()
>>> def print_compile_fn(fx_module, args):
>>>     print(fx_module)
>>>     return fx_module
>>> aot_fn = aot_function(fn, print_compile_fn)
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)