Shortcuts

torch.hub

Pytorch Hub 是一个预训练模型库,旨在促进研究的可重复性。

发布模型

Pytorch Hub 支持通过添加一个简单的 hubconf.py 文件,将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库;

hubconf.py 可以有多个入口点。每个入口点定义为一个Python函数(例如:您想要发布的预训练模型)。

def entrypoint_name(*args, **kwargs):
    # 参数 args 和 kwargs 是可选的,适用于接受位置参数/关键字参数的模型。
    ...

如何实现一个入口点?

这里是一个代码片段,指定了resnet18模型的入口点,如果我们扩展了pytorch/vision/hubconf.py中的实现。 在大多数情况下,在hubconf.py中导入正确的函数就足够了。这里我们只是想使用扩展版本作为一个例子来展示它是如何工作的。 你可以在pytorch/vision repo中看到完整的脚本。

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18

# resnet18 是入口点的名称
def resnet18(pretrained=False, **kwargs):
    """ # 此文档字符串显示在 hub.help() 中
    Resnet18 模型
    pretrained (bool): kwargs, 将预训练的权重加载到模型中
    """
    # 调用模型,加载预训练的权重
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model
  • dependencies 变量是一个 列表,包含加载模型所需的包名称。请注意,这可能与训练模型所需的依赖项略有不同。

  • argskwargs 被传递给实际的可调用函数。

  • 函数的文档字符串作为帮助信息。它解释了模型做什么以及允许的位置/关键字参数是什么。强烈建议在这里添加一些示例。

  • 入口函数可以返回一个模型(nn.module),或者辅助工具以使用户工作流程更加顺畅,例如分词器。

  • 以下划线开头的可调用对象被视为辅助函数,不会出现在 torch.hub.list() 中。

  • 预训练权重可以本地存储在GitHub仓库中,或者可以通过 torch.hub.load_state_dict_from_url()加载。如果小于2GB,建议将其附加到项目发布 并使用发布中的url。 在上面的示例中,torchvision.models.resnet.resnet18处理pretrained,或者您可以将以下逻辑放入入口点定义中。

if pretrained:
    # 对于保存在本地GitHub仓库中的检查点,例如 =weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)

    # 对于保存在其他地方的检查点
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知

  • 发布的模型应至少在一个分支/标签中。它不能是一个随机的提交。

从Hub加载模型

Pytorch Hub 提供了便捷的 API 来通过 torch.hub.list() 探索 hub 中所有可用的模型,通过 torch.hub.help() 显示文档字符串和示例,并使用 torch.hub.load() 加载预训练模型。

torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[源代码]

列出指定仓库中所有可调用的入口点,仓库由 github 指定。

Parameters
  • github (str) – 一个格式为“repo_owner/repo_name[:ref]”的字符串,可选的ref(标签或分支)。如果未指定ref,则假设默认分支为main(如果存在),否则为master。 示例:‘pytorch/vision:0.10’

  • force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。 默认值为 False

  • skip_validation (布尔值, 可选) – 如果False,torchhub 将检查由github参数指定的分支或提交是否正确属于仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN环境变量来指定非默认的 GitHub 令牌。默认值为False

  • trust_repo (bool, str or None) –

    "check", True, False or None. 此参数在v1.12中引入,有助于确保用户仅运行来自他们信任的仓库的代码。

    • 如果 False,将提示用户是否应信任该仓库。

    • 如果 True,该仓库将被添加到信任列表中,并在不需要明确确认的情况下加载。

    • 如果 "check",该仓库将与缓存中的信任仓库列表进行检查。如果它不在该列表中,行为将回退到 trust_repo=False 选项。

    • 如果 None:这将引发警告,邀请用户将 trust_repo 设置为 FalseTrue"check"。这仅用于向后兼容,将在v2.0中移除。

    默认值为 None,最终将在v2.0中更改为 "check"

  • 详细 (布尔值, 可选) – 如果False,静音关于命中本地缓存的消息。请注意,关于首次下载的消息无法静音。默认值为True

Returns

可用的可调用入口点

Return type

列表

示例

>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[源代码]

显示入口点 model 的文档字符串。

Parameters
  • github (str) – 一个格式为 的字符串,可选的 ref(标签或分支)。如果未指定 ref,则默认分支假设为 main(如果存在),否则为 master。 示例:‘pytorch/vision:0.10’

  • 模型 (字符串) – 在仓库的 hubconf.py 中定义的入口点名称的字符串

  • force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。 默认值为 False

  • skip_validation (布尔值, 可选) – 如果False,torchhub 将检查由github参数指定的引用是否正确属于仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN环境变量来指定非默认的 GitHub 令牌。默认值为False

  • trust_repo (bool, str or None) –

    "check", True, False or None. 此参数在v1.12中引入,有助于确保用户仅运行来自他们信任的仓库的代码。

    • 如果False,将提示用户是否应信任该仓库。

    • 如果True,该仓库将被添加到信任列表中,并在不需要明确确认的情况下加载。

    • 如果"check",该仓库将与缓存中的信任仓库列表进行检查。如果它不在该列表中,行为将回退到trust_repo=False选项。

    • 如果None:这将引发警告,邀请用户将trust_repo设置为FalseTrue"check"。这仅用于向后兼容,将在v2.0中移除。

    默认值为None,最终将在v2.0中更改为"check"

示例

>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[源代码]

从 GitHub 仓库或本地目录加载模型。

注意:加载模型是典型的用例,但这也可用于加载其他对象,如分词器、损失函数等。

如果 source 是 ‘github’,repo_or_dir 应为 形式 repo_owner/repo_name[:ref],并带有一个可选的 ref(标签或分支)。

如果 source 是 ‘local’,则 repo_or_dir 应为本地目录的路径。

Parameters
  • repo_or_dir (str) – 如果 source 是 ‘github’, 这应该对应于一个格式为 repo_owner/repo_name[:ref] 的 github 仓库, 带有可选的 ref(标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定 ref, 则默认分支假设为 main(如果存在),否则为 master。 如果 source 是 ‘local’,则它应该是一个本地目录的路径。

  • 模型 (字符串) – 在仓库/目录的 hubconf.py 中定义的可调用对象(入口点)的名称。

  • *args (可选) – 可调用对象 model 对应的参数。

  • (str, 可选) – ‘github’ 或 ‘local’。指定如何解释 repo_or_dir。默认值是 ‘github’。

  • trust_repo (bool, str or None) –

    "check", True, False or None. 此参数在v1.12中引入,有助于确保用户仅运行来自他们信任的仓库的代码。

    • 如果 False,将提示用户是否应信任该仓库。

    • 如果 True,该仓库将被添加到信任列表中,并在不需要明确确认的情况下加载。

    • 如果 "check",该仓库将与缓存中的信任仓库列表进行检查。如果它不在该列表中,行为将回退到 trust_repo=False 选项。

    • 如果 None:这将引发警告,邀请用户将 trust_repo 设置为 FalseTrue"check"。这仅用于向后兼容,将在v2.0中移除。

    默认值为 None,最终将在v2.0中更改为 "check"

  • force_reload (bool, 可选) – 是否强制重新下载GitHub仓库。如果source = 'local',则没有任何效果。默认值为False

  • 详细 (布尔值, 可选) – 如果False,静音关于命中本地缓存的消息。请注意,关于首次下载的消息无法静音。如果source = 'local',则没有任何效果。默认值为True

  • skip_validation (布尔值, 可选) – 如果False,torchhub 将检查由github参数指定的分支或提交是否正确属于仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN环境变量来指定非默认的 GitHub 令牌。默认值为False

  • **kwargs (可选) – 可调用对象 model 对应的 kwargs。

Returns

当使用给定的 *args**kwargs 调用时,model 可调用对象的输出。

示例

>>> # 来自一个github仓库
>>> repo = 'pytorch/vision'
>>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
>>> # 来自一个本地目录
>>> path = '/some/local/path/pytorch/vision'
>>> model = torch.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT')
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[源代码]

将指定URL的对象下载到本地路径。

Parameters
  • url (str) – 要下载的对象的URL

  • dst (str) – 对象将保存的完整路径,例如 /tmp/temporary_file

  • hash_prefix (str, 可选) – 如果不是 None,下载的文件的 SHA256 应该以 hash_prefix 开头。 默认值: None

  • 进度 (bool, 可选) – 是否在stderr显示进度条 默认值: True

示例

>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[源代码]

加载位于给定URL的Torch序列化对象。

如果下载的文件是zip文件,它将自动解压缩。

如果对象已经存在于model_dir中,它将被反序列化并返回。 model_dir的默认值是/checkpoints,其中 hub_dir是由get_dir()返回的目录。

Parameters
  • url (str) – 要下载的对象的URL

  • model_dir (str, 可选) – 保存对象的目录

  • map_location可选)– 一个函数或字典,指定如何重新映射存储位置(参见 torch.load)

  • 进度 (布尔值, 可选) – 是否在标准错误输出中显示进度条。 默认值: True

  • check_hash (布尔值, 可选) – 如果为True,URL的文件名部分应遵循命名约定 filename-.ext 其中 是文件内容的SHA256哈希的前八位或更多位数字。该哈希用于确保唯一名称并验证文件内容。 默认值:False

  • file_name (str, 可选) – 下载文件的名称。如果未设置,将使用来自url的文件名。

  • weights_only (布尔值, 可选) – 如果为True,则只加载权重,不加载复杂的序列化对象。 推荐用于不可信的来源。更多详情请参见load()

Return type

字典[字符串, 任意]

示例

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

运行加载的模型:

请注意,*args**kwargstorch.hub.load() 中用于 实例化 模型。在加载模型后,如何找出 你可以用模型做什么? 建议的工作流程是

  • dir(model) 查看模型所有可用的方法。

  • help(model.foo) 检查 model.foo 运行时所需的参数

为了帮助用户在不来回查阅文档的情况下进行探索,我们强烈建议仓库所有者使函数帮助信息清晰简洁。同时,包含一个最小的工作示例也是非常有帮助的。

我的下载模型保存在哪里?

位置按以下顺序使用

  • 调用 hub.set_dir()

  • $TORCH_HOME/hub,如果设置了环境变量TORCH_HOME

  • $XDG_CACHE_HOME/torch/hub,如果环境变量 XDG_CACHE_HOME 已设置。

  • ~/.cache/torch/hub

torch.hub.get_dir()[源代码]

获取用于存储下载模型和权重的Torch Hub缓存目录。

如果未调用 set_dir(),默认路径为 $TORCH_HOME/hub,其中环境变量 $TORCH_HOME 默认为 $XDG_CACHE_HOME/torch$XDG_CACHE_HOME 遵循 Linux 文件系统布局的 X 设计组规范,如果未设置环境变量,则默认值为 ~/.cache

torch.hub.set_dir(d)[源代码]

可选地设置用于保存下载的模型和权重的Torch Hub目录。

Parameters

d (str) – 保存下载模型和权重的本地文件夹路径。

缓存逻辑

默认情况下,我们在加载文件后不会清理文件。如果目录中已经存在缓存,Hub 默认使用缓存,该目录由 get_dir() 返回。

用户可以通过调用 hub.load(..., force_reload=True) 来强制重新加载。这将删除现有的 GitHub 文件夹和下载的权重,重新初始化一个新的下载。这在同一分支发布更新时非常有用,用户可以跟上最新的发布版本。

已知限制:

Torch hub 通过像安装包一样导入包来工作。在 Python 中导入时会引入一些副作用。例如,您可以在 Python 缓存中看到新项目 sys.modulessys.path_importer_cache,这是正常的 Python 行为。 这也意味着,如果您从不同的仓库导入不同的模型,如果这些仓库具有相同的子包名称(通常是一个 model 子包),您可能会遇到导入错误。解决这些导入错误的一种方法是 从 sys.modules 字典中删除有问题的子包;更多详细信息可以在 这个 GitHub 问题中找到。

这里值得一提的是一个已知的限制:用户不能同一个Python进程中加载同一仓库的两个不同分支。这就像在Python中安装两个同名包一样,并不好。如果你真的尝试这样做,缓存可能会加入其中并给你带来惊喜。当然,在单独的进程中加载它们是完全没问题的。

优云智算