From 4a265b5200bb86ef81f08d9fce516330b2c2f41a Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 2 Mar 2017 21:42:11 +0800 Subject: [PATCH] Use reader in dataset imdb.py --- demo/sentiment/train_v2.py | 38 ++++++++------------------------ python/paddle/v2/dataset/imdb.py | 5 +++++ 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/demo/sentiment/train_v2.py b/demo/sentiment/train_v2.py index 3d595fad30..0fa7494853 100644 --- a/demo/sentiment/train_v2.py +++ b/demo/sentiment/train_v2.py @@ -2,10 +2,11 @@ import sys from os.path import join as join_path import paddle.trainer_config_helpers.attrs as attrs from paddle.trainer_config_helpers.poolings import MaxPooling -import paddle.v2 as paddle import paddle.v2.layer as layer import paddle.v2.activation as activation import paddle.v2.data_type as data_type +import paddle.v2.dataset.imdb as imdb +import paddle.v2 as paddle def sequence_conv_pool(input, @@ -189,36 +190,15 @@ def stacked_lstm_net(input_dim, return cost -def data_reader(data_file, dict_file): - def reader(): - with open(dict_file, 'r') as fdict, open(data_file, 'r') as fdata: - dictionary = dict() - for i, line in enumerate(fdict): - dictionary[line.split('\t')[0]] = i - - for line_count, line in enumerate(fdata): - label, comment = line.strip().split('\t\t') - label = int(label) - words = comment.split() - word_slot = [dictionary[w] for w in words if w in dictionary] - yield (word_slot, label) - - return reader - - if __name__ == '__main__': - # data file - train_file = "./data/pre-imdb/train_part_000" - test_file = "./data/pre-imdb/test_part_000" - dict_file = "./data/pre-imdb/dict.txt" - labels = "./data/pre-imdb/labels.list" - # init paddle.init(use_gpu=True, trainer_count=4) # network config - dict_dim = len(open(dict_file).readlines()) - class_dim = len(open(labels).readlines()) + print 'load dictionary...' + word_dict = imdb.word_dict() + dict_dim = len(word_dict) + class_dim = 2 # Please choose the way to build the network # by uncommenting the corresponding line. @@ -246,7 +226,7 @@ if __name__ == '__main__': if isinstance(event, paddle.event.EndPass): result = trainer.test( reader=paddle.reader.batched( - data_reader(test_file, dict_file), batch_size=128), + lambda: imdb.test(word_dict), batch_size=128), reader_dict={'word': 0, 'label': 1}) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) @@ -259,8 +239,8 @@ if __name__ == '__main__': trainer.train( reader=paddle.reader.batched( paddle.reader.shuffle( - data_reader(train_file, dict_file), buf_size=4096), - batch_size=128), + lambda: imdb.train(word_dict), buf_size=1000), + batch_size=100), event_handler=event_handler, reader_dict={'word': 0, 'label': 1}, diff --git a/python/paddle/v2/dataset/imdb.py b/python/paddle/v2/dataset/imdb.py index 433e37380f..db388be1e0 100644 --- a/python/paddle/v2/dataset/imdb.py +++ b/python/paddle/v2/dataset/imdb.py @@ -118,3 +118,8 @@ def test(word_idx): return reader_creator( re.compile("aclImdb/test/pos/.*\.txt$"), re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000) + + +def word_dict(): + return build_dict( + re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150) -- GitLab