add_nodepred_split
- class dgl.data.utils.add_nodepred_split(dataset, ratio, ntype=None)[source]
基础类:
将给定的数据集分割为训练集、验证集和测试集,用于传导节点预测任务。
它向数据集中的每个图添加了三个节点掩码数组
'train_mask'
,'val_mask'
和'test_mask'
。因此,数据集中的每个样本必须是一个DGLGraph
。修复NumPy的随机种子以使结果确定:
numpy.random.seed(42)
- Parameters:
dataset (DGLDataset) – 要修改的数据集。
ntype (str, optional) – 要为其添加掩码的节点类型。
示例
>>> dataset = dgl.data.AmazonCoBuyComputerDataset() >>> print('train_mask' in dataset[0].ndata) False >>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1]) >>> print('train_mask' in dataset[0].ndata) True