ray.train.torch.get_device#

ray.train.torch.get_device() torch.device[源代码]#

获取为此进程配置的正确 torch 设备。

返回当前工作者的 torch 设备。如果每个工作者请求超过 1 个 GPU,则返回设备索引最小的设备。

备注

如果你为每个工作节点请求了多个GPU,并且想要获取完整的torch设备列表,请使用 get_devices()

假设 CUDA_VISIBLE_DEVICES 已设置,并且是 ray.get_gpu_ids() 的超集。

示例

示例:在当前节点上启动了2个工作线程,每个工作线程使用1个GPU

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:0")

示例:在当前节点上启动了4个工作进程,每个进程使用1个GPU

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:2")

示例:在当前节点上启动了2个工作进程,每个工作进程使用2个GPU

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.get_gpu_ids() == [2,3]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:2")

你可以通过以下方式将模型移动到设备上:

model.to(ray.train.torch.get_device())

而不是手动检查设备类型:

model.to("cuda" if torch.cuda.is_available() else "cpu")