jax.extend.linear_util.WrappedFun#
- class jax.extend.linear_util.WrappedFun(f, transforms, stores, params, in_type, debug_info)[源代码][源代码]#
表示一个函数 f ,其中 transforms 将被应用。
- 参数:
f – 要转换的函数。
transforms – 一个 (gen, gen_static_args) 元组列表,表示要应用于 f 的转换。这里 gen 是一个生成器函数,gen_static_args 是生成器的静态参数元组。有关生成器预期行为的描述,请参见本模块开头的说明。
stores – transforms 辅助输出的 out_store 列表。
params – 传递给 f 的额外参数,以及转换后的关键字参数。
方法
__init__(f, transforms, stores, params, ...)call_wrapped(*args, **kwargs)调用底层函数,应用转换。
populate_stores(stores)将 stores 中的值复制到 self.stores 中。
wrap(gen, gen_static_args, out_store)添加另一个转换及其存储。
属性
ftransformsstoresparamsin_typedebug_info