Shortcuts

torch.hub 的源代码

import contextlib
import errno
import hashlib
import json
import os
import re
import shutil
import sys
import tempfile
import torch
import uuid
import warnings
import zipfile
from pathlib import Path
from typing import Dict, Optional, Any
from urllib.error import HTTPError, URLError
from urllib.request import urlopen, Request
from urllib.parse import urlparse  # noqa: F401
from torch.serialization import MAP_LOCATION

class _Faketqdm:  # type: ignore[no-redef]

    def __init__(self, total=None, disable=False,
                 unit=None, *args, **kwargs):
        self.total = total
        self.disable = disable
        self.n = 0
        # 忽略所有额外的 *args 和 **kwargs,以免你想要重新发明 tqdm

    def update(self, n):
        if self.disable:
            return

        self.n += n
        if self.total is None:
            sys.stderr.write(f"\r{self.n:.1f} bytes")
        else:
            sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%")
        sys.stderr.flush()

    # 不要费心实现;如果你想要,可以使用真正的 tqdm
    def set_description(self, *args, **kwargs):
        pass

    def write(self, s):
        sys.stderr.write(f"{s}\n")

    def close(self):
        self.disable = True

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.disable:
            return

        sys.stderr.write('\n')

try:
    from tqdm import tqdm  # 如果安装了 tqdm,则使用它,否则使用假包装器
except ImportError:
    tqdm = _Faketqdm

__all__ = [
    'download_url_to_file',
    'get_dir',
    'help',
    'list',
    'load',
    'load_state_dict_from_url',
    'set_dir',
]

# 匹配 bfd8deac 从 resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')

_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
VAR_DEPENDENCY = 'dependencies'
MODULE_HUBCONF = 'hubconf.py'
READ_DATA_CHUNK = 128 * 1024
_hub_dir: Optional[str] = None


@contextlib.contextmanager
def _add_to_sys_path(path):
    sys.path.insert(0, path)
    try:
        yield
    finally:
        sys.path.remove(path)


# 从 tools/shared/module_loader 复制,以包含在 torch 包中
def _import_module(name, path):
    import importlib.util
    from importlib.abc import Loader
    spec = importlib.util.spec_from_file_location(name, path)
    assert spec is not None