torch.multiprocessing 的源代码
"""torch.multiprocessing 是围绕原生 :mod:`multiprocessing` 模块的封装。
它注册了自定义的 reducers,使用共享内存为不同进程提供对相同数据的共享视图。一旦张量/存储被移动到共享内存(参见 :func:`~torch.Tensor.share_memory_`),就可以将其发送到其他进程,而无需进行任何复制。
API 与原始模块 100% 兼容 - 只需将 ``import multiprocessing`` 更改为 ``import torch.multiprocessing``,就可以通过队列或其他机制发送的所有张量移动到共享内存中。
由于 API 的相似性,我们没有记录此包的大部分内容,我们建议参考原始模块的非常好的文档。
"""
import multiprocessing
import sys
import torch
from .reductions import init_reductions
__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]
from multiprocessing import * # noqa: F403
__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined]
# 此调用为此模块添加了一个特定于 Linux 的 prctl(2) 包装函数。
# 更多信息请参见 https://github.com/pytorch/pytorch/pull/14391。
torch._C._multiprocessing_init()
"""添加辅助函数以生成 N 个进程并等待其中任何一个完成。这依赖于 `mp.get_context`,它在 Python 3.4 中被添加。"""
from .spawn import (
ProcessContext,
ProcessExitedException,
ProcessRaisedException,
spawn,
SpawnContext,
start_processes,
)
if sys.platform == "darwin" or sys.platform == "win32":
_sharing_strategy = "file_system"
_all_sharing_strategies = {"file_system"}
else:
_sharing_strategy = "file_descriptor"
_all_sharing_strategies = {"file_descriptor", "file_system"}
[docs]def set_sharing_strategy(new_strategy):
"""设置共享 CPU 张量的策略。
参数:
new_strategy (str): 所选策略的名称。应该是 :func:`get_all_sharing_strategies()` 返回的值之一。
"""
global _sharing_strategy
assert new_strategy in _all_sharing_strategies
_sharing_strategy = new_strategy
[docs]def get_sharing_strategy():
"""返回当前共享 CPU 张量的策略。"""
return _sharing_strategy
[docs]def get_all_sharing_strategies():
"""返回当前系统支持的共享策略集。"""
return _all_sharing_strategies
init_reductions()