Shortcuts

torch.package.package_importer 的源代码

import builtins
import importlib
import importlib.machinery
import inspect
import io
import linecache
import os
import types
from contextlib import contextmanager
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
from weakref import WeakValueDictionary

import torch
from torch.serialization import _get_restore_location, _maybe_decode_ascii

from ._directory_reader import DirectoryReader
from ._importlib import (
    _calc___package__,
    _normalize_line_endings,
    _normalize_path,
    _resolve_name,
    _sanity_check,
)
from ._mangling import demangle, PackageMangler
from ._package_unpickler import PackageUnpickler
from .file_structure_representation import _create_directory_from_file_list, Directory
from .glob_group import GlobPattern
from .importer import Importer

__all__ = ["PackageImporter"]


# 这是一个导入列表,即使它们没有被标记为外部导入,也会隐式允许。这是为了解决Torch隐式依赖numpy的问题,而包无法跟踪它。
# https://github.com/pytorch/MultiPy/issues/46
IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
    "numpy",
    "numpy.core",
    "numpy.core._multiarray_umath",
    # FX GraphModule可能依赖于builtins模块,而用户通常不会外部化builtins。这里我们默认导入它。
    "builtins",
]


[docs]class PackageImporter(Importer): """导入器允许你加载由 :class:`PackageExporter` 编写的包中的代码。 代码以一种封闭的方式加载,使用包中的文件而不是正常的python导入系统。这允许 打包PyTorch模型代码和数据,以便它可以在服务器上运行或在将来用于迁移学习。 包的导入器确保模块中的代码只能从包中加载,除非在导出期间明确列出为外部模块。 压缩存档中的文件 ``extern_modules`` 列出了包外部依赖的所有模块。 这防止了“隐式”依赖,即包在本地运行是因为它正在导入本地安装的包,但在将包复制到另一台机器时失败。 """ """从该包中已经加载的模块的字典,相当于 ``sys.modules`` 但 仅限于此导入器。 """ modules: Dict[str, types.ModuleType]
[docs] def __init__( self, file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO], module_allowed: Callable[[str], bool] = lambda module_name: True, ): """打开 ``file_or_buffer`` 以进行导入。这会检查导入的包是否只要求模块 被 ``module_allowed`` 允许 参数: file_or_buffer: 一个类文件对象(必须实现 :meth:`read`, :meth:`readline`, :meth:`tell`, 和 :meth:`seek`), 一个字符串,或一个 ``os.PathLike`` 对象,包含文件名。 module_allowed (Callable[[str], bool], 可选): 一个方法,用于确定是否允许外部提供的模块。 可以用来确保加载的包不依赖于服务器不支持的模块。默认为允许任何模块。 引发: ImportError: 如果包将使用不允许的模块。 """ torch._C._log_api_usage_once("torch.package.PackageImporter") self.zip_reader: Any if isinstance(file_or_buffer, torch._C.PyTorchFileReader): self.filename = "" self.zip_reader = file_or_buffer elif isinstance(file_or_buffer, (os.PathLike, str)): self.filename = os.fspath(file_or_buffer) if not os.path.isdir(self.filename): self.zip_reader = torch._C.PyTorchFileReader(self.filename) else: self.zip_reader = DirectoryReader(self.filename) else: self.filename = "" self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer) torch._C._log_api_usage_metadata( "torch.package.PackageImporter.metadata", { "serialization_id": self.zip_reader.serialization_id(), "file_name": self.filename, }, ) self.root = _PackageNode(None) self.modules = {} self.extern_modules = self._read_extern() for extern_module in self.extern_modules:
优云智算