如何编写你自己的TVTensor类¶
本指南适用于高级用户和下游库维护者。我们将解释如何编写您自己的TVTensor类,以及如何使其与内置的Torchvision v2转换兼容。在继续之前,请确保您已阅读TVTensors 常见问题。
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
我们将创建一个非常简单的类,它只是继承自基类
TVTensor。这将足以涵盖
您需要了解的内容,以实现更复杂的使用案例。如果您需要
创建一个携带元数据的类,请查看
BoundingBoxes 类的 实现。
class MyTVTensor(tv_tensors.TVTensor):
pass
my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])
既然我们已经定义了自定义的TVTensor类,我们希望它能与内置的torchvision转换和功能API兼容。为此,我们需要实现一个执行转换核心的内核,然后通过register_kernel()将其“挂钩”到我们想要支持的功能上。
我们在下面展示这个过程:我们为MyTVTensor类的“水平翻转”操作创建了一个内核,并将其注册到功能API中。
from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
要理解为什么使用wrap(),请参阅
我有一个TVTensor但现在我有一个Tensor。帮助!。暂时忽略*args, **kwargs,
我们将在参数转发,并确保你的内核的未来兼容性中解释它。
注意
在我们上面的register_kernel调用中,我们使用了一个字符串functional="hflip"来引用我们想要挂钩的功能。我们也可以直接使用功能本身,即@register_kernel(functional=F.hflip, ...)。
现在我们已经注册了我们的内核,我们可以在一个MyTVTensor实例上调用功能API:
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!
我们也可以使用
RandomHorizontalFlip 变换,因为它内部依赖于 hflip():
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!
注意
我们无法为转换类注册内核,我们只能为功能注册内核。我们无法注册转换类的原因是因为一个转换可能在内部依赖于多个功能,因此通常我们无法为给定的类注册单个内核。
参数转发,并确保您的内核的未来兼容性¶
您正在使用的功能API是公开的,因此向后兼容:我们保证这些功能的参数在没有适当的弃用周期的情况下不会被删除或重命名。然而,我们不保证向前兼容性,并且我们可能会在未来添加新的参数。
想象一下,在未来的版本中,Torchvision 为其 hflip() 函数添加了一个新的 inplace 参数。如果你已经定义并注册了你自己的内核为
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
然后调用 F.hflip(my_dp) 将会 失败,因为 hflip 会尝试将新的 inplace 参数传递给你的内核,但你的内核不接受它。
因此,我们建议始终在签名中使用*args, **kwargs来定义您的内核,如上所示。这样,您的内核将能够接受我们未来可能添加的任何新参数。(从技术上讲,仅添加**kwargs应该就足够了)。
脚本总运行时间: (0 分钟 0.004 秒)