diff --git a/finetune/sequence_label.py b/finetune/sequence_label.py index dff0f9c531a57758e08ba635478f901baa54fd5d..71610477040e0e6459e90886afc1dbbbac6513c2 100644 --- a/finetune/sequence_label.py +++ b/finetune/sequence_label.py @@ -181,11 +181,11 @@ def chunk_predict(np_inputs, np_probs, np_lens, dev_count=1): seq_st = base_index + i * max_len + 1 seq_en = seq_st + (lens[i] - 2) prob = probs[seq_st:seq_en, :] - infers = np.argmax(probs, -1) + infers = np.argmax(prob, -1) out.append(( inputs[seq_st:seq_en].tolist(), infers.tolist(), - probs.tolist())) + prob.tolist())) base_index += max_len * len(lens) return out @@ -199,14 +199,13 @@ def predict(exe, graph_vars["inputs"].name, graph_vars["probs"].name, graph_vars["seqlen"].name, - graph_vars["probs"].name, ] test_pyreader.start() res = [] while True: try: - inputs, probs, np_lens, np_probs = exe.run(program=test_program, + inputs, probs, np_lens = exe.run(program=test_program, fetch_list=fetch_list) r = chunk_predict(inputs, probs, np_lens, dev_count) res += r