4.1 DGLDataset 类
DGLDataset
是用于处理、加载和保存定义在 dgl.data 中的图数据集的基础类。它实现了处理图数据的基本流程。下面的流程图展示了该流程的工作原理。
要处理位于远程服务器或本地磁盘上的图数据集,可以定义一个类,例如 MyDataset
,继承自 dgl.data.DGLDataset
。MyDataset
的模板如下。

DGLDataset 类中定义的图数据输入管道的流程图。
from dgl.data import DGLDataset
class MyDataset(DGLDataset):
""" Template for customizing graph datasets in DGL.
Parameters
----------
url : str
URL to download the raw dataset
raw_dir : str
Specifying the directory that will store the
downloaded data or the directory that
already stores the input data.
Default: ~/.dgl/
save_dir : str
Directory to save the processed dataset.
Default: the value of `raw_dir`
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information
"""
def __init__(self,
url=None,
raw_dir=None,
save_dir=None,
force_reload=False,
verbose=False):
super(MyDataset, self).__init__(name='dataset_name',
url=url,
raw_dir=raw_dir,
save_dir=save_dir,
force_reload=force_reload,
verbose=verbose)
def download(self):
# download raw data to local disk
pass
def process(self):
# process raw data to graphs, labels, splitting masks
pass
def __getitem__(self, idx):
# get one example by index
pass
def __len__(self):
# number of data examples
pass
def save(self):
# save processed data to directory `self.save_path`
pass
def load(self):
# load processed data from directory `self.save_path`
pass
def has_cache(self):
# check whether there are processed data in `self.save_path`
pass
DGLDataset
类有抽象函数 process()
,
__getitem__(idx)
和 __len__()
,这些函数必须在子类中实现。DGL 还建议实现保存和加载功能,
因为它们可以节省处理大型数据集的大量时间,并且有几个 API 使其变得容易(参见 4.4 保存和加载数据)。
请注意,DGLDataset
的目的是提供一种标准和便捷的方式来加载图数据。可以存储图、特征、标签、掩码以及数据集的基本信息,例如类别数量、标签数量等。采样、分区或特征归一化等操作在 DGLDataset
子类之外完成。
本章的其余部分展示了在管道中实现函数的最佳实践。