提交 d03626b9 编写于 作者: Y Yu Yang

Fix SRL training

Fix #374
上级 b3754c77
...@@ -160,6 +160,9 @@ def main(): ...@@ -160,6 +160,9 @@ def main():
reader = paddle.batch( reader = paddle.batch(
paddle.reader.shuffle(conll05.test(), buf_size=8192), batch_size=10) paddle.reader.shuffle(conll05.test(), buf_size=8192), batch_size=10)
test_reader = paddle.batch(
paddle.reader.shuffle(conll05.test(), buf_size=8192), batch_size=10)
feeding = { feeding = {
'word_data': 0, 'word_data': 0,
'ctx_n2_data': 1, 'ctx_n2_data': 1,
...@@ -178,7 +181,7 @@ def main(): ...@@ -178,7 +181,7 @@ def main():
print "Pass %d, Batch %d, Cost %f, %s" % ( print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost, event.metrics)
if event.batch_id % 1000 == 0: if event.batch_id % 1000 == 0:
result = trainer.test(reader=reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, Batch %d, %s" % ( print "\nTest with Pass %d, Batch %d, %s" % (
event.pass_id, event.batch_id, result.metrics) event.pass_id, event.batch_id, result.metrics)
...@@ -187,7 +190,7 @@ def main(): ...@@ -187,7 +190,7 @@ def main():
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
result = trainer.test(reader=reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
trainer.train( trainer.train(
...@@ -211,6 +214,7 @@ def main(): ...@@ -211,6 +214,7 @@ def main():
output_layer=predict, output_layer=predict,
parameters=parameters, parameters=parameters,
input=test_data, input=test_data,
feeding=feeding,
field='id') field='id')
assert len(probs) == len(test_data[0][0]) assert len(probs) == len(test_data[0][0])
labels_reverse = {} labels_reverse = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册