扫雷数据集
- class dgl.data.MinesweeperDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[source]
基础类:
HeterophilousGraphDataset
来自论文《A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>》的扫雷数据集。
该数据集灵感来源于扫雷游戏。图是一个规则的100x100网格,其中每个节点(单元格)与八个相邻节点相连(除了位于网格边缘的节点,它们的邻居较少)。20%的节点被随机选为地雷。任务是预测哪些节点是地雷。节点特征是相邻地雷数量的独热编码。然而,对于随机选择的50%的节点,特征是未知的,这由一个单独的二值特征表示。
统计:
节点数:10000
边数: 78804
类别:2
节点特征:7
10 个训练/验证/测试分割
- Parameters:
raw_dir (str, optional) – Raw file directory to store the processed data. Default: ~/.dgl/
force_reload (bool, optional) – Whether to re-download the data source. Default: False
verbose (bool, optional) – Whether to print progress information. Default: True
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access. Default: None
示例
>>> from dgl.data import MinesweeperDataset >>> dataset = MinesweeperDataset() >>> g = dataset[0] >>> num_classes = dataset.num_classes
>>> # get node features >>> feat = g.ndata["feat"]
>>> # get the first data split >>> train_mask = g.ndata["train_mask"][:, 0] >>> val_mask = g.ndata["val_mask"][:, 0] >>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels >>> label = g.ndata['label']
- __getitem__(idx)
获取索引处的数据对象。
- __len__()
数据集中的示例数量。