diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 9ec03396f95bd24704be014633916631ff98e627..420213ee5a6fce1f11c72b960d7e90344dd295ee 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -23,7 +23,7 @@ import copy import numpy as np import math import time - +import traceback import paddle.fluid as fluid import tools.infer.utility as utility @@ -106,10 +106,10 @@ class TextClassifier(object): norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) self.predictor.run([norm_img_batch]) prob_out = self.output_tensors[0].copy_to_cpu() - cls_res = self.postprocess_op(prob_out) + cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime - for rno in range(len(cls_res)): - label, score = cls_res[rno] + for rno in range(len(cls_result)): + label, score = cls_result[rno] cls_res[indices[beg_img_no + rno]] = [label, score] if '180' in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( @@ -133,8 +133,8 @@ def main(args): img_list.append(img) try: img_list, cls_res, predict_time = text_classifier(img_list) - except Exception as e: - print(e) + except: + logger.info(traceback.format_exc()) logger.info( "ERROR!!!! \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" @@ -143,10 +143,10 @@ def main(args): "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() for ino in range(len(img_list)): - print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ + logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ ino])) - print("Total predict time for {} images, cost: {:.3f}".format( + logger.info("Total predict time for {} images, cost: {:.3f}".format( len(img_list), predict_time)) - if __name__ == "__main__": - main(utility.parse_args()) +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 4b4825a66a145faf78d96446604329730a453381..5be27339dbae07c8d99fe442f18e64288d831f79 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -178,11 +178,12 @@ if __name__ == "__main__": if count > 0: total_time += elapse count += 1 - print("Predict time of {}: {}".format(image_file, elapse)) + logger.info("Predict time of {}: {}".format(image_file, elapse)) src_im = utility.draw_text_det_res(dt_boxes, image_file) img_name_pure = os.path.split(image_file)[-1] img_path = os.path.join(draw_img_save, "det_res_{}".format(img_name_pure)) cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) if count > 1: - print("Avg Time:", total_time / (count - 1)) + logger.info("Avg Time:", total_time / (count - 1)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index c1f20ef3c6b42772d47665032504b1fae039cbcd..c615fa0d36e9179d0e11d7e5588d223361aab349 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -22,7 +22,7 @@ import cv2 import numpy as np import math import time - +import traceback import paddle.fluid as fluid import tools.infer.utility as utility @@ -135,8 +135,8 @@ def main(args): img_list.append(img) try: rec_res, predict_time = text_recognizer(img_list) - except Exception as e: - print(e) + except: + logger.info(traceback.format_exc()) logger.info( "ERROR!!!! \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" @@ -145,9 +145,9 @@ def main(args): "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() for ino in range(len(img_list)): - print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ + logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ ino])) - print("Total predict time for {} images, cost: {:.3f}".format( + logger.info("Total predict time for {} images, cost: {:.3f}".format( len(img_list), predict_time)) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 4e81039745ff63985728113ef99fb9e3c54daca5..ae660fdedad9580f098420119946a0291d3aa1f8 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -23,17 +23,21 @@ import numpy as np import time from PIL import Image import tools.infer.utility as utility -from tools.infer.utility import draw_ocr import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det +import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger +from tools.infer.utility import draw_ocr_box_txt class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) + self.use_angle_cls = args.use_angle_cls + if self.use_angle_cls: + self.text_classifier = predict_cls.TextClassifier(args) def get_rotate_crop_image(self, img, points): ''' @@ -72,12 +76,12 @@ class TextSystem(object): bbox_num = len(img_crop_list) for bno in range(bbox_num): cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) - print(bno, rec_res[bno]) + logger.info(bno, rec_res[bno]) def __call__(self, img): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - print("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) + logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) if dt_boxes is None: return None, None img_crop_list = [] @@ -88,8 +92,14 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) + if self.use_angle_cls: + img_crop_list, angle_list, elapse = self.text_classifier( + img_crop_list) + logger.info("cls num : {}, elapse : {}".format( + len(img_crop_list), elapse)) + rec_res, elapse = self.text_recognizer(img_crop_list) - print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) + logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) return dt_boxes, rec_res @@ -119,7 +129,8 @@ def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True - tackle_img_num = 0 + font_path = args.vis_font_path + drop_score = args.drop_score for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: @@ -128,20 +139,16 @@ def main(args): logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() - tackle_img_num += 1 - if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: - text_sys = TextSystem(args) dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime - print("Predict time of %s: %.3fs" % (image_file, elapse)) + logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) - drop_score = 0.5 dt_num = len(dt_boxes) for dno in range(dt_num): text, score = rec_res[dno] if score >= drop_score: text_str = "%s, %.3f" % (text, score) - print(text_str) + logger.info(text_str) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) @@ -149,15 +156,20 @@ def main(args): txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] - draw_img = draw_ocr( - image, boxes, txts, scores, drop_score=drop_score) + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=drop_score, + font_path=font_path) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) - print("The visualized image saved in {}".format( + logger.info("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file)))) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 5a524516a28bd3fa1f7a9f2800ebace7b9776088..ee1f954dcc4b6518cfe454a86650b397b9db449e 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -71,6 +71,7 @@ def parse_args(): parser.add_argument("--use_space_char", type=str2bool, default=True) parser.add_argument( "--vis_font_path", type=str, default="./doc/simfang.ttf") + parser.add_argument("--drop_score", type=float, default=0.5) # params for text classifier parser.add_argument("--use_angle_cls", type=str2bool, default=False) @@ -202,7 +203,12 @@ def draw_ocr(image, return image -def draw_ocr_box_txt(image, boxes, txts): +def draw_ocr_box_txt(image, + boxes, + txts, + scores=None, + drop_score=0.5, + font_path="./doc/simfang.ttf"): h, w = image.height, image.width img_left = image.copy() img_right = Image.new('RGB', (w, h), (255, 255, 255)) @@ -212,7 +218,9 @@ def draw_ocr_box_txt(image, boxes, txts): random.seed(0) draw_left = ImageDraw.Draw(img_left) draw_right = ImageDraw.Draw(img_right) - for (box, txt) in zip(boxes, txts): + for idx, (box, txt) in enumerate(zip(boxes, txts)): + if scores is not None and scores[idx] < drop_score: + continue color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) draw_left.polygon(box, fill=color) @@ -222,14 +230,13 @@ def draw_ocr_box_txt(image, boxes, txts): box[2][1], box[3][0], box[3][1] ], outline=color) - box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][ - 1])**2) - box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][ - 1])**2) + box_height = math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][ + 1]) ** 2) + box_width = math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][ + 1]) ** 2) if box_height > 2 * box_width: font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype( - "./doc/simfang.ttf", font_size, encoding="utf-8") + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") cur_y = box[0][1] for c in txt: char_size = font.getsize(c) @@ -238,8 +245,7 @@ def draw_ocr_box_txt(image, boxes, txts): cur_y += char_size[1] else: font_size = max(int(box_height * 0.8), 10) - font = ImageFont.truetype( - "./doc/simfang.ttf", font_size, encoding="utf-8") + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") draw_right.text( [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) img_left = Image.blend(image, img_left, 0.5)