From d3c755df3fe6009ed2cde1b5dca41196e4024aa7 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 2 Mar 2017 13:41:51 +0800 Subject: [PATCH] Refine code --- demo/sentiment/train_v2.py | 95 ++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 56 deletions(-) diff --git a/demo/sentiment/train_v2.py b/demo/sentiment/train_v2.py index 779bfee5b6..3d595fad30 100644 --- a/demo/sentiment/train_v2.py +++ b/demo/sentiment/train_v2.py @@ -1,3 +1,4 @@ +import sys from os.path import join as join_path import paddle.trainer_config_helpers.attrs as attrs from paddle.trainer_config_helpers.poolings import MaxPooling @@ -188,88 +189,69 @@ def stacked_lstm_net(input_dim, return cost -def data_reader(): - data_dir = "./data/pre-imdb" - train_file = "train_part_000" - test_file = "test_part_000" - dict_file = "dict.txt" - train_file = join_path(data_dir, train_file) - test_file = join_path(data_dir, test_file) - dict_file = join_path(data_dir, dict_file) - - with open(dict_file, 'r') as fdict, open(train_file, 'r') as fdata: - dictionary = dict() - for i, line in enumerate(fdict): - dictionary[line.split('\t')[0]] = i - - for line_count, line in enumerate(fdata): - label, comment = line.strip().split('\t\t') - label = int(label) - words = comment.split() - word_slot = [dictionary[w] for w in words if w in dictionary] - yield (word_slot, label) - - -def test_reader(): - data_dir = "./data/pre-imdb" - train_file = "train_part_000" - test_file = "test_part_000" - dict_file = "dict.txt" - train_file = join_path(data_dir, train_file) - test_file = join_path(data_dir, test_file) - dict_file = join_path(data_dir, dict_file) - - with open(dict_file, 'r') as fdict, open(test_file, 'r') as ftest: - dictionary = dict() - for i, line in enumerate(fdict): - dictionary[line.split('\t')[0]] = i - - for line_count, line in enumerate(ftest): - label, comment = line.strip().split('\t\t') - label = int(label) - words = comment.split() - word_slot = [dictionary[w] for w in words if w in dictionary] - yield (word_slot, label) +def data_reader(data_file, dict_file): + def reader(): + with open(dict_file, 'r') as fdict, open(data_file, 'r') as fdata: + dictionary = dict() + for i, line in enumerate(fdict): + dictionary[line.split('\t')[0]] = i + + for line_count, line in enumerate(fdata): + label, comment = line.strip().split('\t\t') + label = int(label) + words = comment.split() + word_slot = [dictionary[w] for w in words if w in dictionary] + yield (word_slot, label) + + return reader if __name__ == '__main__': - data_dir = "./data/pre-imdb" - train_list = "train.list" - test_list = "test.list" - dict_file = "dict.txt" - dict_dim = len(open(join_path(data_dir, "dict.txt")).readlines()) - class_dim = len(open(join_path(data_dir, 'labels.list')).readlines()) - is_predict = False + # data file + train_file = "./data/pre-imdb/train_part_000" + test_file = "./data/pre-imdb/test_part_000" + dict_file = "./data/pre-imdb/dict.txt" + labels = "./data/pre-imdb/labels.list" # init paddle.init(use_gpu=True, trainer_count=4) # network config - # cost = convolution_net(dict_dim, class_dim=class_dim, is_predict=is_predict) - cost = stacked_lstm_net( - dict_dim, class_dim=class_dim, stacked_num=3, is_predict=is_predict) + dict_dim = len(open(dict_file).readlines()) + class_dim = len(open(labels).readlines()) + + # Please choose the way to build the network + # by uncommenting the corresponding line. + cost = convolution_net(dict_dim, class_dim=class_dim) + # cost = stacked_lstm_net(dict_dim, class_dim=class_dim, stacked_num=3) # create parameters parameters = paddle.parameters.create(cost) + # create optimizer adam_optimizer = paddle.optimizer.Adam( learning_rate=2e-3, regularization=paddle.optimizer.L2Regularization(rate=8e-4), model_average=paddle.optimizer.ModelAverage(average_window=0.5)) + # End batch and end pass event handler def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( + print "\nPass %d, Batch %d, Cost %f, %s" % ( event.pass_id, event.batch_id, event.cost, event.metrics) + else: + sys.stdout.write('.') + sys.stdout.flush() if isinstance(event, paddle.event.EndPass): result = trainer.test( reader=paddle.reader.batched( - test_reader, batch_size=128), + data_reader(test_file, dict_file), batch_size=128), reader_dict={'word': 0, 'label': 1}) - print "Test with Pass %d, %s" % (event.pass_id, result.metrics) + print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) + # create trainer trainer = paddle.trainer.SGD(cost=cost, parameters=parameters, update_equation=adam_optimizer) @@ -277,7 +259,8 @@ if __name__ == '__main__': trainer.train( reader=paddle.reader.batched( paddle.reader.shuffle( - data_reader, buf_size=4096), batch_size=128), + data_reader(train_file, dict_file), buf_size=4096), + batch_size=128), event_handler=event_handler, reader_dict={'word': 0, 'label': 1}, -- GitLab