创建图数据集
尽管 PyG 已经包含了许多有用的数据集,您可能希望使用自己记录或非公开可用的数据创建自己的数据集。
自己实现数据集是直接的,你可能想查看源代码以了解各种数据集是如何实现的。 然而,我们简要介绍了设置自己的数据集所需的内容。
我们为数据集提供了两个抽象类:torch_geometric.data.Dataset 和 torch_geometric.data.InMemoryDataset。
torch_geometric.data.InMemoryDataset 继承自 torch_geometric.data.Dataset,并且应该在整个数据集适合放入CPU内存时使用。
遵循torchvision的惯例,每个数据集都会传递一个根文件夹,该文件夹指示数据集应存储的位置。
我们将根文件夹分为两个文件夹:raw_dir,数据集下载到该文件夹,以及processed_dir,处理后的数据集保存在该文件夹中。
此外,每个数据集可以传递一个transform、一个pre_transform和一个pre_filter函数,这些函数默认是None。
transform函数在访问之前动态转换数据对象(因此最好用于数据增强)。
pre_transform函数在将数据对象保存到磁盘之前应用转换(因此最好用于只需要进行一次的繁重预计算)。
pre_filter函数可以在保存之前手动过滤掉数据对象。
用例可能涉及限制数据对象属于特定类。
创建“内存数据集”
为了创建一个torch_geometric.data.InMemoryDataset,你需要实现四个基本方法:
torch_geometric.data.InMemoryDataset.raw_file_names(): 一个文件列表,位于raw_dir中,需要找到这些文件以便跳过下载。torch_geometric.data.InMemoryDataset.processed_file_names(): 一个文件列表,位于processed_dir中,需要找到这些文件以便跳过处理。torch_geometric.data.InMemoryDataset.download(): 将原始数据下载到raw_dir中。torch_geometric.data.InMemoryDataset.process(): 处理原始数据并将其保存到processed_dir中。
你可以在 torch_geometric.data 中找到有用的方法来下载和提取数据。
真正的魔法发生在process()的主体中。
在这里,我们需要读取并创建一个Data对象列表,并将其保存到processed_dir中。
因为保存一个巨大的python列表非常慢,我们在保存之前通过torch_geometric.data.InMemoryDataset.collate()将列表整理成一个巨大的Data对象。
整理后的数据对象将所有示例连接成一个大数据对象,并返回一个slices字典,以便从该对象中重建单个示例。
最后,我们需要在构造函数中将这两个对象加载到属性self.data和self.slices中。
注意
从 PyG >= 2.4 开始,torch.save() 和 torch_geometric.data.InMemoryDataset.collate() 的功能被统一并在 torch_geometric.data.InMemoryDataset.save() 后面实现。
此外,self.data 和 self.slices 通过 torch_geometric.data.InMemoryDataset.load() 隐式加载。
让我们通过一个简化的例子来看看这个过程:
import torch
from torch_geometric.data import InMemoryDataset, download_url
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
self.load(self.processed_paths[0])
# For PyG<2.4:
# self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
download_url(url, self.raw_dir)
...
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
self.save(data_list, self.processed_paths[0])
# For PyG<2.4:
# torch.save(self.collate(data_list), self.processed_paths[0])
创建“更大”的数据集
对于创建不适合内存的数据集,可以使用torch_geometric.data.Dataset,它紧密遵循torchvision数据集的概念。
它期望额外实现以下方法:
torch_geometric.data.Dataset.len(): 返回数据集中示例的数量。torch_geometric.data.Dataset.get(): 实现加载单个图形的逻辑。
在内部,torch_geometric.data.Dataset.__getitem__() 从 torch_geometric.data.Dataset.get() 获取数据对象,并根据 transform 进行可选的转换。
让我们通过一个简化的例子来看看这个过程:
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
path = download_url(url, self.raw_dir)
...
def process(self):
idx = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
idx += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
return data
常见问题
如何跳过
download()和/或process()的执行?你可以通过不重写
download()和process()方法来跳过下载和/或处理:class MyOwnDataset(Dataset): def __init__(self, transform=None, pre_transform=None): super().__init__(None, transform, pre_transform)
我真的需要使用这些数据集接口吗?
不!就像在常规的PyTorch中一样,你不需要使用数据集,例如,当你想在不需要显式保存到磁盘的情况下动态创建合成数据时。 在这种情况下,只需传递一个包含
torch_geometric.data.Data对象的常规Python列表,并将它们传递给torch_geometric.loader.DataLoader:from torch_geometric.data import Data from torch_geometric.loader import DataLoader data_list = [Data(...), ..., Data(...)] loader = DataLoader(data_list, batch_size=32)
练习
考虑以下由一系列InMemoryDataset构建的Data对象列表:
class MyDataset(InMemoryDataset):
def __init__(self, root, data_list, transform=None):
self.data_list = data_list
super().__init__(root, transform)
self.load(self.processed_paths[0])
@property
def processed_file_names(self):
return 'data.pt'
def process(self):
self.save(self.data_list, self.processed_paths[0])
self.processed_paths[0]的输出是什么?save()是做什么的?