diff --git a/fluid/sequence_tagging_for_ner/train.py b/fluid/sequence_tagging_for_ner/train.py index 7ad5f28546c1971aa247f14079af8fa44d5a02ce..11b9d15d755d118ba2d0538ddcf4517ec929c846 100644 --- a/fluid/sequence_tagging_for_ner/train.py +++ b/fluid/sequence_tagging_for_ner/train.py @@ -15,17 +15,23 @@ from utils import logger, load_dict from utils_extend import to_lodtensor, get_embedding -def test(exe, chunk_evaluator, inference_program, test_data, place): - chunk_evaluator.reset(exe) +def test(exe, chunk_evaluator, inference_program, test_data, test_fetch_list, + place): + chunk_evaluator.reset() for data in test_data(): word = to_lodtensor([x[0] for x in data], place) mark = to_lodtensor([x[1] for x in data], place) target = to_lodtensor([x[2] for x in data], place) - acc = exe.run(inference_program, - feed={"word": word, - "mark": mark, - "target": target}) - return chunk_evaluator.eval(exe) + rets = exe.run(inference_program, + feed={"word": word, + "mark": mark, + "target": target}, + fetch_list=test_fetch_list) + num_infer = np.array(rets[0]) + num_label = np.array(rets[1]) + num_correct = np.array(rets[2]) + chunk_evaluator.update(num_infer[0], num_label[0], num_correct[0]) + return chunk_evaluator.eval() def main(train_data_file, @@ -58,16 +64,16 @@ def main(train_data_file, crf_decode = fluid.layers.crf_decoding( input=feature_out, param_attr=fluid.ParamAttr(name='crfw')) - chunk_evaluator = fluid.evaluator.ChunkEvaluator( - input=crf_decode, - label=target, - chunk_scheme="IOB", - num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0))) + (precision, recall, f1_score, num_infer_chunks, num_label_chunks, + num_correct_chunks) = fluid.layers.chunk_eval( + input=crf_decode, + label=target, + chunk_scheme="IOB", + num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0))) + chunk_evaluator = fluid.metrics.ChunkEvaluator() inference_program = fluid.default_main_program().clone(for_test=True) - with fluid.program_guard(inference_program): - test_target = chunk_evaluator.metrics + chunk_evaluator.states - inference_program = fluid.io.get_inference_program(test_target) + test_fetch_list = [num_infer_chunks, num_label_chunks, num_correct_chunks] if "CE_MODE_X" not in os.environ: train_reader = paddle.batch( @@ -100,26 +106,29 @@ def main(train_data_file, embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor() embedding_param.set(word_vector_values, place) + time_begin = time.time() for pass_id in six.moves.xrange(num_passes): - chunk_evaluator.reset(exe) + chunk_evaluator.reset() for batch_id, data in enumerate(train_reader()): - cost, batch_precision, batch_recall, batch_f1_score = exe.run( + cost_var, nums_infer, nums_label, nums_correct = exe.run( fluid.default_main_program(), feed=feeder.feed(data), - fetch_list=[avg_cost] + chunk_evaluator.metrics) + fetch_list=[ + avg_cost, num_infer_chunks, num_label_chunks, + num_correct_chunks + ]) if batch_id % 5 == 0: - print(cost) - print("Pass " + str(pass_id) + ", Batch " + str( - batch_id) + ", Cost " + str(cost[0]) + ", Precision " + str( - batch_precision[0]) + ", Recall " + str(batch_recall[0]) - + ", F1_score" + str(batch_f1_score[0])) - - pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(exe) + print("Pass " + str(pass_id) + ", Batch " + str(batch_id) + + ", Cost " + str(cost_var[0])) + chunk_evaluator.update(nums_infer, nums_label, nums_correct) + pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval() print("[TrainSet] pass_id:" + str(pass_id) + " pass_precision:" + str( pass_precision) + " pass_recall:" + str(pass_recall) + " pass_f1_score:" + str(pass_f1_score)) + test_pass_precision, test_pass_recall, test_pass_f1_score = test( - exe, chunk_evaluator, inference_program, test_reader, place) + exe, chunk_evaluator, inference_program, test_reader, + test_fetch_list, place) print("[TestSet] pass_id:" + str(pass_id) + " pass_precision:" + str( test_pass_precision) + " pass_recall:" + str(test_pass_recall) + " pass_f1_score:" + str(test_pass_f1_score)) @@ -128,12 +137,10 @@ def main(train_data_file, fluid.io.save_inference_model(save_dirname, ['word', 'mark', 'target'], crf_decode, exe) - if ("CE_MODE_X" in os.environ) and (pass_id % 50 == 0): - if pass_id > 0: - print("kpis train_precision %f" % pass_precision) - print("kpis test_precision %f" % test_pass_precision) - print("kpis train_duration %f" % (time.time() - time_begin)) - time_begin = time.time() + if "CE_MODE_X" in os.environ: + print("kpis train_precision %f" % pass_precision) + print("kpis test_precision %f" % test_pass_precision) + print("kpis train_duration %f" % (time.time() - time_begin)) if __name__ == "__main__":