4.3 处理数据
可以在函数process()
中实现数据处理代码,并且假设原始数据已经位于self.raw_dir
中。在图上的机器学习中,通常有三种类型的任务:图分类、节点分类和链接预测。本节将展示如何处理与这些任务相关的数据集。
本节重点介绍处理图、特征和掩码的标准方法。 它将使用内置数据集作为示例,并跳过从文件构建图的实现, 但会添加指向详细实现的链接。请参阅1.4 从外部源创建图以查看 如何从外部源构建图的完整指南。
处理图分类数据集
图分类数据集与大多数典型的机器学习任务中的数据集几乎相同,其中使用了小批量训练。因此,可以将原始数据处理为一系列dgl.DGLGraph
对象和一系列标签张量。此外,如果原始数据已被分割成多个文件,可以添加一个参数split
来加载数据的特定部分。
以QM7bDataset
为例:
from dgl.data import DGLDataset
class QM7bDataset(DGLDataset):
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
mat_path = self.raw_path + '.mat'
# process data to a list of graphs and a list of labels
self.graphs, self.label = self._load_graph(mat_path)
def __getitem__(self, idx):
""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
(dgl.DGLGraph, Tensor)
"""
return self.graphs[idx], self.label[idx]
def __len__(self):
"""Number of graphs in the dataset"""
return len(self.graphs)
在process()
中,原始数据被处理成图列表和标签列表。必须实现__getitem__(idx)
和__len__()
以进行迭代。DGL建议使__getitem__(idx)
返回一个元组(graph, label)
,如上所述。请查看QM7bDataset源代码以了解self._load_graph()
和__getitem__
的详细信息。
还可以向类添加属性以指示数据集的一些有用信息。在QM7bDataset
中,可以添加一个属性num_tasks
来指示此多任务数据集中预测任务的总数:
@property
def num_tasks(self):
"""Number of labels for each graph, i.e. number of prediction tasks."""
return 14
在完成所有这些编码之后,最终可以如下使用QM7bDataset
:
import dgl
import torch
from dgl.dataloading import GraphDataLoader
# load data
dataset = QM7bDataset()
num_tasks = dataset.num_tasks
# create dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# training
for epoch in range(100):
for g, labels in dataloader:
# your training code here
pass
训练图分类模型的完整指南可以在5.4 图分类中找到。
有关图分类数据集的更多示例,请参考DGL内置的图分类数据集:
gindataset
minigcdataset
qm7bdata
tudata
处理节点分类数据集
与图分类不同,节点分类通常是在单个图上进行的。因此,数据集的划分是在图的节点上进行的。DGL 建议使用节点掩码来指定划分。本节使用内置数据集 CitationGraphDataset 作为示例:
此外,DGL建议重新排列节点和边,使得彼此靠近的节点具有相近的ID范围。这一过程可以提高访问节点邻居的局部性,这可能有利于在图上的后续计算和分析。DGL为此提供了一个名为dgl.reorder_graph()
的API。请参考下面示例中的process()
部分以获取更多详细信息。
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url
class CitationGraphDataset(DGLBuiltinDataset):
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
if name.lower() == 'cora':
name = 'cora_v2'
url = _get_dgl_url(self._urls[name])
super(CitationGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# Skip some processing code
# === data processing skipped ===
# build graph
g = dgl.graph(graph)
# splitting masks
g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask
g.ndata['test_mask'] = test_mask
# node labels
g.ndata['label'] = torch.tensor(labels)
# node features
g.ndata['feat'] = torch.tensor(_preprocess_features(features),
dtype=F.data_type_dict['float32'])
self._num_tasks = onehot_labels.shape[1]
self._labels = labels
# reorder graph to obtain better locality.
self._g = dgl.reorder_graph(g)
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
return 1
为简洁起见,本节跳过process()
中的一些代码,以突出处理节点分类数据集的关键部分:分割掩码。节点特征和节点标签存储在g.ndata
中。有关详细实现,请参阅CitationGraphDataset源代码。
请注意,__getitem__(idx)
和
__len__()
的实现也有所改变,因为在节点分类任务中通常只有一个图。
掩码在 PyTorch 和 TensorFlow 中是 bool tensors
,而在 MXNet 中是 float tensors
。
本节使用CitationGraphDataset
的子类,dgl.data.CiteseerGraphDataset
,来展示其用法:
# load data
dataset = CiteseerGraphDataset(raw_dir='')
graph = dataset[0]
# get split masks
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
# get node features
feats = graph.ndata['feat']
# get labels
labels = graph.ndata['label']
训练节点分类模型的完整指南可以在 5.1 节点分类/回归中找到。
有关节点分类数据集的更多示例,请参考DGL的内置数据集:
引用数据
corafulldata
amazoncobuydata
coauthordata
karateclubdata
ppidata
redditdata
sbmdata
sstdata
rdfdata
处理用于链接预测的数据集
链接预测数据集的处理与节点分类的处理类似,数据集中通常有一个图。
本节使用内置数据集 KnowledgeGraphDataset 作为示例,并仍然跳过详细的数据处理代码,以突出处理链接预测数据集的关键部分:
# Example for creating Link Prediction datasets
class KnowledgeGraphDataset(DGLBuiltinDataset):
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
self._name = name
self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
super(KnowledgeGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# Skip some processing code
# === data processing skipped ===
# splitting mask
g.edata['train_mask'] = train_mask
g.edata['val_mask'] = val_mask
g.edata['test_mask'] = test_mask
# edge type
g.edata['etype'] = etype
# node type
g.ndata['ntype'] = ntype
self._g = g
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
return 1
如代码所示,它将分割掩码添加到图的edata
字段中。查看KnowledgeGraphDataset源代码以查看完整代码。以下代码使用KnowledgeGraphDataset
的子类dgl.data.FB15k237Dataset
来展示其用法:
from dgl.data import FB15k237Dataset
# load data
dataset = FB15k237Dataset()
graph = dataset[0]
# get training mask
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
src, dst = graph.edges(train_idx)
# get edge types in training set
rel = graph.edata['etype'][train_idx]
训练链接预测模型的完整指南可以在 5.3 链接预测中找到。
有关链接预测数据集的更多示例,请参考DGL的内置数据集:
kgdata
bitcoinotcdata