提交 d15fca53 编写于 作者: D Double_V 提交者: qq_25193841

Merge pull request #6824 from ChenNima/release/2.5-kie-save-res

[kie]add write_kie_result to kie infer tool
上级 b7d99acd
...@@ -88,6 +88,29 @@ def draw_kie_result(batch, node, idx_to_cls, count): ...@@ -88,6 +88,29 @@ def draw_kie_result(batch, node, idx_to_cls, count):
cv2.imwrite(save_path, vis_img) cv2.imwrite(save_path, vis_img)
logger.info("The Kie Image saved in {}".format(save_path)) logger.info("The Kie Image saved in {}".format(save_path))
def write_kie_result(fout, node, data):
"""
Write infer result to output file, sorted by the predict label of each line.
The format keeps the same as the input with additional score attribute.
"""
import json
label = data['label']
annotations = json.loads(label)
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
node_pred_label = max_idx.numpy().tolist()
node_pred_score = max_value.numpy().tolist()
res = []
for i, label in enumerate(node_pred_label):
pred_score = '{:.2f}'.format(node_pred_score[i])
pred_res = {
'label': label,
'transcription': annotations[i]['transcription'],
'score': pred_score,
'points': annotations[i]['points'],
}
res.append(pred_res)
res.sort(key=lambda x: x['label'])
fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
def main(): def main():
global_config = config['Global'] global_config = config['Global']
...@@ -114,7 +137,7 @@ def main(): ...@@ -114,7 +137,7 @@ def main():
warmup_times = 0 warmup_times = 0
count_t = [] count_t = []
with open(save_res_path, "wb") as fout: with open(save_res_path, "w") as fout:
with open(config['Global']['infer_img'], "rb") as f: with open(config['Global']['infer_img'], "rb") as f:
lines = f.readlines() lines = f.readlines()
for index, data_line in enumerate(lines): for index, data_line in enumerate(lines):
...@@ -139,6 +162,8 @@ def main(): ...@@ -139,6 +162,8 @@ def main():
node = F.softmax(node, -1) node = F.softmax(node, -1)
count_t.append(time.time() - st) count_t.append(time.time() - st)
draw_kie_result(batch, node, idx_to_cls, index) draw_kie_result(batch, node, idx_to_cls, index)
write_kie_result(fout, node, data)
fout.close()
logger.info("success!") logger.info("success!")
logger.info("It took {} s for predict {} images.".format( logger.info("It took {} s for predict {} images.".format(
np.sum(count_t), len(count_t))) np.sum(count_t), len(count_t)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册