未验证 提交 9f6ab586 编写于 作者: M Meiyim 提交者: GitHub

Merge pull request #314 from Meiyim/dev

fix NER infer results
...@@ -181,11 +181,11 @@ def chunk_predict(np_inputs, np_probs, np_lens, dev_count=1): ...@@ -181,11 +181,11 @@ def chunk_predict(np_inputs, np_probs, np_lens, dev_count=1):
seq_st = base_index + i * max_len + 1 seq_st = base_index + i * max_len + 1
seq_en = seq_st + (lens[i] - 2) seq_en = seq_st + (lens[i] - 2)
prob = probs[seq_st:seq_en, :] prob = probs[seq_st:seq_en, :]
infers = np.argmax(probs, -1) infers = np.argmax(prob, -1)
out.append(( out.append((
inputs[seq_st:seq_en].tolist(), inputs[seq_st:seq_en].tolist(),
infers.tolist(), infers.tolist(),
probs.tolist())) prob.tolist()))
base_index += max_len * len(lens) base_index += max_len * len(lens)
return out return out
...@@ -199,14 +199,13 @@ def predict(exe, ...@@ -199,14 +199,13 @@ def predict(exe,
graph_vars["inputs"].name, graph_vars["inputs"].name,
graph_vars["probs"].name, graph_vars["probs"].name,
graph_vars["seqlen"].name, graph_vars["seqlen"].name,
graph_vars["probs"].name,
] ]
test_pyreader.start() test_pyreader.start()
res = [] res = []
while True: while True:
try: try:
inputs, probs, np_lens, np_probs = exe.run(program=test_program, inputs, probs, np_lens = exe.run(program=test_program,
fetch_list=fetch_list) fetch_list=fetch_list)
r = chunk_predict(inputs, probs, np_lens, dev_count) r = chunk_predict(inputs, probs, np_lens, dev_count)
res += r res += r
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册