From d03626b93e68bff16182be9012c7970d4f01f633 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 15 Sep 2017 15:48:54 -0700 Subject: [PATCH] Fix SRL training Fix #374 --- 07.label_semantic_roles/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/07.label_semantic_roles/train.py b/07.label_semantic_roles/train.py index d8a3698..ba86abc 100644 --- a/07.label_semantic_roles/train.py +++ b/07.label_semantic_roles/train.py @@ -160,6 +160,9 @@ def main(): reader = paddle.batch( paddle.reader.shuffle(conll05.test(), buf_size=8192), batch_size=10) + test_reader = paddle.batch( + paddle.reader.shuffle(conll05.test(), buf_size=8192), batch_size=10) + feeding = { 'word_data': 0, 'ctx_n2_data': 1, @@ -178,7 +181,7 @@ def main(): print "Pass %d, Batch %d, Cost %f, %s" % ( event.pass_id, event.batch_id, event.cost, event.metrics) if event.batch_id % 1000 == 0: - result = trainer.test(reader=reader, feeding=feeding) + result = trainer.test(reader=test_reader, feeding=feeding) print "\nTest with Pass %d, Batch %d, %s" % ( event.pass_id, event.batch_id, result.metrics) @@ -187,7 +190,7 @@ def main(): with open('params_pass_%d.tar' % event.pass_id, 'w') as f: parameters.to_tar(f) - result = trainer.test(reader=reader, feeding=feeding) + result = trainer.test(reader=test_reader, feeding=feeding) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) trainer.train( @@ -211,6 +214,7 @@ def main(): output_layer=predict, parameters=parameters, input=test_data, + feeding=feeding, field='id') assert len(probs) == len(test_data[0][0]) labels_reverse = {} -- GitLab