cpp_extension
用于加载CPP / CUDA扩展的实用函数。
函数
如果当前的CUDA版本满足要求,使用torch.utils.cpp_extension.load()加载C++ / CUDA扩展。 |
- load_cpp_extension(name, sources, cuda_version_specifiers, fail_msg='', raise_if_failed=False, **load_kwargs)
如果当前的CUDA版本满足要求,使用torch.utils.cpp_extension.load()加载C++ / CUDA扩展。
由于编译的原因,第一次加载可能需要几分钟,但后续加载是即时的。
- Parameters:
name (str) – 扩展的名称。
sources (List[str | Path]) – 要编译的源文件。
cuda_version_specifiers (str | None) – 用于指定启用扩展所需的CUDA版本的说明符(例如“>=11.8,<12”)。
fail_msg (str) – 如果扩展加载失败时显示的附加消息。
raise_if_failed (bool) – 如果扩展加载失败,则引发异常。
**load_kwargs (Any) – 传递给 torch.utils.cpp_extension.load() 的关键字参数。
- Return type:
模块 | 无