From a6ceff1c0b60609efa591e490d7ad00befefd2ec Mon Sep 17 00:00:00 2001 From: Steffy-zxf <48793257+Steffy-zxf@users.noreply.github.com> Date: Mon, 21 Sep 2020 09:59:48 +0800 Subject: [PATCH] Update ocr (#906) --- .../chinese_ocr_db_crnn_mobile/README.md | 10 +- .../chinese_ocr_db_crnn_mobile/character.py | 48 +++- .../chinese_ocr_db_crnn_mobile/module.py | 240 ++++++++++++++---- .../chinese_ocr_db_crnn_mobile/utils.py | 4 +- .../chinese_ocr_db_crnn_server/README.md | 4 + .../README.md | 2 +- .../module.py | 50 ++-- .../processor.py | 71 ++++-- .../README.md | 4 + 9 files changed, 340 insertions(+), 93 deletions(-) diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/README.md b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/README.md index b7dba8ce..3c0bd1d6 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/README.md +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/README.md @@ -1,6 +1,6 @@ ## 概述 -chinese_ocr_db_crnn_mobile Module用于识别图片当中的汉字。其基于[chinese_text_detection_db_mobile Module](https://www.paddlepaddle.org.cn/hubdetail?name=chinese_text_detection_db_mobile&en_category=TextRecognition)检测得到的文本框,继续识别文本框中的中文文字。识别文字算法采用CRNN(Convolutional Recurrent Neural Network)即卷积递归神经网络。其是DCNN和RNN的组合,专门用于识别图像中的序列式对象。与CTC loss配合使用,进行文字识别,可以直接从文本词级或行级的标注中学习,不需要详细的字符级的标注。该Module是一个超轻量级中文OCR模型,支持直接预测。 +chinese_ocr_db_crnn_mobile Module用于识别图片当中的汉字。其基于[chinese_text_detection_db_mobile Module](https://www.paddlepaddle.org.cn/hubdetail?name=chinese_text_detection_db_mobile&en_category=TextRecognition)检测得到的文本框,继续识别文本框中的中文文字。之后对检测文本框进行角度分类。最终识别文字算法采用CRNN(Convolutional Recurrent Neural Network)即卷积递归神经网络。其是DCNN和RNN的组合,专门用于识别图像中的序列式对象。与CTC loss配合使用,进行文字识别,可以直接从文本词级或行级的标注中学习,不需要详细的字符级的标注。该Module是一个超轻量级中文OCR模型,支持直接预测。

@@ -142,3 +142,11 @@ pyclipper * 1.0.1 修复使用在线服务调用模型失败问题 + +* 1.0.2 + + 支持mkldnn加速CPU计算 + +* 1.1.0 + + 使用超轻量级的三阶段模型(文本框检测-角度分类-文字识别)识别图片文字。 diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/character.py b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/character.py index 8e5f1021..ff88b52e 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/character.py +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/character.py @@ -22,17 +22,23 @@ class CharacterOps(object): def __init__(self, config): self.character_type = config['character_type'] self.loss_type = config['loss_type'] + self.max_text_len = config['max_text_length'] if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) elif self.character_type == "ch": character_dict_path = config['character_dict_path'] + add_space = False + if 'use_space_char' in config: + add_space = config['use_space_char'] self.character_str = "" with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: - line = line.decode('utf-8').strip("\n") + line = line.decode('utf-8').strip("\n").strip("\r\n") self.character_str += line + if add_space: + self.character_str += " " dict_character = list(self.character_str) elif self.character_type == "en_sensitive": # same with ASTER setting (use 94 char). @@ -46,6 +52,8 @@ class CharacterOps(object): self.end_str = "eos" if self.loss_type == "attention": dict_character = [self.beg_str, self.end_str] + dict_character + elif self.loss_type == "srn": + dict_character = dict_character + [self.beg_str, self.end_str] self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i @@ -90,7 +98,7 @@ class CharacterOps(object): if is_remove_duplicate: if idx > 0 and text_index[idx - 1] == text_index[idx]: continue - char_list.append(self.character[text_index[idx]]) + char_list.append(self.character[int(text_index[idx])]) text = ''.join(char_list) return text @@ -139,6 +147,42 @@ def cal_predicts_accuracy(char_ops, return acc, acc_num, img_num +def cal_predicts_accuracy_srn(char_ops, + preds, + labels, + max_text_len, + is_debug=False): + acc_num = 0 + img_num = 0 + + char_num = char_ops.get_char_num() + + total_len = preds.shape[0] + img_num = int(total_len / max_text_len) + for i in range(img_num): + cur_label = [] + cur_pred = [] + for j in range(max_text_len): + if labels[j + i * max_text_len] != int(char_num - 1): #0 + cur_label.append(labels[j + i * max_text_len][0]) + else: + break + + for j in range(max_text_len + 1): + if j < len(cur_label) and preds[ + j + i * max_text_len][0] != cur_label[j]: + break + elif j == len(cur_label) and j == max_text_len: + acc_num += 1 + break + elif j == len(cur_label) and preds[j + i * max_text_len][0] == int( + char_num - 1): + acc_num += 1 + break + acc = acc_num * 1.0 / img_num + return acc, acc_num, img_num + + def convert_rec_attention_infer_res(preds): img_num = preds.shape[0] target_lod = [0] diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py index a3e4bd5b..cd519726 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py @@ -1,8 +1,4 @@ # -*- coding:utf-8 -*- -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import argparse import ast import copy @@ -25,9 +21,10 @@ from chinese_ocr_db_crnn_mobile.utils import base64_to_cv2, draw_ocr, get_image_ @moduleinfo( name="chinese_ocr_db_crnn_mobile", - version="1.0.4", + version="1.1.0", summary= - "The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ", + "The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \ + based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ", author="paddle-dev", author_email="paddle-dev@baidu.com", type="cv/text_recognition") @@ -41,23 +38,31 @@ class ChineseOCRDBCRNN(hub.Module): char_ops_params = { 'character_type': 'ch', 'character_dict_path': self.character_dict_path, - 'loss_type': 'ctc' + 'loss_type': 'ctc', + 'max_text_length': 25, + 'use_space_char': True } self.char_ops = CharacterOps(char_ops_params) self.rec_image_shape = [3, 32, 320] self._text_detector_module = text_detector_module self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf') - self.pretrained_model_path = os.path.join(self.directory, - 'inference_model') self.enable_mkldnn = enable_mkldnn - self._set_config() - def _set_config(self): + self.rec_pretrained_model_path = os.path.join( + self.directory, 'inference_model', 'character_rec') + self.cls_pretrained_model_path = os.path.join( + self.directory, 'inference_model', 'angle_cls') + self.rec_predictor, self.rec_input_tensor, self.rec_output_tensors = self._set_config( + self.rec_pretrained_model_path) + self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config( + self.cls_pretrained_model_path) + + def _set_config(self, pretrained_model_path): """ - predictor config setting + predictor config path """ - model_file_path = os.path.join(self.pretrained_model_path, 'model') - params_file_path = os.path.join(self.pretrained_model_path, 'params') + model_file_path = os.path.join(pretrained_model_path, 'model') + params_file_path = os.path.join(pretrained_model_path, 'params') config = AnalysisConfig(model_file_path, params_file_path) try: @@ -72,21 +77,25 @@ class ChineseOCRDBCRNN(hub.Module): else: config.disable_gpu() if self.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() config.disable_glog_info() - - # use zero copy config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - self.predictor = create_paddle_predictor(config) - input_names = self.predictor.get_input_names() - self.input_tensor = self.predictor.get_input_tensor(input_names[0]) - output_names = self.predictor.get_output_names() - self.output_tensors = [] + + predictor = create_paddle_predictor(config) + + input_names = predictor.get_input_names() + input_tensor = predictor.get_input_tensor(input_names[0]) + output_names = predictor.get_output_names() + output_tensors = [] for output_name in output_names: - output_tensor = self.predictor.get_output_tensor(output_name) - self.output_tensors.append(output_tensor) + output_tensor = predictor.get_output_tensor(output_name) + output_tensors.append(output_tensor) + + return predictor, input_tensor, output_tensors @property def text_detector_module(self): @@ -97,7 +106,7 @@ class ChineseOCRDBCRNN(hub.Module): self._text_detector_module = hub.Module( name='chinese_text_detection_db_mobile', enable_mkldnn=self.enable_mkldnn, - version='1.0.2') + version='1.0.3') return self._text_detector_module def read_images(self, paths=[]): @@ -113,6 +122,7 @@ class ChineseOCRDBCRNN(hub.Module): return images def get_rotate_crop_image(self, img, points): + ''' img_height, img_width = img.shape[0:2] left = int(np.min(points[:, 0])) right = int(np.max(points[:, 0])) @@ -121,23 +131,51 @@ class ChineseOCRDBCRNN(hub.Module): img_crop = img[top:bottom, left:right, :].copy() points[:, 0] = points[:, 0] - left points[:, 1] = points[:, 1] - top - img_crop_width = int(np.linalg.norm(points[0] - points[1])) - img_crop_height = int(np.linalg.norm(points[0] - points[3])) - pts_std = np.float32([[0, 0], [img_crop_width, 0],\ - [img_crop_width, img_crop_height], [0, img_crop_height]]) + ''' + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) M = cv2.getPerspectiveTransform(points, pts_std) dst_img = cv2.warpPerspective( - img_crop, + img, M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE) + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) dst_img_height, dst_img_width = dst_img.shape[0:2] if dst_img_height * 1.0 / dst_img_width >= 1.5: dst_img = np.rot90(dst_img) return dst_img - def resize_norm_img(self, img, max_wh_ratio): + def resize_norm_img_rec(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape - imgW = int(32 * max_wh_ratio) + assert imgC == img.shape[2] + imgW = int((32 * max_wh_ratio)) + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def resize_norm_img_cls(self, img): + cls_image_shape = [3, 48, 192] + imgC, imgH, imgW = cls_image_shape h = img.shape[0] w = img.shape[1] ratio = w / float(h) @@ -147,7 +185,11 @@ class ChineseOCRDBCRNN(hub.Module): resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 + if cls_image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) @@ -198,6 +240,7 @@ class ChineseOCRDBCRNN(hub.Module): detection_results = self.text_detector_module.detect_text( images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh) + boxes = [ np.array(item['data']).astype(np.float32) for item in detection_results @@ -206,7 +249,7 @@ class ChineseOCRDBCRNN(hub.Module): for index, img_boxes in enumerate(boxes): original_image = predicted_data[index].copy() result = {'save_path': ''} - if img_boxes is None: + if img_boxes.size == 0: result['data'] = [] else: img_crop_list = [] @@ -216,8 +259,9 @@ class ChineseOCRDBCRNN(hub.Module): img_crop = self.get_rotate_crop_image( original_image, tmp_box) img_crop_list.append(img_crop) - + img_crop_list, angle_list = self._classify_text(img_crop_list) rec_results = self._recognize_text(img_crop_list) + # if the recognized text confidence score is lower than text_thresh, then drop it rec_res_final = [] for index, res in enumerate(rec_results): @@ -276,32 +320,86 @@ class ChineseOCRDBCRNN(hub.Module): cv2.imwrite(save_file_path, draw_img[:, :, ::-1]) return save_file_path - def _recognize_text(self, image_list): - img_num = len(image_list) + def _classify_text(self, image_list): + img_list = copy.deepcopy(image_list) + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + + cls_res = [['', 0.0]] * img_num batch_num = 30 - rec_res = [] - predict_time = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] max_wh_ratio = 0 for ino in range(beg_img_no, end_img_no): - h, w = image_list[ino].shape[0:2] - wh_ratio = w / h + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(image_list[ino], max_wh_ratio) + norm_img = self.resize_norm_img_cls(img_list[indices[ino]]) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.zero_copy_run() - rec_idx_batch = self.output_tensors[0].copy_to_cpu() - rec_idx_lod = self.output_tensors[0].lod()[0] - predict_batch = self.output_tensors[1].copy_to_cpu() - predict_lod = self.output_tensors[1].lod()[0] + self.cls_input_tensor.copy_from_cpu(norm_img_batch) + self.cls_predictor.zero_copy_run() + + prob_out = self.cls_output_tensors[0].copy_to_cpu() + label_out = self.cls_output_tensors[1].copy_to_cpu() + if len(label_out.shape) != 1: + prob_out, label_out = label_out, prob_out + label_list = ['0', '180'] + for rno in range(len(label_out)): + label_idx = label_out[rno] + score = prob_out[rno][label_idx] + label = label_list[label_idx] + cls_res[indices[beg_img_no + rno]] = [label, score] + if '180' in label and score > 0.9999: + img_list[indices[beg_img_no + rno]] = cv2.rotate( + img_list[indices[beg_img_no + rno]], 1) + return img_list, cls_res + + def _recognize_text(self, img_list): + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + + rec_res = [['', 0.0]] * img_num + batch_num = 30 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img_rec(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + + norm_img_batch = np.concatenate(norm_img_batch, axis=0) + norm_img_batch = norm_img_batch.copy() + + self.rec_input_tensor.copy_from_cpu(norm_img_batch) + self.rec_predictor.zero_copy_run() + + rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu() + rec_idx_lod = self.rec_output_tensors[0].lod()[0] + predict_batch = self.rec_output_tensors[1].copy_to_cpu() + predict_lod = self.rec_output_tensors[1].lod()[0] for rno in range(len(rec_idx_lod) - 1): beg = rec_idx_lod[rno] end = rec_idx_lod[rno + 1] @@ -316,9 +414,10 @@ class ChineseOCRDBCRNN(hub.Module): if len(valid_ind) == 0: continue score = np.mean(probs[valid_ind, ind[valid_ind]]) - rec_res.append([preds_text, score]) + # rec_res.append([preds_text, score]) + rec_res[indices[beg_img_no + rno]] = [preds_text, score] - return rec_res + return rec_res def save_inference_model(self, dirname, @@ -326,9 +425,12 @@ class ChineseOCRDBCRNN(hub.Module): params_filename=None, combined=True): detector_dir = os.path.join(dirname, 'text_detector') + classifier_dir = os.path.join(dirname, 'angle_classifier') recognizer_dir = os.path.join(dirname, 'text_recognizer') self._save_detector_model(detector_dir, model_filename, params_filename, combined) + self._save_classifier_model(classifier_dir, model_filename, + params_filename, combined) self._save_recognizer_model(recognizer_dir, model_filename, params_filename, combined) logger.info("The inference model has been saved in the path {}".format( @@ -353,10 +455,40 @@ class ChineseOCRDBCRNN(hub.Module): place = fluid.CPUPlace() exe = fluid.Executor(place) - model_file_path = os.path.join(self.pretrained_model_path, 'model') - params_file_path = os.path.join(self.pretrained_model_path, 'params') + model_file_path = os.path.join(self.rec_pretrained_model_path, 'model') + params_file_path = os.path.join(self.rec_pretrained_model_path, + 'params') + program, feeded_var_names, target_vars = fluid.io.load_inference_model( + dirname=self.rec_pretrained_model_path, + model_filename=model_file_path, + params_filename=params_file_path, + executor=exe) + + fluid.io.save_inference_model( + dirname=dirname, + main_program=program, + executor=exe, + feeded_var_names=feeded_var_names, + target_vars=target_vars, + model_filename=model_filename, + params_filename=params_filename) + + def _save_classifier_model(self, + dirname, + model_filename=None, + params_filename=None, + combined=True): + if combined: + model_filename = "__model__" if not model_filename else model_filename + params_filename = "__params__" if not params_filename else params_filename + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + model_file_path = os.path.join(self.cls_pretrained_model_path, 'model') + params_file_path = os.path.join(self.cls_pretrained_model_path, + 'params') program, feeded_var_names, target_vars = fluid.io.load_inference_model( - dirname=self.pretrained_model_path, + dirname=self.cls_pretrained_model_path, model_filename=model_file_path, params_filename=params_file_path, executor=exe) @@ -430,7 +562,7 @@ class ChineseOCRDBCRNN(hub.Module): if __name__ == '__main__': ocr = ChineseOCRDBCRNN() image_path = [ - '/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg', + '/mnt/zhangxuefei/PaddleOCR/doc/imgs/2.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/test_image.jpg' ] diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/utils.py b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/utils.py index cc9e9eff..8c41af30 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/utils.py +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/utils.py @@ -175,8 +175,8 @@ def sorted_boxes(dt_boxes): _boxes = list(sorted_boxes) for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): tmp = _boxes[i] _boxes[i] = _boxes[i + 1] _boxes[i + 1] = tmp diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/README.md b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/README.md index 2b0aa7d7..b35ac4e0 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/README.md +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/README.md @@ -137,3 +137,7 @@ pyclipper * 1.0.0 初始发布 + +* 1.0.1 + + 支持mkldnn加速CPU计算 diff --git a/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md b/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md index 6c43493f..7358497a 100644 --- a/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md +++ b/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md @@ -62,7 +62,7 @@ def detect_text(paths=[], import paddlehub as hub import cv2 -text_detector = hub.Module(name="chinese_text_detection_db_mobile", enable_mk) +text_detector = hub.Module(name="chinese_text_detection_db_mobile", enable_mkldnn=True) result = text_detector.detect_text(images=[cv2.imread('/PATH/TO/IMAGE')]) # or diff --git a/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py b/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py index be00f7e2..00de292e 100644 --- a/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py +++ b/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py @@ -29,7 +29,7 @@ def base64_to_cv2(b64str): @moduleinfo( name="chinese_text_detection_db_mobile", - version="1.0.2", + version="1.0.3", summary= "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", author="paddle-dev", @@ -73,7 +73,10 @@ class ChineseTextDetectionDB(hub.Module): config.enable_use_gpu(8000, 0) else: config.disable_gpu() + config.set_cpu_math_library_num_threads(6) if self.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() config.disable_glog_info() @@ -102,19 +105,18 @@ class ChineseTextDetectionDB(hub.Module): images.append(img) return images + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + def filter_tag_det_res(self, dt_boxes, image_shape): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: box = self.order_points_clockwise(box) - left = int(np.min(box[:, 0])) - right = int(np.max(box[:, 0])) - top = int(np.min(box[:, 1])) - bottom = int(np.max(box[:, 1])) - bbox_height = bottom - top - bbox_width = right - left - diffh = math.fabs(box[0, 1] - box[1, 1]) - diffw = math.fabs(box[0, 0] - box[3, 0]) + box = self.clip_det_res(box, img_height, img_width) rect_width = int(np.linalg.norm(box[0] - box[1])) rect_height = int(np.linalg.norm(box[0] - box[3])) if rect_width <= 10 or rect_height <= 10: @@ -168,7 +170,7 @@ class ChineseTextDetectionDB(hub.Module): """ self.check_requirements() - from chinese_text_detection_db_mobile.processor import DBPreProcess, DBPostProcess, draw_boxes, get_image_ext + from chinese_text_detection_db_mobile.processor import DBProcessTest, DBPostProcess, draw_boxes, get_image_ext if use_gpu: try: @@ -188,13 +190,20 @@ class ChineseTextDetectionDB(hub.Module): assert predicted_data != [], "There is not any image to be predicted. Please check the input data." - preprocessor = DBPreProcess() - postprocessor = DBPostProcess(box_thresh) + preprocessor = DBProcessTest(params={'max_side_len': 960}) + postprocessor = DBPostProcess( + params={ + 'thresh': 0.3, + 'box_thresh': 0.5, + 'max_candidates': 1000, + 'unclip_ratio': 2.0 + }) all_imgs = [] all_ratios = [] all_results = [] for original_image in predicted_data: + ori_im = original_image.copy() im, ratio_list = preprocessor(original_image) res = {'save_path': ''} if im is None: @@ -202,11 +211,20 @@ class ChineseTextDetectionDB(hub.Module): else: im = im.copy() - starttime = time.time() self.input_tensor.copy_from_cpu(im) self.predictor.zero_copy_run() - data_out = self.output_tensors[0].copy_to_cpu() - dt_boxes_list = postprocessor(data_out, [ratio_list]) + + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + outs_dict = {} + outs_dict['maps'] = outputs[0] + + # data_out = self.output_tensors[0].copy_to_cpu() + dt_boxes_list = postprocessor(outs_dict, [ratio_list]) + dt_boxes = dt_boxes_list[0] boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape) res['data'] = boxes.astype(np.int).tolist() @@ -328,7 +346,7 @@ class ChineseTextDetectionDB(hub.Module): if __name__ == '__main__': db = ChineseTextDetectionDB() image_path = [ - '/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg', + '/mnt/zhangxuefei/PaddleOCR/doc/imgs/2.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/test_image.jpg' ] diff --git a/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/processor.py b/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/processor.py index aec5a119..289104ab 100644 --- a/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/processor.py +++ b/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/processor.py @@ -12,25 +12,43 @@ import numpy as np import pyclipper -class DBPreProcess(object): - def __init__(self, max_side_len=960): - self.max_side_len = max_side_len +class DBProcessTest(object): + """ + DB pre-process for Test mode + """ + + def __init__(self, params): + super(DBProcessTest, self).__init__() + self.resize_type = 0 + if 'test_image_shape' in params: + self.image_shape = params['test_image_shape'] + # print(self.image_shape) + self.resize_type = 1 + if 'max_side_len' in params: + self.max_side_len = params['max_side_len'] + else: + self.max_side_len = 2400 - def resize_image_type(self, im): + def resize_image_type0(self, im): """ resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) """ + max_side_len = self.max_side_len h, w, _ = im.shape resize_w = w resize_h = h # limit the max side - if max(resize_h, resize_w) > self.max_side_len: + if max(resize_h, resize_w) > max_side_len: if resize_h > resize_w: - ratio = float(self.max_side_len) / resize_h + ratio = float(max_side_len) / resize_h else: - ratio = float(self.max_side_len) / resize_w + ratio = float(max_side_len) / resize_w else: ratio = 1. resize_h = int(resize_h * ratio) @@ -58,19 +76,34 @@ class DBPreProcess(object): ratio_w = resize_w / float(w) return im, (ratio_h, ratio_w) + def resize_image_type1(self, im): + resize_h, resize_w = self.image_shape + ori_h, ori_w = im.shape[:2] # (h, w, c) + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + return im, (ratio_h, ratio_w) + def normalize(self, im): img_mean = [0.485, 0.456, 0.406] img_std = [0.229, 0.224, 0.225] im = im.astype(np.float32, copy=False) im = im / 255 - im -= img_mean - im /= img_std + im[:, :, 0] -= img_mean[0] + im[:, :, 1] -= img_mean[1] + im[:, :, 2] -= img_mean[2] + im[:, :, 0] /= img_std[0] + im[:, :, 1] /= img_std[1] + im[:, :, 2] /= img_std[2] channel_swap = (2, 0, 1) im = im.transpose(channel_swap) return im def __call__(self, im): - im, (ratio_h, ratio_w) = self.resize_image_type(im) + if self.resize_type == 0: + im, (ratio_h, ratio_w) = self.resize_image_type0(im) + else: + im, (ratio_h, ratio_w) = self.resize_image_type1(im) im = self.normalize(im) im = im[np.newaxis, :] return [im, (ratio_h, ratio_w)] @@ -81,10 +114,11 @@ class DBPostProcess(object): The post process for Differentiable Binarization (DB). """ - def __init__(self, thresh=0.3, box_thresh=0.5, max_candidates=1000): - self.thresh = thresh - self.box_thresh = box_thresh - self.max_candidates = max_candidates + def __init__(self, params): + self.thresh = params['thresh'] + self.box_thresh = params['box_thresh'] + self.max_candidates = params['max_candidates'] + self.unclip_ratio = params['unclip_ratio'] self.min_size = 3 def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): @@ -134,7 +168,8 @@ class DBPostProcess(object): scores[index] = score return boxes, scores - def unclip(self, box, unclip_ratio=2.0): + def unclip(self, box): + unclip_ratio = self.unclip_ratio poly = Polygon(box) distance = poly.area * unclip_ratio / poly.length offset = pyclipper.PyclipperOffset() @@ -179,8 +214,10 @@ class DBPostProcess(object): cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] - def __call__(self, predictions, ratio_list): - pred = predictions[:, 0, :, :] + def __call__(self, outs_dict, ratio_list): + pred = outs_dict['maps'] + + pred = pred[:, 0, :, :] segmentation = pred > self.thresh boxes_batch = [] diff --git a/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/README.md b/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/README.md index 393eec37..4410d176 100644 --- a/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/README.md +++ b/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/README.md @@ -125,3 +125,7 @@ pyclipper * 1.0.0 初始发布 + +* 1.0.2 + + 支持mkldnn加速CPU计算 -- GitLab