From 803da664eddfc85bb55e192b7a98c696bf4fe112 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 1 Mar 2017 19:49:17 +0800 Subject: [PATCH] Add test --- demo/sentiment/train_v2.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/demo/sentiment/train_v2.py b/demo/sentiment/train_v2.py index bec07de92a1..a764798addf 100644 --- a/demo/sentiment/train_v2.py +++ b/demo/sentiment/train_v2.py @@ -142,6 +142,28 @@ def data_reader(): yield (word_slot, label) +def test_reader(): + data_dir = "./data/pre-imdb" + train_file = "train_part_000" + test_file = "test_part_000" + dict_file = "dict.txt" + train_file = join_path(data_dir, train_file) + test_file = join_path(data_dir, test_file) + dict_file = join_path(data_dir, dict_file) + + with open(dict_file, 'r') as fdict, open(test_file, 'r') as ftest: + dictionary = dict() + for i, line in enumerate(fdict): + dictionary[line.split('\t')[0]] = i + + for line_count, line in enumerate(ftest): + label, comment = line.strip().split('\t\t') + label = int(label) + words = comment.split() + word_slot = [dictionary[w] for w in words if w in dictionary] + yield (word_slot, label) + + if __name__ == '__main__': data_dir = "./data/pre-imdb" train_list = "train.list" @@ -170,6 +192,13 @@ if __name__ == '__main__': if event.batch_id % 100 == 0: print "Pass %d, Batch %d, Cost %f, %s" % ( event.pass_id, event.batch_id, event.cost, event.metrics) + if isinstance(event, paddle.event.EndPass): + result = trainer.test( + reader=paddle.reader.batched( + test_reader, batch_size=128), + reader_dict={'word': 0, + 'label': 1}) + print "Test with Pass %d, %s" % (event.pass_id, result.metrics) trainer = paddle.trainer.SGD(cost=cost, parameters=parameters, -- GitLab