创建图数据集

尽管 已经包含了许多有用的数据集,您可能希望使用自己记录或非公开可用的数据创建自己的数据集。

自己实现数据集是直接的,你可能想查看源代码以了解各种数据集是如何实现的。 然而,我们简要介绍了设置自己的数据集所需的内容。

我们为数据集提供了两个抽象类:torch_geometric.data.Datasettorch_geometric.data.InMemoryDatasettorch_geometric.data.InMemoryDataset 继承自 torch_geometric.data.Dataset,并且应该在整个数据集适合放入CPU内存时使用。

遵循torchvision的惯例,每个数据集都会传递一个根文件夹,该文件夹指示数据集应存储的位置。 我们将根文件夹分为两个文件夹:raw_dir,数据集下载到该文件夹,以及processed_dir,处理后的数据集保存在该文件夹中。

此外,每个数据集可以传递一个transform、一个pre_transform和一个pre_filter函数,这些函数默认是Nonetransform函数在访问之前动态转换数据对象(因此最好用于数据增强)。 pre_transform函数在将数据对象保存到磁盘之前应用转换(因此最好用于只需要进行一次的繁重预计算)。 pre_filter函数可以在保存之前手动过滤掉数据对象。 用例可能涉及限制数据对象属于特定类。

创建“内存数据集”

为了创建一个torch_geometric.data.InMemoryDataset,你需要实现四个基本方法:

你可以在 torch_geometric.data 中找到有用的方法来下载和提取数据。

真正的魔法发生在process()的主体中。 在这里,我们需要读取并创建一个Data对象列表,并将其保存到processed_dir中。 因为保存一个巨大的python列表非常慢,我们在保存之前通过torch_geometric.data.InMemoryDataset.collate()将列表整理成一个巨大的Data对象。 整理后的数据对象将所有示例连接成一个大数据对象,并返回一个slices字典,以便从该对象中重建单个示例。 最后,我们需要在构造函数中将这两个对象加载到属性self.dataself.slices中。

注意

PyG >= 2.4 开始,torch.save()torch_geometric.data.InMemoryDataset.collate() 的功能被统一并在 torch_geometric.data.InMemoryDataset.save() 后面实现。 此外,self.dataself.slices 通过 torch_geometric.data.InMemoryDataset.load() 隐式加载。

让我们通过一个简化的例子来看看这个过程:

import torch
from torch_geometric.data import InMemoryDataset, download_url


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.load(self.processed_paths[0])
        # For PyG<2.4:
        # self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        self.save(data_list, self.processed_paths[0])
        # For PyG<2.4:
        # torch.save(self.collate(data_list), self.processed_paths[0])

创建“更大”的数据集

对于创建不适合内存的数据集,可以使用torch_geometric.data.Dataset,它紧密遵循torchvision数据集的概念。 它期望额外实现以下方法:

在内部,torch_geometric.data.Dataset.__getitem__()torch_geometric.data.Dataset.get() 获取数据对象,并根据 transform 进行可选的转换。

让我们通过一个简化的例子来看看这个过程:

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        idx = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

在这里,每个图数据对象在process()中单独保存,并在get()中手动加载。

常见问题

  1. 如何跳过 download() 和/或 process() 的执行?

    你可以通过不重写download()process()方法来跳过下载和/或处理:

    class MyOwnDataset(Dataset):
        def __init__(self, transform=None, pre_transform=None):
            super().__init__(None, transform, pre_transform)
    
  2. 我真的需要使用这些数据集接口吗?

    不!就像在常规的中一样,你不需要使用数据集,例如,当你想在不需要显式保存到磁盘的情况下动态创建合成数据时。 在这种情况下,只需传递一个包含torch_geometric.data.Data对象的常规Python列表,并将它们传递给torch_geometric.loader.DataLoader

    from torch_geometric.data import Data
    from torch_geometric.loader import DataLoader
    
    data_list = [Data(...), ..., Data(...)]
    loader = DataLoader(data_list, batch_size=32)
    

练习

考虑以下由一系列InMemoryDataset构建的Data对象列表:

class MyDataset(InMemoryDataset):
    def __init__(self, root, data_list, transform=None):
        self.data_list = data_list
        super().__init__(root, transform)
        self.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return 'data.pt'

    def process(self):
        self.save(self.data_list, self.processed_paths[0])
  1. self.processed_paths[0] 的输出是什么?

  2. save() 是做什么的?