CLUSTERDataset

class dgl.data.CLUSTERDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: DGLBuiltinDataset

用于半监督聚类任务的CLUSTER数据集。

每个图包含6个SBM簇,其大小在[5, 35]之间随机选择,概率p = 0.55,q = 0.25。图的大小为40到190个节点。每个节点可以取输入特征值在{0, 1, 2, …, 6}中,值1~6分别对应类别0~5,而值0表示节点的类别未知。每个社区中只有一个随机分配的标记节点,大多数节点特征设置为0。

Reference https://arxiv.org/pdf/2003.00982.pdf

统计:

  • 训练示例:10,000

  • 有效示例:1,000

  • 测试示例:1,000

  • 每个节点的类别数:6

Parameters:
  • mode (str) – Must be one of (‘train’, ‘valid’, ‘test’). Default: ‘train’

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: False

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

每个节点的类别数量。

Type:

int

示例

>>> from dgl.data import CLUSTERDataset
>>>
>>> trainset = CLUSTERDataset(mode='train')
>>>
>>> trainset.num_classes
6
>>> len(trainset)
10000
>>> trainset[0]
Graph(num_nodes=117, num_edges=4104,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int16),
                     'feat': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
__getitem__(idx)[source]

获取第 idx 个样本。

Parameters:

idx (int) – The sample index.

Returns:

graph structure, node features, node labels and edge features.

  • ndata['feat']: node features

  • ndata['label']: node labels

  • edata['feat']: edge features

Return type:

dgl.DGLGraph

__len__()[source]

数据集中的示例数量。