From f8c0fb44eb9452f837335fb60a2bd1ba3da43585 Mon Sep 17 00:00:00 2001 From: Xiangci Li Date: Wed, 4 Mar 2020 13:56:11 -0800 Subject: [PATCH] Fixed errors when only evaluating the fine-tuned model for sequence labeling. --- ernie/finetune/sequence_label.py | 5 +++++ ernie/run_sequence_labeling.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/ernie/finetune/sequence_label.py b/ernie/finetune/sequence_label.py index 1c76ec1..5615c3b 100644 --- a/ernie/finetune/sequence_label.py +++ b/ernie/finetune/sequence_label.py @@ -109,6 +109,11 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): def calculate_f1(num_label, num_infer, num_correct): + + num_infer = np.sum(num_infer) + num_label = np.sum(num_label) + num_correct = np.sum(num_correct) + if num_infer == 0: precision = 0.0 else: diff --git a/ernie/run_sequence_labeling.py b/ernie/run_sequence_labeling.py index 2896154..806d9e5 100644 --- a/ernie/run_sequence_labeling.py +++ b/ernie/run_sequence_labeling.py @@ -289,10 +289,14 @@ def main(args): # final eval on dev set if nccl2_trainer_id ==0 and args.do_val: + if not args.do_train: + current_example, current_epoch = reader.get_train_progress() evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars, current_epoch, 'final') if nccl2_trainer_id == 0 and args.do_test: + if not args.do_train: + current_example, current_epoch = reader.get_train_progress() predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars, current_epoch, 'final') -- GitLab