SST数据集
- class dgl.data.SSTDataset(mode='train', glove_embed_file=None, vocab_file=None, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
Bases:
DGLBuiltinDataset
斯坦福情感树库数据集。
每个样本都是一个句子的选区树。叶节点代表单词。单词是一个存储在
x
特征字段中的整数值。非叶节点在x
字段中有一个特殊值PAD_WORD
。每个节点还有一个情感注释:5个类别(非常负面、负面、中性、正面和非常正面)。情感标签是一个存储在y
特征字段中的整数值。官方网站:http://nlp.stanford.edu/sentiment/index.html统计:
训练示例:8,544
开发示例:1,101
测试示例:2,210
每个节点的类别数:5
- Parameters:
mode (str, optional) – 应该是 [‘train’, ‘dev’, ‘test’, ‘tiny’] 中的一个 默认值: train
glove_embed_file (str, optional) – 预训练的glove嵌入文件的路径。 默认值:无
vocab_file (str, optional) – 可选的词汇表文件。如果未提供,则使用默认的词汇表文件。 默认值:无
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: True.
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
- vocab
数据集的词汇表
- Type:
有序字典
- pretrained_emb
预训练的glove嵌入与词汇表相关。
- Type:
张量
注释
所有样本将首先在内存中加载和预处理。
示例
>>> # get dataset >>> train_data = SSTDataset() >>> dev_data = SSTDataset(mode='dev') >>> test_data = SSTDataset(mode='test') >>> tiny_data = SSTDataset(mode='tiny') >>> >>> len(train_data) 8544 >>> train_data.num_classes 5 >>> glove_embed = train_data.pretrained_emb >>> train_data.vocab_size 19536 >>> train_data[0] Graph(num_nodes=71, num_edges=70, ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)} edata_schemes={}) >>> for tree in train_data: ... input_ids = tree.ndata['x'] ... labels = tree.ndata['y'] ... mask = tree.ndata['mask'] ... # your code here