From 12b6ea9f16972f1f5ac155821f62e5a256fb655e Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Wed, 4 Sep 2019 19:25:13 +0800 Subject: [PATCH] fix NER infer results --- finetune/sequence_label.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/finetune/sequence_label.py b/finetune/sequence_label.py index dff0f9c..7161047 100644 --- a/finetune/sequence_label.py +++ b/finetune/sequence_label.py @@ -181,11 +181,11 @@ def chunk_predict(np_inputs, np_probs, np_lens, dev_count=1): seq_st = base_index + i * max_len + 1 seq_en = seq_st + (lens[i] - 2) prob = probs[seq_st:seq_en, :] - infers = np.argmax(probs, -1) + infers = np.argmax(prob, -1) out.append(( inputs[seq_st:seq_en].tolist(), infers.tolist(), - probs.tolist())) + prob.tolist())) base_index += max_len * len(lens) return out @@ -199,14 +199,13 @@ def predict(exe, graph_vars["inputs"].name, graph_vars["probs"].name, graph_vars["seqlen"].name, - graph_vars["probs"].name, ] test_pyreader.start() res = [] while True: try: - inputs, probs, np_lens, np_probs = exe.run(program=test_program, + inputs, probs, np_lens = exe.run(program=test_program, fetch_list=fetch_list) r = chunk_predict(inputs, probs, np_lens, dev_count) res += r -- GitLab