From 01fd2d0f81d2f0079ea9df2f5b5190d1421b5caa Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 7 Feb 2022 07:56:08 +0000 Subject: [PATCH] add system pred save --- tools/infer/predict_system.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 16789b81..64df865a 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -24,6 +24,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 import copy import numpy as np +import json import time import logging from PIL import Image @@ -92,11 +93,11 @@ class TextSystem(object): self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res) filter_boxes, filter_rec_res = [], [] - for box, rec_result in zip(dt_boxes, rec_res): - text, score = rec_result + for box, rec_reuslt in zip(dt_boxes, rec_res): + text, score = rec_reuslt if score >= self.drop_score: filter_boxes.append(box) - filter_rec_res.append(rec_result) + filter_rec_res.append(rec_reuslt) return filter_boxes, filter_rec_res @@ -128,6 +129,9 @@ def main(args): is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score + draw_img_save_dir = args.draw_img_save_dir + os.makedirs(draw_img_save_dir, exist_ok=True) + save_results = [] # warm up 10 times if args.warmup: @@ -157,6 +161,14 @@ def main(args): for text, score in rec_res: logger.debug("{}, {:.3f}".format(text, score)) + res = [{ + "transcription": rec_res[idx][0], + "points": np.array(dt_boxes[idx]).astype(np.int32).tolist(), + } for idx in range(len(dt_boxes)) if rec_res[idx][1] >= drop_score] + save_pred = os.path.basename(image_file) + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + save_results.append(save_pred) + if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) boxes = dt_boxes @@ -170,8 +182,6 @@ def main(args): scores, drop_score=drop_score, font_path=font_path) - draw_img_save_dir = args.draw_img_save_dir - os.makedirs(draw_img_save_dir, exist_ok=True) if flag: image_file = image_file[:-3] + "png" cv2.imwrite( @@ -185,6 +195,10 @@ def main(args): text_sys.text_detector.autolog.report() text_sys.text_recognizer.autolog.report() + with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w') as f: + f.writelines(save_results) + f.close() + if __name__ == "__main__": args = utility.parse_args() -- GitLab