From d15fca53d299f78d4a8bc5eefeaa7c6866cb07a2 Mon Sep 17 00:00:00 2001 From: Double_V Date: Thu, 21 Jul 2022 09:55:42 +0800 Subject: [PATCH] Merge pull request #6824 from ChenNima/release/2.5-kie-save-res [kie]add write_kie_result to kie infer tool --- tools/infer_kie.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tools/infer_kie.py b/tools/infer_kie.py index 346e2e0a..9375434c 100755 --- a/tools/infer_kie.py +++ b/tools/infer_kie.py @@ -88,6 +88,29 @@ def draw_kie_result(batch, node, idx_to_cls, count): cv2.imwrite(save_path, vis_img) logger.info("The Kie Image saved in {}".format(save_path)) +def write_kie_result(fout, node, data): + """ + Write infer result to output file, sorted by the predict label of each line. + The format keeps the same as the input with additional score attribute. + """ + import json + label = data['label'] + annotations = json.loads(label) + max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) + node_pred_label = max_idx.numpy().tolist() + node_pred_score = max_value.numpy().tolist() + res = [] + for i, label in enumerate(node_pred_label): + pred_score = '{:.2f}'.format(node_pred_score[i]) + pred_res = { + 'label': label, + 'transcription': annotations[i]['transcription'], + 'score': pred_score, + 'points': annotations[i]['points'], + } + res.append(pred_res) + res.sort(key=lambda x: x['label']) + fout.writelines([json.dumps(res, ensure_ascii=False) + '\n']) def main(): global_config = config['Global'] @@ -114,7 +137,7 @@ def main(): warmup_times = 0 count_t = [] - with open(save_res_path, "wb") as fout: + with open(save_res_path, "w") as fout: with open(config['Global']['infer_img'], "rb") as f: lines = f.readlines() for index, data_line in enumerate(lines): @@ -139,6 +162,8 @@ def main(): node = F.softmax(node, -1) count_t.append(time.time() - st) draw_kie_result(batch, node, idx_to_cls, index) + write_kie_result(fout, node, data) + fout.close() logger.info("success!") logger.info("It took {} s for predict {} images.".format( np.sum(count_t), len(count_t))) -- GitLab