diff --git a/demo/sentiment/train_v2.py b/demo/sentiment/train_v2.py index 3d595fad30b71ac259cbb14639e0c6792dc8481d..0fa74948533b4362a7a9206e7a787cf217ca5ca2 100644 --- a/demo/sentiment/train_v2.py +++ b/demo/sentiment/train_v2.py @@ -2,10 +2,11 @@ import sys from os.path import join as join_path import paddle.trainer_config_helpers.attrs as attrs from paddle.trainer_config_helpers.poolings import MaxPooling -import paddle.v2 as paddle import paddle.v2.layer as layer import paddle.v2.activation as activation import paddle.v2.data_type as data_type +import paddle.v2.dataset.imdb as imdb +import paddle.v2 as paddle def sequence_conv_pool(input, @@ -189,36 +190,15 @@ def stacked_lstm_net(input_dim, return cost -def data_reader(data_file, dict_file): - def reader(): - with open(dict_file, 'r') as fdict, open(data_file, 'r') as fdata: - dictionary = dict() - for i, line in enumerate(fdict): - dictionary[line.split('\t')[0]] = i - - for line_count, line in enumerate(fdata): - label, comment = line.strip().split('\t\t') - label = int(label) - words = comment.split() - word_slot = [dictionary[w] for w in words if w in dictionary] - yield (word_slot, label) - - return reader - - if __name__ == '__main__': - # data file - train_file = "./data/pre-imdb/train_part_000" - test_file = "./data/pre-imdb/test_part_000" - dict_file = "./data/pre-imdb/dict.txt" - labels = "./data/pre-imdb/labels.list" - # init paddle.init(use_gpu=True, trainer_count=4) # network config - dict_dim = len(open(dict_file).readlines()) - class_dim = len(open(labels).readlines()) + print 'load dictionary...' + word_dict = imdb.word_dict() + dict_dim = len(word_dict) + class_dim = 2 # Please choose the way to build the network # by uncommenting the corresponding line. @@ -246,7 +226,7 @@ if __name__ == '__main__': if isinstance(event, paddle.event.EndPass): result = trainer.test( reader=paddle.reader.batched( - data_reader(test_file, dict_file), batch_size=128), + lambda: imdb.test(word_dict), batch_size=128), reader_dict={'word': 0, 'label': 1}) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) @@ -259,8 +239,8 @@ if __name__ == '__main__': trainer.train( reader=paddle.reader.batched( paddle.reader.shuffle( - data_reader(train_file, dict_file), buf_size=4096), - batch_size=128), + lambda: imdb.train(word_dict), buf_size=1000), + batch_size=100), event_handler=event_handler, reader_dict={'word': 0, 'label': 1}, diff --git a/python/paddle/v2/dataset/imdb.py b/python/paddle/v2/dataset/imdb.py index 433e37380f840f5b7ff619a5f64b99d2ad724b17..db388be1e06d636ef33ec1c2ecf2408e5e1d4d59 100644 --- a/python/paddle/v2/dataset/imdb.py +++ b/python/paddle/v2/dataset/imdb.py @@ -118,3 +118,8 @@ def test(word_idx): return reader_creator( re.compile("aclImdb/test/pos/.*\.txt$"), re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000) + + +def word_dict(): + return build_dict( + re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150)