AsNodePredDataset

class dgl.data.AsNodePredDataset(dataset, split_ratio=None, target_ntype=None, **kwargs)[source]

基础类:DGLDataset

将一个数据集重新用于标准的半监督传导节点预测任务。

该类将给定的数据集转换为一个新的数据集对象,使得:

  • 仅包含一个图表,可从 dataset[0] 访问。

  • 图表存储:

    • 节点标签在 g.ndata['label'] 中。

    • 训练/验证/测试掩码分别在 g.ndata['train_mask'], g.ndata['val_mask'], 和 g.ndata['test_mask'] 中。

  • 此外,数据集包含以下属性:

    • num_classes,要预测的类别数量。

    • train_idx, val_idx, test_idx, 训练/验证/测试索引。

如果输入的数据集包含异构图,用户需要指定target_ntype参数来指示要为哪种节点类型进行预测。在这种情况下:

  • 节点标签存储在 g.nodes[target_ntype].data['label'] 中。

  • 训练掩码存储在 g.nodes[target_ntype].data['train_mask'] 中。 验证和测试掩码也是如此。

该类将仅保留提供的数据集中的第一个图,并根据给定的分割比例生成训练/验证/测试掩码。生成的掩码将被缓存到磁盘以便快速重新加载。如果提供的分割比例与缓存的不同,它将适当地重新处理数据集。

Parameters:
  • dataset (DGLDataset) – 要转换的数据集。

  • split_ratio ((float, float, float), optional) – 训练集、验证集和测试集的分割比例。它们的总和必须为一。

  • target_ntype (str, optional) – 要添加分割掩码的节点类型。

num_classes

预测的类别数量。

Type:

int

train_idx

一个一维整数张量,包含训练节点的ID。

Type:

张量

val_idx

一个一维整数张量,包含验证节点的ID。

Type:

张量

test_idx

一个一维整数张量的测试节点ID。

Type:

张量

示例

>>> ds = dgl.data.AmazonCoBuyComputerDataset()
>>> print(ds)
Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...)
>>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])
>>> print(new_ds)
Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...)
>>> print('train_mask' in new_ds[0].ndata)
True
__getitem__(idx)[source]

获取索引处的数据对象。

__len__()[source]

数据集中的示例数量。