TUDataset
- class dgl.data.TUDataset(name, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
Bases:
DGLBuiltinDataset
TUDataset 包含许多用于图分类的图核数据集。
- Parameters:
name (str) – Dataset Name, such as
ENZYMES
,DD
,COLLAB
,MUTAG
, can be the datasets name on https://chrsmrrs.github.io/datasets/docs/datasets/.transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
注释
重要提示: 一些数据集中存在图中的重复边,例如
IMDB-BINARY
中的边都是重复的。DGL 忠实地保留了原始数据中的重复边。其他框架如 PyTorch Geometric 默认会移除重复边。你可以使用dgl.to_simple()
来移除重复边。图可能具有节点标签、节点属性、边标签和边属性,这些属性因不同的数据集而异。
标签被映射到 \(\lbrace 0,\cdots,n-1 \rbrace\) 其中 \(n\) 是 标签的数量(一些数据集的原始标签为 \(\lbrace -1, 1 \rbrace\) 这些 将被映射到 \(\lbrace 0, 1 \rbrace\))。在之前的版本中,最小 标签被添加,因此 \(\lbrace -1, 1 \rbrace\) 被映射到 \(\lbrace 0, 2 \rbrace\)。
数据集根据标签对图表进行排序。 在手动进行训练/验证分割之前,建议先进行洗牌。
示例
>>> data = TUDataset('DD')
数据集实例是可迭代的
>>> len(data) 1178 >>> g, label = data[1024] >>> g Graph(num_nodes=88, num_edges=410, ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'node_labels': Scheme(shape=(1,), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}) >>> label tensor([1])
将图和标签分批用于小批量训练
>>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> batched_graphs = dgl.batch(graphs) >>> batched_labels = torch.tensor(labels) >>> batched_graphs Graph(num_nodes=9539, num_edges=47382, ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
- __getitem__(idx)[source]
获取第 idx 个样本。
- Parameters:
idx (int) – The sample index.
- Returns:
图中节点的特征存储在
feat
字段中,节点的标签存储在node_labels
字段中(如果存在)。 以及它的标签。- Return type:
(
dgl.DGLGraph
, Tensor)