SST数据集

class dgl.data.SSTDataset(mode='train', glove_embed_file=None, vocab_file=None, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: DGLBuiltinDataset

斯坦福情感树库数据集。

每个样本都是一个句子的选区树。叶节点代表单词。单词是一个存储在x特征字段中的整数值。非叶节点在x字段中有一个特殊值PAD_WORD。每个节点还有一个情感注释:5个类别(非常负面、负面、中性、正面和非常正面)。情感标签是一个存储在y特征字段中的整数值。官方网站:http://nlp.stanford.edu/sentiment/index.html

统计:

  • 训练示例:8,544

  • 开发示例:1,101

  • 测试示例:2,210

  • 每个节点的类别数:5

Parameters:
  • mode (str, optional) – 应该是 [‘train’, ‘dev’, ‘test’, ‘tiny’] 中的一个 默认值: train

  • glove_embed_file (str, optional) – 预训练的glove嵌入文件的路径。 默认值:无

  • vocab_file (str, optional) – 可选的词汇表文件。如果未提供,则使用默认的词汇表文件。 默认值:无

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

vocab

数据集的词汇表

Type:

有序字典

num_classes

每个节点的类别数量

Type:

int

pretrained_emb

预训练的glove嵌入与词汇表相关。

Type:

张量

vocab_size

词汇表的大小

Type:

int

注释

所有样本将首先在内存中加载和预处理。

示例

>>> # get dataset
>>> train_data = SSTDataset()
>>> dev_data = SSTDataset(mode='dev')
>>> test_data = SSTDataset(mode='test')
>>> tiny_data = SSTDataset(mode='tiny')
>>>
>>> len(train_data)
8544
>>> train_data.num_classes
5
>>> glove_embed = train_data.pretrained_emb
>>> train_data.vocab_size
19536
>>> train_data[0]
Graph(num_nodes=71, num_edges=70,
  ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
  edata_schemes={})
>>> for tree in train_data:
...     input_ids = tree.ndata['x']
...     labels = tree.ndata['y']
...     mask = tree.ndata['mask']
...     # your code here
__getitem__(idx)[source]

通过索引获取图表

Parameters:

idx (int) –

Returns:

图结构,每个节点的单词ID,节点标签和掩码。

  • ndata['x']: 节点的单词ID

  • ndata['y']: 节点的标签

  • ndata['mask']: 如果节点是叶子节点则为1,否则为0

Return type:

dgl.DGLGraph

__len__()[source]

数据集中的图表数量。