提交 12b6ea9f 编写于 作者: C chenxuyi

fix NER infer results

上级 58066a1e
......@@ -181,11 +181,11 @@ def chunk_predict(np_inputs, np_probs, np_lens, dev_count=1):
seq_st = base_index + i * max_len + 1
seq_en = seq_st + (lens[i] - 2)
prob = probs[seq_st:seq_en, :]
infers = np.argmax(probs, -1)
infers = np.argmax(prob, -1)
out.append((
inputs[seq_st:seq_en].tolist(),
infers.tolist(),
probs.tolist()))
prob.tolist()))
base_index += max_len * len(lens)
return out
......@@ -199,14 +199,13 @@ def predict(exe,
graph_vars["inputs"].name,
graph_vars["probs"].name,
graph_vars["seqlen"].name,
graph_vars["probs"].name,
]
test_pyreader.start()
res = []
while True:
try:
inputs, probs, np_lens, np_probs = exe.run(program=test_program,
inputs, probs, np_lens = exe.run(program=test_program,
fetch_list=fetch_list)
r = chunk_predict(inputs, probs, np_lens, dev_count)
res += r
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册