Shortcuts

torch.backends.cudnn 的源代码

```html
import os
import sys
import warnings
from contextlib import contextmanager
from typing import Optional

import torch
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule

try:
    from torch._C import _cudnn
except ImportError:
    _cudnn = None  # type: ignore[assignment]

# 写入:
#
#   torch.backends.cudnn.enabled = False
#
# 以全局禁用 CuDNN/MIOpen

__cudnn_version: Optional[int] = None

if _cudnn is not None:

    def _init():
        global __cudnn_version
        if __cudnn_version is None:
            __cudnn_version = _cudnn.getVersionInt()
            runtime_version = _cudnn.getRuntimeVersion()
            compile_version = _cudnn.getCompileVersion()
            runtime_major, runtime_minor, _ = runtime_version
            compile_major, compile_minor, _ = compile_version
            # 不同的主版本总是不兼容的
            # 从 cuDNN 7 开始,次版本是向后兼容的
            # 不确定 MIOpen(ROCm),所以总是进行严格的检查
            if runtime_major != compile_major:
                cudnn_compatible = False
            elif runtime_major < 7 or not _cudnn.is_cuda:
                cudnn_compatible = runtime_minor == compile_minor
            else:
                cudnn_compatible = runtime_minor >= compile_minor
            if not cudnn_compatible:
                if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1":
                    return True
                base_error_msg = (
                    f"cuDNN 版本不兼容: "
                    f"PyTorch 是针对 {compile_version} 编译的 "
                    f"但找到了运行时版本 {runtime_version}. "
                    f"PyTorch 已经捆绑了 cuDNN. "
                    f"解决此错误的一个选项是确保 PyTorch 可以找到捆绑的 cuDNN. "
                )

                if "LD_LIBRARY_PATH" in os.environ:
                    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
                    if any(
                        substring in ld_library_path for substring in ["cuda", "cudnn"]
                    ):
                        raise RuntimeError(
                            f"{base_error_msg}"
                            f"看起来你的 LD_LIBRARY_PATH 包含了不兼容版本的 cudnn. "
                            f"请从路径中删除它或安装 cudnn {compile_version}"
                        )
                    else:
                        raise RuntimeError(
                            f"{base_error_msg}"
                            f"一个可能性是 LD_LIBRARY_PATH 中存在冲突的 cuDNN."
                        )
                else:
                    raise RuntimeError(base_error_msg)

        return True

else:

    def _init():
        return False


[docs]def version(): """返回 cuDNN 的版本。""" if not _init(): return None return __cudnn_version
CUDNN_TENSOR_DTYPES = { torch.half, torch.float, torch.double, }
[docs]def is_available(): r"""返回一个布尔值,指示 CUDNN 当前是否可用。""" return torch._C._has_cudnn
def is_acceptable(tensor): if not torch._C._get_cudnn_enabled(): return False if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES: return False if not is_available(): warnings.warn( "PyTorch 是在没有 cuDNN/MIOpen 支持的情况下编译的。要使用 cuDNN/MIOpen,请重新构建 " "PyTorch,确保库对构建系统可见。" ) return False if not _init(): warnings.warn( "未找到 cuDNN/MIOpen 库。请检查您的 {libpath}".format( libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "
优云智算