提交 338ce12b 编写于 作者: R root

for style check

上级 9e544d42
...@@ -98,27 +98,25 @@ def main(train_data_file, test_data_file, vocab_file, target_file, emb_file, ...@@ -98,27 +98,25 @@ def main(train_data_file, test_data_file, vocab_file, target_file, emb_file,
for pass_id in xrange(num_passes): for pass_id in xrange(num_passes):
chunk_evaluator.reset(exe) chunk_evaluator.reset(exe)
for data in train_reader(): for data in train_reader():
print len(data)
cost, batch_precision, batch_recall, batch_f1_score = exe.run( cost, batch_precision, batch_recall, batch_f1_score = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost] + chunk_evaluator.metrics) fetch_list=[avg_cost] + chunk_evaluator.metrics)
if batch_id % 5 == 0: if batch_id % 5 == 0:
print( print("Pass " + str(pass_id) + ", Batch " + str(
"Pass " + str(pass_id) + ", Batch " + str(batch_id) + batch_id) + ", Cost " + str(cost[0]) + ", Precision " + str(
", Cost " + str(cost[0]) + ", Precision " + batch_precision[0]) + ", Recall " + str(batch_recall[0])
str(batch_precision[0]) + ", Recall " + str(batch_recall[0]) + ", F1_score" + str(batch_f1_score[0]))
+ ", F1_score" + str(batch_f1_score[0]))
batch_id = batch_id + 1 batch_id = batch_id + 1
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(exe) pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(exe)
print("[TrainSet] pass_id:" + str(pass_id) + " pass_precision:" + print("[TrainSet] pass_id:" + str(pass_id) + " pass_precision:" + str(
str(pass_precision) + " pass_recall:" + str(pass_recall) + pass_precision) + " pass_recall:" + str(pass_recall) +
" pass_f1_score:" + str(pass_f1_score)) " pass_f1_score:" + str(pass_f1_score))
pass_precision, pass_recall, pass_f1_score = test( pass_precision, pass_recall, pass_f1_score = test(
exe, chunk_evaluator, inference_program, test_reader, place) exe, chunk_evaluator, inference_program, test_reader, place)
print("[TestSet] pass_id:" + str(pass_id) + " pass_precision:" + print("[TestSet] pass_id:" + str(pass_id) + " pass_precision:" + str(
str(pass_precision) + " pass_recall:" + str(pass_recall) + pass_precision) + " pass_recall:" + str(pass_recall) +
" pass_f1_score:" + str(pass_f1_score)) " pass_f1_score:" + str(pass_f1_score))
save_dirname = os.path.join(model_save_dir, "params_pass_%d" % pass_id) save_dirname = os.path.join(model_save_dir, "params_pass_%d" % pass_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册