PPIDataset

class dgl.data.PPIDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: DGLBuiltinDataset

用于归纳节点分类的蛋白质-蛋白质相互作用数据集

一个玩具蛋白质-蛋白质相互作用网络数据集。该数据集包含24个图。每个图的平均节点数为2372。每个节点有50个特征和121个标签。20个图用于训练,2个用于验证,2个用于测试。

参考:http://snap.stanford.edu/graphsage/

统计:

  • 训练示例:20

  • 有效示例:2

  • 测试示例:2

Parameters:
  • mode (str) – Must be one of (‘train’, ‘valid’, ‘test’). Default: ‘train’

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – 是否打印进度信息。 默认值:True。

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_labels

每个节点的标签数量

Type:

int

labels

节点标签

Type:

张量

features

节点特性

Type:

张量

示例

>>> dataset = PPIDataset(mode='valid')
>>> num_classes = dataset.num_classes
>>> for g in dataset:
....    feat = g.ndata['feat']
....    label = g.ndata['label']
....    # your code here
>>>
__getitem__(item)[source]

获取第 item^th 个样本。

Parameters:

item (int) – 样本索引。

Returns:

图结构、节点特征和节点标签。

  • ndata['feat']: 节点特征

  • ndata['label']: 节点标签

Return type:

dgl.DGLGraph

__len__()[source]

返回此数据集中的样本数量。