CLUSTERDataset
- class dgl.data.CLUSTERDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
Bases:
DGLBuiltinDataset
用于半监督聚类任务的CLUSTER数据集。
每个图包含6个SBM簇,其大小在[5, 35]之间随机选择,概率p = 0.55,q = 0.25。图的大小为40到190个节点。每个节点可以取输入特征值在{0, 1, 2, …, 6}中,值1~6分别对应类别0~5,而值0表示节点的类别未知。每个社区中只有一个随机分配的标记节点,大多数节点特征设置为0。
Reference https://arxiv.org/pdf/2003.00982.pdf
统计:
训练示例:10,000
有效示例:1,000
测试示例:1,000
每个节点的类别数:6
- Parameters:
mode (str) – Must be one of (‘train’, ‘valid’, ‘test’). Default: ‘train’
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: False
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
示例
>>> from dgl.data import CLUSTERDataset >>> >>> trainset = CLUSTERDataset(mode='train') >>> >>> trainset.num_classes 6 >>> len(trainset) 10000 >>> trainset[0] Graph(num_nodes=117, num_edges=4104, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int16), 'feat': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})