• Tutorials >
  • Writing Custom Datasets, DataLoaders and Transforms
Shortcuts

编写自定义数据集、数据加载器和转换

创建于:2017年6月10日 | 最后更新:2024年1月19日 | 最后验证:2024年11月5日

作者: Sasank Chilamkurthy

在解决任何机器学习问题时,大量的努力都投入到了数据准备中。PyTorch 提供了许多工具,使数据加载变得容易,并希望使您的代码更具可读性。在本教程中,我们将看到如何从一个非平凡的数据集中加载和预处理/增强数据。

要运行本教程,请确保已安装以下软件包:

  • scikit-image: 用于图像输入输出和变换

  • pandas: 用于更简单的csv解析

import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
<contextlib.ExitStack object at 0x7f10a5cb2320>

我们将要处理的数据集是人脸姿态的数据集。 这意味着人脸会被这样标注:

../_images/landmarked_face2.png

总体而言,每张脸标注了68个不同的关键点。

注意

这里下载数据集 以便图像位于名为‘data/faces/’的目录中。 这个数据集实际上是通过在imagenet中标记为‘face’的一些图像上应用优秀的dlib的姿态估计生成的。

数据集附带一个带有注释的.csv文件,看起来像这样:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

让我们从CSV中获取一个图像名称及其注释,在这个例子中,行索引号为65,以person-7.jpg为例。读取它,将图像名称存储在img_name中,并将其注释存储在一个(L, 2)数组landmarks中,其中L是该行中的地标数量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

让我们编写一个简单的辅助函数来显示图像及其地标,并使用它来显示一个样本。

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()
data loading tutorial

数据集类

torch.utils.data.Dataset 是一个表示数据集的抽象类。 您的自定义数据集应继承 Dataset 并重写以下方法:

  • __len__ 使得 len(dataset) 返回数据集的大小。

  • __getitem__ 支持索引,使得 dataset[i] 可以用于获取第 \(i\) 个样本。

让我们为我们的面部标志数据集创建一个数据集类。我们将在__init__中读取csv文件,但将图像的读取留给__getitem__。这是内存高效的,因为所有图像不会一次性存储在内存中,而是根据需要读取。

我们的数据集样本将是一个字典 {'image': image, 'landmarks': landmarks}。我们的数据集将接受一个 可选参数 transform,以便可以对样本应用任何所需的处理。我们将在 下一节中看到 transform 的用处。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

让我们实例化这个类并遍历数据样本。我们将打印前4个样本的大小并显示它们的地标。

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')

fig = plt.figure()

for i, sample in enumerate(face_dataset):
    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break
Sample #0, Sample #1, Sample #2, Sample #3
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

转换

从上面我们可以看到的一个问题是样本的大小不一致。大多数神经网络期望图像具有固定的大小。因此,我们需要编写一些预处理代码。让我们创建三个转换:

  • Rescale: 缩放图像

  • RandomCrop: 从图像中随机裁剪。这是数据增强。

  • ToTensor: 将numpy图像转换为torch图像(我们需要交换轴)。

我们将它们写成可调用的类,而不是简单的函数,这样就不需要在每次调用时传递转换的参数。为此,我们只需要实现__call__方法,如果需要的话,还可以实现__init__方法。然后我们可以像这样使用一个转换:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

观察下面这些变换是如何同时应用于图像和地标的。

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

注意

在上面的例子中,RandomCrop 使用了一个外部库的随机数生成器 (在这个例子中,是 Numpy 的 np.random.int)。这可能会导致与 DataLoader 的意外行为 (参见 这里)。 在实践中,使用 PyTorch 的随机数生成器更安全,例如使用 torch.randint 代替。

组合变换

现在,我们在一个样本上应用这些变换。

假设我们想要将图像的较短边缩放到256,然后从中随机裁剪一个224大小的正方形。也就是说,我们想要组合RescaleRandomCrop变换。torchvision.transforms.Compose是一个简单的可调用类,它允许我们这样做。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()
Rescale, RandomCrop, Compose

遍历数据集

让我们把所有这些放在一起,创建一个包含组合变换的数据集。 总结一下,每次采样这个数据集时:

  • 图像是从文件中动态读取的

  • 转换应用于读取的图像

  • 由于其中一个变换是随机的,数据在采样时会被增强

我们可以像之前一样使用for i in range循环来遍历创建的数据集。

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i, sample in enumerate(transformed_dataset):
    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

然而,我们通过使用简单的for循环来遍历数据,失去了很多功能。特别是,我们错过了以下内容:

  • 数据批处理

  • 打乱数据

  • 使用multiprocessing工作进程并行加载数据。

torch.utils.data.DataLoader 是一个迭代器,它提供了所有这些功能。下面使用的参数应该很清楚。一个值得关注的参数是 collate_fn。您可以使用 collate_fn 来指定样本需要如何精确地批处理。然而,默认的 collate 应该适用于大多数用例。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

# if you are using Windows, uncomment the next line and indent the for loop.
# you might need to go back and change ``num_workers`` to 0.

# if __name__ == '__main__':
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
Batch from dataloader
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后记:torchvision

在本教程中,我们已经了解了如何编写和使用数据集、转换和数据加载器。torchvision 包提供了一些常见的数据集和转换。您甚至可能不需要编写自定义类。torchvision 中可用的一个更通用的数据集是 ImageFolder。它假设图像按以下方式组织:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中‘ants’, ‘bees’等是类别标签。同样,通用的变换操作如PIL.Image上的RandomHorizontalFlip, Scale也是可用的。你可以使用这些来编写一个数据加载器,如下所示:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

有关训练代码的示例,请参见 计算机视觉教程中的迁移学习

脚本总运行时间: ( 0 分钟 2.396 秒)

Gallery generated by Sphinx-Gallery

优云智算