AsNodePredDataset
- class dgl.data.AsNodePredDataset(dataset, split_ratio=None, target_ntype=None, **kwargs)[source]
基础类:
DGLDataset
将一个数据集重新用于标准的半监督传导节点预测任务。
该类将给定的数据集转换为一个新的数据集对象,使得:
仅包含一个图表,可从
dataset[0]
访问。图表存储:
节点标签在
g.ndata['label']
中。训练/验证/测试掩码分别在
g.ndata['train_mask']
,g.ndata['val_mask']
, 和g.ndata['test_mask']
中。
此外,数据集包含以下属性:
num_classes
,要预测的类别数量。train_idx
,val_idx
,test_idx
, 训练/验证/测试索引。
如果输入的数据集包含异构图,用户需要指定
target_ntype
参数来指示要为哪种节点类型进行预测。在这种情况下:节点标签存储在
g.nodes[target_ntype].data['label']
中。训练掩码存储在
g.nodes[target_ntype].data['train_mask']
中。 验证和测试掩码也是如此。
该类将仅保留提供的数据集中的第一个图,并根据给定的分割比例生成训练/验证/测试掩码。生成的掩码将被缓存到磁盘以便快速重新加载。如果提供的分割比例与缓存的不同,它将适当地重新处理数据集。
- Parameters:
dataset (DGLDataset) – 要转换的数据集。
split_ratio ((float, float, float), optional) – 训练集、验证集和测试集的分割比例。它们的总和必须为一。
target_ntype (str, optional) – 要添加分割掩码的节点类型。
- train_idx
一个一维整数张量,包含训练节点的ID。
- Type:
张量
- val_idx
一个一维整数张量,包含验证节点的ID。
- Type:
张量
- test_idx
一个一维整数张量的测试节点ID。
- Type:
张量
示例
>>> ds = dgl.data.AmazonCoBuyComputerDataset() >>> print(ds) Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...) >>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1]) >>> print(new_ds) Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...) >>> print('train_mask' in new_ds[0].ndata) True