From 902606499b5359a57941cda6e448e93804f9b629 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Tue, 1 Dec 2020 16:42:10 +0800 Subject: [PATCH] add predict_cls to predict_system --- tools/infer/predict_system.py | 29 ++++++++++++++++++++++------- tools/infer/utility.py | 17 +++++++++++------ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 4e810397..7ebe3ec3 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -23,17 +23,21 @@ import numpy as np import time from PIL import Image import tools.infer.utility as utility -from tools.infer.utility import draw_ocr import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det +import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger +from tools.infer.utility import draw_ocr_box_txt class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) + self.use_angle_cls = args.use_angle_cls + if self.use_angle_cls: + self.text_classifier = predict_cls.TextClassifier(args) def get_rotate_crop_image(self, img, points): ''' @@ -88,6 +92,15 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) + cv2.imwrite( + '/home/zhoujun20/dygraph/PaddleOCR_rc/inference_results/{}.jpg'. + format(bno), img_crop) + if self.use_angle_cls: + img_crop_list, angle_list, elapse = self.text_classifier( + img_crop_list) + print("cls num : {}, elapse : {}".format( + len(img_crop_list), elapse)) + rec_res, elapse = self.text_recognizer(img_crop_list) print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) @@ -119,7 +132,7 @@ def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True - tackle_img_num = 0 + font_path = args.vis_font_path for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: @@ -128,9 +141,6 @@ def main(args): logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() - tackle_img_num += 1 - if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: - text_sys = TextSystem(args) dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime print("Predict time of %s: %.3fs" % (image_file, elapse)) @@ -149,8 +159,13 @@ def main(args): txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] - draw_img = draw_ocr( - image, boxes, txts, scores, drop_score=drop_score) + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=drop_score, + font_path=font_path) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 5a524516..1d8cf22a 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -202,7 +202,12 @@ def draw_ocr(image, return image -def draw_ocr_box_txt(image, boxes, txts): +def draw_ocr_box_txt(image, + boxes, + txts, + scores=None, + drop_score=0.5, + font_path="./doc/simfang.ttf"): h, w = image.height, image.width img_left = image.copy() img_right = Image.new('RGB', (w, h), (255, 255, 255)) @@ -212,7 +217,9 @@ def draw_ocr_box_txt(image, boxes, txts): random.seed(0) draw_left = ImageDraw.Draw(img_left) draw_right = ImageDraw.Draw(img_right) - for (box, txt) in zip(boxes, txts): + for idx, (box, txt) in enumerate(zip(boxes, txts)): + if scores is not None and scores[idx] < drop_score: + continue color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) draw_left.polygon(box, fill=color) @@ -228,8 +235,7 @@ def draw_ocr_box_txt(image, boxes, txts): 1])**2) if box_height > 2 * box_width: font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype( - "./doc/simfang.ttf", font_size, encoding="utf-8") + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") cur_y = box[0][1] for c in txt: char_size = font.getsize(c) @@ -238,8 +244,7 @@ def draw_ocr_box_txt(image, boxes, txts): cur_y += char_size[1] else: font_size = max(int(box_height * 0.8), 10) - font = ImageFont.truetype( - "./doc/simfang.ttf", font_size, encoding="utf-8") + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") draw_right.text( [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) img_left = Image.blend(image, img_left, 0.5) -- GitLab