cpp_extension

用于加载CPP / CUDA扩展的实用函数。

函数

load_cpp_extension

如果当前的CUDA版本满足要求,使用torch.utils.cpp_extension.load()加载C++ / CUDA扩展。

load_cpp_extension(name, sources, cuda_version_specifiers, fail_msg='', raise_if_failed=False, **load_kwargs)

如果当前的CUDA版本满足要求,使用torch.utils.cpp_extension.load()加载C++ / CUDA扩展。

由于编译的原因,第一次加载可能需要几分钟,但后续加载是即时的。

Parameters:
  • name (str) – 扩展的名称。

  • sources (List[str | Path]) – 要编译的源文件。

  • cuda_version_specifiers (str | None) – 用于指定启用扩展所需的CUDA版本的说明符(例如“>=11.8,<12”)。

  • fail_msg (str) – 如果扩展加载失败时显示的附加消息。

  • raise_if_failed (bool) – 如果扩展加载失败,则引发异常。

  • **load_kwargs (Any) – 传递给 torch.utils.cpp_extension.load() 的关键字参数。

Return type:

模块 |