# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) import cv2 import numpy as np import time import sys import paddle import tools.infer.utility as utility from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.data import create_operators, transform from ppocr.postprocess import build_post_process class TextDetector(object): def __init__(self, args): self.det_algorithm = args.det_algorithm self.use_zero_copy_run = args.use_zero_copy_run postprocess_params = {} if self.det_algorithm == "DB": pre_process_list = [{ 'ResizeForTest': { 'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type } }, { 'NormalizeImage': { 'std': [0.229, 0.224, 0.225], 'mean': [0.485, 0.456, 0.406], 'scale': '1./255.', 'order': 'hwc' } }, { 'ToCHWImage': None }, { 'keepKeys': { 'keep_keys': ['image', 'shape'] } }] postprocess_params['name'] = 'DBPostProcess' postprocess_params["thresh"] = args.det_db_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["max_candidates"] = 1000 postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) self.predictor = paddle.jit.load(args.det_model_dir) self.predictor.eval() def order_points_clockwise(self, pts): """ reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py # sort the points based on their x-coordinates """ xSorted = pts[np.argsort(pts[:, 0]), :] # grab the left-most and right-most points from the sorted # x-roodinate points leftMost = xSorted[:2, :] rightMost = xSorted[2:, :] # now, sort the left-most coordinates according to their # y-coordinates so we can grab the top-left and bottom-left # points, respectively leftMost = leftMost[np.argsort(leftMost[:, 1]), :] (tl, bl) = leftMost rightMost = rightMost[np.argsort(rightMost[:, 1]), :] (tr, br) = rightMost rect = np.array([tl, tr, br, bl], dtype="float32") return rect def clip_det_res(self, points, img_height, img_width): for pno in range(points.shape[0]): points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) return points def filter_tag_det_res(self, dt_boxes, image_shape): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: box = self.order_points_clockwise(box) box = self.clip_det_res(box, img_height, img_width) rect_width = int(np.linalg.norm(box[0] - box[1])) rect_height = int(np.linalg.norm(box[0] - box[3])) if rect_width <= 10 or rect_height <= 10: continue dt_boxes_new.append(box) dt_boxes = np.array(dt_boxes_new) return dt_boxes def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: box = self.clip_det_res(box, img_height, img_width) dt_boxes_new.append(box) dt_boxes = np.array(dt_boxes_new) return dt_boxes def __call__(self, img): ori_im = img.copy() data = {'image': img} data = transform(data, self.preprocess_op) img, shape_list = data if img is None: return None, 0 img = np.expand_dims(img, axis=0) shape_list = np.expand_dims(shape_list, axis=0) starttime = time.time() preds = self.predictor(img) post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) elapse = time.time() - starttime return dt_boxes, elapse if __name__ == "__main__": args = utility.parse_args() place = paddle.CPUPlace() paddle.disable_static(place) image_file_list = get_image_file_list(args.image_dir) logger = get_logger() text_detector = TextDetector(args) count = 0 total_time = 0 draw_img_save = "./inference_results" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) dt_boxes, elapse = text_detector(img) if count > 0: total_time += elapse count += 1 print("Predict time of %s:" % image_file, elapse) src_im = utility.draw_text_det_res(dt_boxes, image_file) img_name_pure = image_file.split("/")[-1] cv2.imwrite( os.path.join(draw_img_save, "det_res_%s" % img_name_pure), src_im) if count > 1: print("Avg Time:", total_time / (count - 1))