开始使用 transforms v2¶
这个例子展示了开始使用新的torchvision.transforms.v2 API所需了解的所有内容。我们将涵盖简单的任务,如图像分类,以及更高级的任务,如目标检测/分割。
首先,进行一些设置
from pathlib import Path
import torch
import matplotlib.pyplot as plt
plt.rcParams["savefig.bbox"] = 'tight'
from torchvision.transforms import v2
from torchvision.io import decode_image
torch.manual_seed(1)
# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
from helpers import plot
img = decode_image(str(Path('../assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.uint8, img.shape = torch.Size([3, 512, 512])
基础¶
Torchvision 的 transforms 行为类似于常规的 torch.nn.Module(事实上,它们中的大多数都是):实例化一个 transform,传递一个输入,获取一个转换后的输出:
transform = v2.RandomCrop(size=(224, 224))
out = transform(img)
plot([img, out])

我只想做图像分类¶
如果你只关心图像分类,事情就非常简单。一个基本的分类流程可能看起来像这样:
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)
plot([img, out])

这种转换管道通常作为transform参数传递给数据集,例如ImageNet(...,
transform=transforms)。
差不多就这些了。从这里开始,阅读我们的主要文档以了解更多关于推荐实践和约定的信息,或者探索更多示例,例如如何使用增强变换如CutMix和MixUp。
注意
如果您已经在使用torchvision.transforms v1 API,
我们建议您切换到新的v2 transforms。这非常容易:
v2 transforms与v1 API完全兼容,因此您只需要更改导入即可!
检测、分割、视频¶
新的Torchvision变换在torchvision.transforms.v2命名空间中
支持图像分类之外的任务:它们还可以变换边界
框、分割/检测掩码或视频。
让我们简要看一下带有边界框的检测示例。
from torchvision import tv_tensors # we'll describe this a bit later, bare with us
boxes = tv_tensors.BoundingBoxes(
[
[15, 10, 370, 510],
[275, 340, 510, 510],
[130, 345, 210, 425]
],
format="XYXY", canvas_size=img.shape[-2:])
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))
plot([(img, boxes), (out_img, out_boxes)])

<class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'> <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>
上面的例子主要关注对象检测。但如果我们有用于对象分割或语义分割的掩码
(torchvision.tv_tensors.Mask),或者视频
(torchvision.tv_tensors.Video),我们可以以完全相同的方式将它们传递给变换。
到目前为止,你可能会有几个问题:这些TVTensors是什么,我们如何使用它们,以及这些变换的预期输入/输出是什么?我们将在接下来的部分中回答这些问题。
什么是TVTensors?¶
TVTensors是torch.Tensor的子类。可用的TVTensors包括
Image,
BoundingBoxes,
Mask, 和
视频。
TVTensors 看起来和感觉就像普通的张量 - 它们就是张量。
所有在普通 torch.Tensor 上支持的操作,比如 .sum()
或任何 torch.* 操作符,也都可以在 TVTensor 上使用:
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
isinstance(img_dp, torch.Tensor) = True
img_dp.dtype = torch.uint8, img_dp.shape = torch.Size([3, 256, 256]), img_dp.sum() = tensor(25087958)
这些TVTensor类是转换的核心:为了转换给定的输入,转换首先查看对象的类,然后相应地分派到适当的实现。
此时您不需要对TVTensors了解太多,但想要了解更多的高级用户可以参考 TVTensors 常见问题。
我应该传递什么作为输入?¶
在上面,我们已经看到了两个例子:一个是我们传递了一张图片作为输入,即 out = transforms(img),另一个是我们同时传递了一张图片和边界框,即 out_img, out_boxes = transforms(img, boxes)。
事实上,转换支持任意输入结构。输入可以是单个图像、元组、任意嵌套的字典……几乎任何东西。相同的结构将作为输出返回。下面,我们使用相同的检测转换,但传递一个元组(image, target_dict)作为输入,并且我们得到了相同的结构作为输出:
target = {
"boxes": boxes,
"labels": torch.arange(boxes.shape[0]),
"this_is_ignored": ("arbitrary", {"structure": "!"})
}
# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)
plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")

('arbitrary', {'structure': '!'})
我们传递了一个元组,所以我们得到了一个元组返回,第二个元素是转换后的目标字典。转换并不真正关心输入的结构;如上所述,它们只关心对象的类型并相应地转换它们。
外部对象如字符串或整数会直接传递。这在调试时非常有用,例如,如果你想为每个样本关联一个路径!
注意
免责声明 本说明稍微高级一些,初次阅读时可以安全跳过。
纯torch.Tensor对象通常被视为图像(或对于特定视频的变换被视为视频)。确实,你可能已经注意到在上面的代码中我们根本没有使用Image类,但我们的图像仍然得到了正确的变换。变换遵循以下逻辑来确定纯张量是否应被视为图像(或视频),或者只是被忽略:
如果输入中存在
Image、Video、 或PIL.Image.Image实例,所有其他纯张量将被传递。如果没有
Image或Video实例,只有第一个纯torch.Tensor会被转换为图像或视频,而所有 其他内容将直接通过。这里的“第一个”指的是“深度优先遍历中的第一个”。
这是在上面的检测示例中发生的情况:第一个纯张量是图像,因此它被正确转换了,而所有其他纯张量实例(如labels)都被传递了(尽管标签仍然可以通过一些转换进行转换,例如SanitizeBoundingBoxes!)。
转换和数据集互操作性¶
大致来说,数据集的输出必须与转换的输入相对应。如何做到这一点取决于您使用的是torchvision的内置数据集,还是您自己的自定义数据集。
使用内置数据集¶
如果你只是进行图像分类,你不需要做任何事情。只需使用数据集的transform参数,例如ImageNet(...,
transform=transforms),你就可以开始了。
Torchvision 还支持用于目标检测或分割的数据集,例如
torchvision.datasets.CocoDetection。这些数据集在
torchvision.transforms.v2 模块和 TVTensors 存在之前就已经存在,因此它们不会直接返回 TVTensors。
强制这些数据集返回TVTensors并使它们与v2转换兼容的简单方法是使用
torchvision.datasets.wrap_dataset_for_transforms_v2() 函数:
from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2
dataset = CocoDetection(..., transforms=my_transforms)
dataset = wrap_dataset_for_transforms_v2(dataset)
# Now the dataset returns TVTensors!
使用您自己的数据集¶
如果您有一个自定义数据集,那么您需要将您的对象转换为适当的TVTensor类。创建TVTensor实例非常简单,更多详情请参考如何构建TVTensor?。
有两个主要的地方可以实现该转换逻辑:
在数据集的
__getitem__方法结束时,返回样本之前(或通过子类化数据集)。作为转换管道的第一步
无论哪种方式,逻辑都将取决于您的特定数据集。
脚本总运行时间: (0 分钟 0.792 秒)