提交 d03626b9 编写于 作者: Y Yu Yang

Fix SRL training

Fix #374
上级 b3754c77
......@@ -160,6 +160,9 @@ def main():
reader = paddle.batch(
paddle.reader.shuffle(conll05.test(), buf_size=8192), batch_size=10)
test_reader = paddle.batch(
paddle.reader.shuffle(conll05.test(), buf_size=8192), batch_size=10)
feeding = {
'word_data': 0,
'ctx_n2_data': 1,
......@@ -178,7 +181,7 @@ def main():
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if event.batch_id % 1000 == 0:
result = trainer.test(reader=reader, feeding=feeding)
result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, Batch %d, %s" % (
event.pass_id, event.batch_id, result.metrics)
......@@ -187,7 +190,7 @@ def main():
with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f)
result = trainer.test(reader=reader, feeding=feeding)
result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
trainer.train(
......@@ -211,6 +214,7 @@ def main():
output_layer=predict,
parameters=parameters,
input=test_data,
feeding=feeding,
field='id')
assert len(probs) == len(test_data[0][0])
labels_reverse = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册