torch.hub 的源代码
import contextlib
import errno
import hashlib
import json
import os
import re
import shutil
import sys
import tempfile
import torch
import uuid
import warnings
import zipfile
from pathlib import Path
from typing import Dict, Optional, Any
from urllib.error import HTTPError, URLError
from urllib.request import urlopen, Request
from urllib.parse import urlparse # noqa: F401
from torch.serialization import MAP_LOCATION
class _Faketqdm: # type: ignore[no-redef]
def __init__(self, total=None, disable=False,
unit=None, *args, **kwargs):
self.total = total
self.disable = disable
self.n = 0
# 忽略所有额外的 *args 和 **kwargs,以免你想要重新发明 tqdm
def update(self, n):
if self.disable:
return
self.n += n
if self.total is None:
sys.stderr.write(f"\r{self.n:.1f} bytes")
else:
sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%")
sys.stderr.flush()
# 不要费心实现;如果你想要,可以使用真正的 tqdm
def set_description(self, *args, **kwargs):
pass
def write(self, s):
sys.stderr.write(f"{s}\n")
def close(self):
self.disable = True
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.disable:
return
sys.stderr.write('\n')
try:
from tqdm import tqdm # 如果安装了 tqdm,则使用它,否则使用假包装器
except ImportError:
tqdm = _Faketqdm
__all__ = [
'download_url_to_file',
'get_dir',
'help',
'list',
'load',
'load_state_dict_from_url',
'set_dir',
]
# 匹配 bfd8deac 从 resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
VAR_DEPENDENCY = 'dependencies'
MODULE_HUBCONF = 'hubconf.py'
READ_DATA_CHUNK = 128 * 1024
_hub_dir: Optional[str] = None
@contextlib.contextmanager
def _add_to_sys_path(path):
sys.path.insert(0, path)
try:
yield
finally:
sys.path.remove(path)
# 从 tools/shared/module_loader 复制,以包含在 torch 包中
def _import_module(name, path):
import importlib.util
from importlib.abc import Loader
spec = importlib.util.spec_from_file_location(name, path)
assert spec is not None