diff --git a/test_tipc/configs/ppocr_det_mobile_params.txt b/test_tipc/configs/ppocr_det_mobile_params.txt index 0e8edb62a8aa881993a95fe9a550de17aceaa435..53bfdd03260f6be8d3679b4873c91e47115a614e 100644 --- a/test_tipc/configs/ppocr_det_mobile_params.txt +++ b/test_tipc/configs/ppocr_det_mobile_params.txt @@ -108,3 +108,15 @@ infer_model:./models/ch_ppocr_mobile_v2.0_det_opt.nb|./models/ch_ppocr_mobile_v2 --config_dir:./config.txt --rec_dict_dir:./ppocr_keys_v1.txt --benchmark:True +===========================paddle2onnx_params=========================== +2onnx: paddle2onnx +--model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--save_file:./inference/det_mobile_onnx/model.onnx +--opset_version:10 +--enable_onnx_checker:True +inference:test_tipc/onnx_inference/predict_det.py +--use_gpu:False +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ \ No newline at end of file diff --git a/test_tipc/onnx_inference/predict_cls.py b/test_tipc/onnx_inference/predict_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..ea70e0f84f3973ae162a5081482126b82d4a80d4 --- /dev/null +++ b/test_tipc/onnx_inference/predict_cls.py @@ -0,0 +1,149 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import copy +import numpy as np +import math +import time +import traceback + +import utility +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif + +logger = get_logger() + + +class TextClassifier(object): + def __init__(self, args): + self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] + self.cls_batch_num = args.cls_batch_num + self.cls_thresh = args.cls_thresh + postprocess_params = { + 'name': 'ClsPostProcess', + "label_list": args.label_list, + } + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors = \ + utility.create_predictor(args, 'cls', logger) + + def resize_norm_img(self, img): + imgC, imgH, imgW = self.cls_image_shape + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if self.cls_image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def __call__(self, img_list): + img_list = copy.deepcopy(img_list) + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + + cls_res = [['', 0.0]] * img_num + batch_num = self.cls_batch_num + elapse = 0 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + starttime = time.time() + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + prob_out = outputs[0] + cls_result = self.postprocess_op(prob_out) + elapse += time.time() - starttime + for rno in range(len(cls_result)): + label, score = cls_result[rno] + cls_res[indices[beg_img_no + rno]] = [label, score] + if '180' in label and score > self.cls_thresh: + img_list[indices[beg_img_no + rno]] = cv2.rotate( + img_list[indices[beg_img_no + rno]], 1) + return img_list, cls_res, elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + text_classifier = TextClassifier(args) + valid_image_file_list = [] + img_list = [] + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + valid_image_file_list.append(image_file) + img_list.append(img) + try: + img_list, cls_res, predict_time = text_classifier(img_list) + except: + logger.info(traceback.format_exc()) + logger.info( + "ERROR!!!! \n" + "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" + "If your model has tps module: " + "TPS does not support variable shape.\n" + "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") + exit() + for ino in range(len(img_list)): + logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], + cls_res[ino])) + logger.info("Total predict time for {} images, cost: {:.3f}".format( + len(img_list), predict_time)) + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/test_tipc/onnx_inference/predict_det.py b/test_tipc/onnx_inference/predict_det.py new file mode 100644 index 0000000000000000000000000000000000000000..48ab0ec75a82dee944a9e59bd7419eea40e20973 --- /dev/null +++ b/test_tipc/onnx_inference/predict_det.py @@ -0,0 +1,223 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import time +import sys + +import utility as utility +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process + +logger = get_logger() + + +class TextDetector(object): + def __init__(self, args): + self.args = args + self.det_algorithm = args.det_algorithm + pre_process_list = [{ + 'DetResizeForTest': { + 'image_shape': [640, 640] + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] + postprocess_params = {} + if self.det_algorithm == "DB": + postprocess_params['name'] = 'DBPostProcess' + postprocess_params["thresh"] = args.det_db_thresh + postprocess_params["box_thresh"] = args.det_db_box_thresh + postprocess_params["max_candidates"] = 1000 + postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio + postprocess_params["use_dilation"] = True + elif self.det_algorithm == "EAST": + postprocess_params['name'] = 'EASTPostProcess' + postprocess_params["score_thresh"] = args.det_east_score_thresh + postprocess_params["cover_thresh"] = args.det_east_cover_thresh + postprocess_params["nms_thresh"] = args.det_east_nms_thresh + elif self.det_algorithm == "SAST": + pre_process_list[0] = { + 'DetResizeForTest': { + 'resize_long': args.det_limit_side_len + } + } + postprocess_params['name'] = 'SASTPostProcess' + postprocess_params["score_thresh"] = args.det_sast_score_thresh + postprocess_params["nms_thresh"] = args.det_sast_nms_thresh + self.det_sast_polygon = args.det_sast_polygon + if self.det_sast_polygon: + postprocess_params["sample_pts_num"] = 6 + postprocess_params["expand_scale"] = 1.2 + postprocess_params["shrink_ratio_of_width"] = 0.2 + else: + postprocess_params["sample_pts_num"] = 2 + postprocess_params["expand_scale"] = 1.0 + postprocess_params["shrink_ratio_of_width"] = 0.3 + else: + logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) + sys.exit(0) + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + print() + self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor( + args, 'det', logger) # paddle.jit.load(args.det_model_dir) + # self.predictor.eval() + + def order_points_clockwise(self, pts): + """ + reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py + # sort the points based on their x-coordinates + """ + xSorted = pts[np.argsort(pts[:, 0]), :] + + # grab the left-most and right-most points from the sorted + # x-roodinate points + leftMost = xSorted[:2, :] + rightMost = xSorted[2:, :] + + # now, sort the left-most coordinates according to their + # y-coordinates so we can grab the top-left and bottom-left + # points, respectively + leftMost = leftMost[np.argsort(leftMost[:, 1]), :] + (tl, bl) = leftMost + + rightMost = rightMost[np.argsort(rightMost[:, 1]), :] + (tr, br) = rightMost + + rect = np.array([tl, tr, br, bl], dtype="float32") + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + starttime = time.time() + + input_dict = {} + input_dict[self.input_tensor.name] = img + + outputs = self.predictor.run(self.output_tensors, input_dict) + + preds = {} + if self.det_algorithm == "EAST": + preds['f_geo'] = outputs[0] + preds['f_score'] = outputs[1] + elif self.det_algorithm == 'SAST': + preds['f_border'] = outputs[0] + preds['f_score'] = outputs[1] + preds['f_tco'] = outputs[2] + preds['f_tvo'] = outputs[3] + elif self.det_algorithm == 'DB': + preds['maps'] = outputs[0] + else: + raise NotImplementedError + + post_result = self.postprocess_op(preds, shape_list) + dt_boxes = post_result[0]['points'] + if self.det_algorithm == "SAST" and self.det_sast_polygon: + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + elapse = time.time() - starttime + return dt_boxes, elapse + + +if __name__ == "__main__": + args = utility.parse_args() + image_file_list = get_image_file_list(args.image_dir) + text_detector = TextDetector(args) + count = 0 + total_time = 0 + draw_img_save = "./inference_results" + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + dt_boxes, elapse = text_detector(img) + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + src_im = utility.draw_text_det_res(dt_boxes, image_file) + img_name_pure = os.path.split(image_file)[-1] + img_path = os.path.join(draw_img_save, + "det_res_{}".format(img_name_pure)) + cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) + if count > 1: + logger.info("Avg Time: {}".format(total_time / (count - 1))) diff --git a/test_tipc/onnx_inference/predict_rec.py b/test_tipc/onnx_inference/predict_rec.py new file mode 100644 index 0000000000000000000000000000000000000000..648cca2c90dbf10979f1a17583455b86885e2e69 --- /dev/null +++ b/test_tipc/onnx_inference/predict_rec.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import math +import time +import traceback + +import utility +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif + +logger = get_logger() + + +class TextRecognizer(object): + def __init__(self, args): + self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] + self.character_type = args.rec_char_type + self.rec_batch_num = args.rec_batch_num + self.rec_algorithm = args.rec_algorithm + postprocess_params = { + 'name': 'CTCLabelDecode', + "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors = \ + utility.create_predictor(args, 'rec', logger) + + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + assert imgC == img.shape[2] + # if self.character_type == "ch": + # imgW = int((32 * max_wh_ratio)) + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def __call__(self, img_list): + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + + # rec_res = [] + rec_res = [['', 0.0]] * img_num + batch_num = self.rec_batch_num + elapse = 0 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + # h, w = img_list[ino].shape[0:2] + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + starttime = time.time() + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + preds = outputs[0] + rec_result = self.postprocess_op(preds) + for rno in range(len(rec_result)): + rec_res[indices[beg_img_no + rno]] = rec_result[rno] + elapse += time.time() - starttime + return rec_res, elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + text_recognizer = TextRecognizer(args) + valid_image_file_list = [] + img_list = [] + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + valid_image_file_list.append(image_file) + img_list.append(img) + try: + rec_res, predict_time = text_recognizer(img_list) + except: + logger.info(traceback.format_exc()) + logger.info( + "ERROR!!!! \n" + "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" + "If your model has tps module: " + "TPS does not support variable shape.\n" + "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") + exit() + for ino in range(len(img_list)): + logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], + rec_res[ino])) + logger.info("Total predict time for {} images, cost: {:.3f}".format( + len(img_list), predict_time)) + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/test_tipc/onnx_inference/predict_system.py b/test_tipc/onnx_inference/predict_system.py new file mode 100644 index 0000000000000000000000000000000000000000..bebb47269f943ed5bb8bacd68212bbebacb01a11 --- /dev/null +++ b/test_tipc/onnx_inference/predict_system.py @@ -0,0 +1,190 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import copy +import numpy as np +import time +from PIL import Image +import utility as utility +import predict_rec as predict_rec +import predict_det as predict_det +import predict_cls as predict_cls +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger +from utility import draw_ocr_box_txt + +logger = get_logger() + + +class TextSystem(object): + def __init__(self, args): + self.text_detector = predict_det.TextDetector(args) + self.text_recognizer = predict_rec.TextRecognizer(args) + self.use_angle_cls = args.use_angle_cls + self.drop_score = args.drop_score + if self.use_angle_cls: + self.text_classifier = predict_cls.TextClassifier(args) + + def get_rotate_crop_image(self, img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + def print_draw_crop_rec_res(self, img_crop_list, rec_res): + bbox_num = len(img_crop_list) + for bno in range(bbox_num): + cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) + logger.info(bno, rec_res[bno]) + + def __call__(self, img): + ori_im = img.copy() + dt_boxes, elapse = self.text_detector(img) + logger.info("dt_boxes num : {}, elapse : {}".format( + len(dt_boxes), elapse)) + if dt_boxes is None: + return None, None + img_crop_list = [] + + dt_boxes = sorted_boxes(dt_boxes) + + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + img_crop = self.get_rotate_crop_image(ori_im, tmp_box) + img_crop_list.append(img_crop) + if self.use_angle_cls: + img_crop_list, angle_list, elapse = self.text_classifier( + img_crop_list) + logger.info("cls num : {}, elapse : {}".format( + len(img_crop_list), elapse)) + + rec_res, elapse = self.text_recognizer(img_crop_list) + logger.info("rec_res num : {}, elapse : {}".format( + len(rec_res), elapse)) + # self.print_draw_crop_rec_res(img_crop_list, rec_res) + filter_boxes, filter_rec_res = [], [] + for box, rec_reuslt in zip(dt_boxes, rec_res): + text, score = rec_reuslt + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_reuslt) + return filter_boxes, filter_rec_res + + +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + text_sys = TextSystem(args) + is_visualize = True + font_path = args.vis_font_path + drop_score = args.drop_score + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + #img = utility.resize_img(img, 640) + #img = utility.padding_img(img, [640,640]) + print(img.shape) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + starttime = time.time() + dt_boxes, rec_res = text_sys(img) + elapse = time.time() - starttime + logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) + + for text, score in rec_res: + logger.info("{}, {:.3f}".format(text, score)) + + if is_visualize: + image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + boxes = dt_boxes + txts = [rec_res[i][0] for i in range(len(rec_res))] + scores = [rec_res[i][1] for i in range(len(rec_res))] + + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=drop_score, + font_path=font_path) + draw_img_save = "./results/" + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + cv2.imwrite( + os.path.join(draw_img_save, os.path.basename(image_file)), + draw_img[:, :, ::-1]) + logger.info("The visualized image saved in {}".format( + os.path.join(draw_img_save, os.path.basename(image_file)))) + + +if __name__ == "__main__": + main(utility.parse_args()) \ No newline at end of file diff --git a/test_tipc/onnx_inference/utility.py b/test_tipc/onnx_inference/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..c770d8db82c6baabecabafe91b0167100b848eea --- /dev/null +++ b/test_tipc/onnx_inference/utility.py @@ -0,0 +1,375 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys +import cv2 +import numpy as np +import json +from PIL import Image, ImageDraw, ImageFont +import math +import onnxruntime as ort + + +def parse_args(): + def str2bool(v): + return v.lower() in ("true", "t", "1") + + parser = argparse.ArgumentParser() + # params for prediction engine + parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--ir_optim", type=str2bool, default=True) + parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--use_fp16", type=str2bool, default=False) + parser.add_argument("--gpu_mem", type=int, default=500) + + # params for text detector + parser.add_argument("--image_dir", type=str) + parser.add_argument("--det_algorithm", type=str, default='DB') + parser.add_argument("--det_model_dir", type=str) + parser.add_argument("--det_limit_side_len", type=float, default=960) + parser.add_argument("--det_limit_type", type=str, default='max') + + # DB parmas + parser.add_argument("--det_db_thresh", type=float, default=0.3) + parser.add_argument("--det_db_box_thresh", type=float, default=0.5) + parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) + parser.add_argument("--max_batch_size", type=int, default=10) + # EAST parmas + parser.add_argument("--det_east_score_thresh", type=float, default=0.8) + parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) + parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) + + # SAST parmas + parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) + parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) + parser.add_argument("--det_sast_polygon", type=bool, default=False) + + # params for text recognizer + parser.add_argument("--rec_algorithm", type=str, default='CRNN') + parser.add_argument("--rec_model_dir", type=str) + parser.add_argument("--rec_image_shape", type=str, default="3, 32, 100") + parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_batch_num", type=int, default=6) + parser.add_argument("--max_text_length", type=int, default=25) + parser.add_argument( + "--rec_char_dict_path", + type=str, + default="./ppocr/utils/ppocr_keys_v1.txt") + parser.add_argument("--use_space_char", type=str2bool, default=True) + parser.add_argument( + "--vis_font_path", type=str, default="./doc/simfang.ttf") + parser.add_argument("--drop_score", type=float, default=0.5) + + # params for text classifier + parser.add_argument("--use_angle_cls", type=str2bool, default=False) + parser.add_argument("--cls_model_dir", type=str) + parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") + parser.add_argument("--label_list", type=list, default=['0', '180']) + parser.add_argument("--cls_batch_num", type=int, default=6) + parser.add_argument("--cls_thresh", type=float, default=0.9) + + parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--use_pdserving", type=str2bool, default=False) + + return parser.parse_args() + + +def create_predictor(args, mode, logger): + if mode == "det": + model_dir = args.det_model_dir + elif mode == 'cls': + model_dir = args.cls_model_dir + else: + model_dir = args.rec_model_dir + + if model_dir is None: + logger.info("not find {} model file path {}".format(mode, model_dir)) + sys.exit(0) + model_file_path = model_dir + + if not os.path.exists(model_file_path): + logger.info("not find model file path {}".format(model_file_path)) + sys.exit(0) + + sess = ort.InferenceSession(model_file_path) + + return sess, sess.get_inputs()[0], None + + +def draw_text_det_res(dt_boxes, img_path): + src_im = cv2.imread(img_path) + for box in dt_boxes: + box = np.array(box).astype(np.int32).reshape(-1, 2) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + return src_im + + +def resize_img(img, input_size=600): + """ + resize img and limit the longest side of the image to input_size + """ + img = np.array(img) + im_shape = img.shape + im_size_max = np.max(im_shape[0:2]) + im_scale = float(input_size) / float(im_size_max) + img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale) + return img + + +def padding_img(img, input_size=[640, 640]): + padding_im_h = input_size[0] + padding_im_w = input_size[1] + im_c = img.shape[2] + im_h = img.shape[0] + im_w = img.shape[1] + padding_im = np.zeros( + (padding_im_h, padding_im_w, im_c), dtype=np.float32) + 255 + padding_im[:im_h, :im_w, :] = img + return padding_im + + +def draw_ocr(image, + boxes, + txts=None, + scores=None, + drop_score=0.5, + font_path="./doc/simfang.ttf"): + """ + Visualize the results of OCR detection and recognition + args: + image(Image|array): RGB image + boxes(list): boxes with shape(N, 4, 2) + txts(list): the texts + scores(list): txxs corresponding scores + drop_score(float): only scores greater than drop_threshold will be visualized + font_path: the path of font which is used to draw text + return(array): + the visualized img + """ + if scores is None: + scores = [1] * len(boxes) + box_num = len(boxes) + for i in range(box_num): + if scores is not None and (scores[i] < drop_score or + math.isnan(scores[i])): + continue + box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64) + image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) + if txts is not None: + img = np.array(resize_img(image, input_size=600)) + txt_img = text_visual( + txts, + scores, + img_h=img.shape[0], + img_w=600, + threshold=drop_score, + font_path=font_path) + img = np.concatenate([np.array(img), np.array(txt_img)], axis=1) + return img + return image + + +def draw_ocr_box_txt(image, + boxes, + txts, + scores=None, + drop_score=0.5, + font_path="./doc/fonts/simfang.ttf"): + h, w = image.height, image.width + img_left = image.copy() + img_right = Image.new('RGB', (w, h), (255, 255, 255)) + + import random + + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(boxes, txts)): + if scores is not None and scores[idx] < drop_score: + continue + color = (random.randint(0, 255), random.randint(0, 255), + random.randint(0, 255)) + draw_left.polygon(box, fill=color) + draw_right.polygon( + [ + box[0][0], box[0][1], box[1][0], box[1][1], box[2][0], + box[2][1], box[3][0], box[3][1] + ], + outline=color) + box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][ + 1])**2) + box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][ + 1])**2) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + cur_y = box[0][1] + for c in txt: + char_size = font.getsize(c) + draw_right.text( + (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font) + cur_y += char_size[1] + else: + font_size = max(int(box_height * 0.8), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + draw_right.text( + [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new('RGB', (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + return np.array(img_show) + + +def str_count(s): + """ + Count the number of Chinese characters, + a single English character and a single number + equal to half the length of Chinese characters. + args: + s(string): the input of string + return(int): + the number of Chinese characters + """ + import string + count_zh = count_pu = 0 + s_len = len(s) + en_dg_count = 0 + for c in s: + if c in string.ascii_letters or c.isdigit() or c.isspace(): + en_dg_count += 1 + elif c.isalpha(): + count_zh += 1 + else: + count_pu += 1 + return s_len - math.ceil(en_dg_count / 2) + + +def text_visual(texts, + scores, + img_h=400, + img_w=600, + threshold=0., + font_path="./doc/simfang.ttf"): + """ + create new blank img and draw txt on it + args: + texts(list): the text will be draw + scores(list|None): corresponding score of each txt + img_h(int): the height of blank img + img_w(int): the width of blank img + font_path: the path of font which is used to draw text + return(array): + """ + if scores is not None: + assert len(texts) == len( + scores), "The number of txts and corresponding scores must match" + + def create_blank_img(): + blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255 + blank_img[:, img_w - 1:] = 0 + blank_img = Image.fromarray(blank_img).convert("RGB") + draw_txt = ImageDraw.Draw(blank_img) + return blank_img, draw_txt + + blank_img, draw_txt = create_blank_img() + + font_size = 20 + txt_color = (0, 0, 0) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + + gap = font_size + 5 + txt_img_list = [] + count, index = 1, 0 + for idx, txt in enumerate(texts): + index += 1 + if scores[idx] < threshold or math.isnan(scores[idx]): + index -= 1 + continue + first_line = True + while str_count(txt) >= img_w // font_size - 4: + tmp = txt + txt = tmp[:img_w // font_size - 4] + if first_line: + new_txt = str(index) + ': ' + txt + first_line = False + else: + new_txt = ' ' + txt + draw_txt.text((0, gap * count), new_txt, txt_color, font=font) + txt = tmp[img_w // font_size - 4:] + if count >= img_h // gap - 1: + txt_img_list.append(np.array(blank_img)) + blank_img, draw_txt = create_blank_img() + count = 0 + count += 1 + if first_line: + new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx]) + else: + new_txt = " " + txt + " " + '%.3f' % (scores[idx]) + draw_txt.text((0, gap * count), new_txt, txt_color, font=font) + # whether add new blank img or not + if count >= img_h // gap - 1 and idx + 1 < len(texts): + txt_img_list.append(np.array(blank_img)) + blank_img, draw_txt = create_blank_img() + count = 0 + count += 1 + txt_img_list.append(np.array(blank_img)) + if len(txt_img_list) == 1: + blank_img = np.array(txt_img_list[0]) + else: + blank_img = np.concatenate(txt_img_list, axis=1) + return np.array(blank_img) + + +def base64_to_cv2(b64str): + import base64 + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + + +def draw_boxes(image, boxes, scores=None, drop_score=0.5): + if scores is None: + scores = [1] * len(boxes) + for (box, score) in zip(boxes, scores): + if score < drop_score: + continue + box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) + image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) + return image + + +if __name__ == '__main__': + test_img = "./doc/test_v2" + predict_txt = "./doc/predict.txt" + f = open(predict_txt, 'r') + data = f.readlines() + img_path, anno = data[0].strip().split('\t') + img_name = os.path.basename(img_path) + img_path = os.path.join(test_img, img_name) + image = Image.open(img_path) + + data = json.loads(anno) + boxes, txts, scores = [], [], [] + for dic in data: + boxes.append(dic['points']) + txts.append(dic['transcription']) + scores.append(round(dic['scores'], 3)) + + new_img = draw_ocr(image, boxes, txts, scores) + + cv2.imwrite(img_name, new_img) diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh new file mode 100644 index 0000000000000000000000000000000000000000..7dbced8eecea04a7efcdbfe5476f4c85b44d4140 --- /dev/null +++ b/test_tipc/test_paddle2onnx.sh @@ -0,0 +1,72 @@ +#!/bin/bash +source test_tipc/common_func.sh + +FILENAME=$1 + +dataline=$(cat ${FILENAME}) +lines=(${dataline}) +# common params +model_name=$(func_parser_value "${lines[1]}") +python=$(func_parser_value "${lines[2]}") + + +# parser params +dataline=$(awk 'NR==111, NR==122{print}' $FILENAME) +IFS=$'\n' +lines=(${dataline}) + +# parser paddle2onnx +padlle2onnx_cmd=$(func_parser_value "${lines[1]}") +infer_model_dir_key=$(func_parser_key "${lines[2]}") +infer_model_dir_value=$(func_parser_value "${lines[2]}") +model_filename_key=$(func_parser_key "${lines[3]}") +model_filename_value=$(func_parser_value "${lines[3]}") +params_filename_key=$(func_parser_key "${lines[4]}") +params_filename_value=$(func_parser_value "${lines[4]}") +save_file_key=$(func_parser_key "${lines[5]}") +save_file_value=$(func_parser_value "${lines[5]}") +opset_version_key=$(func_parser_key "${lines[6]}") +opset_version_value=$(func_parser_value "${lines[6]}") +enable_onnx_checker_key=$(func_parser_key "${lines[7]}") +enable_onnx_checker_value=$(func_parser_value "${lines[7]}") +# parser onnx inference +inference_py=$(func_parser_value "${lines[8]}") +use_gpu_key=$(func_parser_key "${lines[9]}") +use_gpu_value=$(func_parser_value "${lines[9]}") +det_model_key=$(func_parser_key "${lines[10]}") +image_dir_key=$(func_parser_key "${lines[11]}") +image_dir_value=$(func_parser_value "${lines[11]}") + + +LOG_PATH="../../test_tipc/output" +mkdir -p ./test_tipc/output +status_log="${LOG_PATH}/results_paddle2onnx.log" + + +function func_paddle2onnx(){ + IFS='|' + _script=$1 + + # paddle2onnx + set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}") + set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") + set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") + set_save_model=$(func_set_params "${save_file_key}" "${save_file_value}") + set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") + set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker}" + eval $trans_model_cmd + # python inference + set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu_value}") + set_model_dir=$(func_set_params "${det_model_key}" "${save_file_value}") + set_img_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}") + infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_model_dir}" + eval $infer_model_cmd +} + + +echo "################### run test ###################" + +export Count=0 +IFS="|" +func_paddle2onnx \ No newline at end of file