diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/README.md b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/README.md index 3c0bd1d6dbc6d8672341091fbe927e8e8cce17f6..74efcf2a6c4307e924909a6e8b1a1687a0094134 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/README.md +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/README.md @@ -36,7 +36,8 @@ def recognize_text(images=[], output_dir='ocr_result', visualization=False, box_thresh=0.5, - text_thresh=0.5) + text_thresh=0.5, + angle_classification_thresh=0.9) ``` 预测API,检测输入图片中的所有中文文本的位置。 @@ -48,6 +49,7 @@ def recognize_text(images=[], * use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量** * box\_thresh (float): 检测文本框置信度的阈值; * text\_thresh (float): 识别中文文本置信度的阈值; +* angle_classification_thresh(float): 文本角度分类置信度的阈值 * visualization (bool): 是否将识别结果保存为图片文件; * output\_dir (str): 图片的保存路径,默认设为 ocr\_result; diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py index cd519726e9a53b7a612c93224ede90c869070c7e..444d1db841618a2ca31fe8269f3005afcdbc4862 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py @@ -203,7 +203,8 @@ class ChineseOCRDBCRNN(hub.Module): output_dir='ocr_result', visualization=False, box_thresh=0.5, - text_thresh=0.5): + text_thresh=0.5, + angle_classification_thresh=0.9): """ Get the chinese texts in the predicted images. Args: @@ -214,7 +215,9 @@ class ChineseOCRDBCRNN(hub.Module): output_dir (str): The directory to store output images. visualization (bool): Whether to save image or not. box_thresh(float): the threshold of the detected text box's confidence - text_thresh(float): the threshold of the recognize chinese texts' confidence + text_thresh(float): the threshold of the chinese text recognition confidence + angle_classification_thresh(float): the threshold of the angle classification confidence + Returns: res (list): The result of chinese texts and save path of images. """ @@ -259,7 +262,9 @@ class ChineseOCRDBCRNN(hub.Module): img_crop = self.get_rotate_crop_image( original_image, tmp_box) img_crop_list.append(img_crop) - img_crop_list, angle_list = self._classify_text(img_crop_list) + img_crop_list, angle_list = self._classify_text( + img_crop_list, + angle_classification_thresh=angle_classification_thresh) rec_results = self._recognize_text(img_crop_list) # if the recognized text confidence score is lower than text_thresh, then drop it @@ -294,12 +299,14 @@ class ChineseOCRDBCRNN(hub.Module): results = self.recognize_text(images_decode, **kwargs) return results - def save_result_image(self, - original_image, - detection_boxes, - rec_results, - output_dir='ocr_result', - text_thresh=0.5): + def save_result_image( + self, + original_image, + detection_boxes, + rec_results, + output_dir='ocr_result', + text_thresh=0.5, + ): image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)) txts = [item[0] for item in rec_results] scores = [item[1] for item in rec_results] @@ -320,7 +327,7 @@ class ChineseOCRDBCRNN(hub.Module): cv2.imwrite(save_file_path, draw_img[:, :, ::-1]) return save_file_path - def _classify_text(self, image_list): + def _classify_text(self, image_list, angle_classification_thresh=0.9): img_list = copy.deepcopy(image_list) img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -360,7 +367,7 @@ class ChineseOCRDBCRNN(hub.Module): score = prob_out[rno][label_idx] label = label_list[label_idx] cls_res[indices[beg_img_no + rno]] = [label, score] - if '180' in label and score > 0.9999: + if '180' in label and score > angle_classification_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1) return img_list, cls_res diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/README.md b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/README.md index b35ac4e03f87330261922df5886c9d26bdc35b9e..ac20c01ce05ce0ac4a7f3c12f3a9e95b189d94a7 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/README.md +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/README.md @@ -35,7 +35,8 @@ def recognize_text(images=[], output_dir='ocr_result', visualization=False, box_thresh=0.5, - text_thresh=0.5) + text_thresh=0.5, + angle_classification_thresh=0.9) ``` 预测API,检测输入图片中的所有中文文本的位置。 @@ -47,6 +48,7 @@ def recognize_text(images=[], * use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量** * box\_thresh (float): 检测文本框置信度的阈值; * text\_thresh (float): 识别中文文本置信度的阈值; +* angle_classification_thresh(float): 文本角度分类置信度的阈值 * visualization (bool): 是否将识别结果保存为图片文件; * output\_dir (str): 图片的保存路径,默认设为 ocr\_result; @@ -141,3 +143,7 @@ pyclipper * 1.0.1 支持mkldnn加速CPU计算 + +* 1.1.0 + + 使用三阶段模型(文本框检测-角度分类-文字识别)识别图片文字。 diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/character.py b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/character.py index 8e5f10211ba441a7dd9b4948413b79c8721eab07..ff88b52e4af1f75b039ca6607b723f152b7a7ab5 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/character.py +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/character.py @@ -22,17 +22,23 @@ class CharacterOps(object): def __init__(self, config): self.character_type = config['character_type'] self.loss_type = config['loss_type'] + self.max_text_len = config['max_text_length'] if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) elif self.character_type == "ch": character_dict_path = config['character_dict_path'] + add_space = False + if 'use_space_char' in config: + add_space = config['use_space_char'] self.character_str = "" with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: - line = line.decode('utf-8').strip("\n") + line = line.decode('utf-8').strip("\n").strip("\r\n") self.character_str += line + if add_space: + self.character_str += " " dict_character = list(self.character_str) elif self.character_type == "en_sensitive": # same with ASTER setting (use 94 char). @@ -46,6 +52,8 @@ class CharacterOps(object): self.end_str = "eos" if self.loss_type == "attention": dict_character = [self.beg_str, self.end_str] + dict_character + elif self.loss_type == "srn": + dict_character = dict_character + [self.beg_str, self.end_str] self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i @@ -90,7 +98,7 @@ class CharacterOps(object): if is_remove_duplicate: if idx > 0 and text_index[idx - 1] == text_index[idx]: continue - char_list.append(self.character[text_index[idx]]) + char_list.append(self.character[int(text_index[idx])]) text = ''.join(char_list) return text @@ -139,6 +147,42 @@ def cal_predicts_accuracy(char_ops, return acc, acc_num, img_num +def cal_predicts_accuracy_srn(char_ops, + preds, + labels, + max_text_len, + is_debug=False): + acc_num = 0 + img_num = 0 + + char_num = char_ops.get_char_num() + + total_len = preds.shape[0] + img_num = int(total_len / max_text_len) + for i in range(img_num): + cur_label = [] + cur_pred = [] + for j in range(max_text_len): + if labels[j + i * max_text_len] != int(char_num - 1): #0 + cur_label.append(labels[j + i * max_text_len][0]) + else: + break + + for j in range(max_text_len + 1): + if j < len(cur_label) and preds[ + j + i * max_text_len][0] != cur_label[j]: + break + elif j == len(cur_label) and j == max_text_len: + acc_num += 1 + break + elif j == len(cur_label) and preds[j + i * max_text_len][0] == int( + char_num - 1): + acc_num += 1 + break + acc = acc_num * 1.0 / img_num + return acc, acc_num, img_num + + def convert_rec_attention_infer_res(preds): img_num = preds.shape[0] target_lod = [0] diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/module.py b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/module.py index 88422344bbd05f1e88d64c3af73f6246c69806e0..c77880392d8c381ffdd40c2252da56c84d5c03ef 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/module.py +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/module.py @@ -25,7 +25,7 @@ from chinese_ocr_db_crnn_server.utils import base64_to_cv2, draw_ocr, get_image_ @moduleinfo( name="chinese_ocr_db_crnn_server", - version="1.0.3", + version="1.1.0", summary= "The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ", author="paddle-dev", @@ -41,24 +41,31 @@ class ChineseOCRDBCRNNServer(hub.Module): char_ops_params = { 'character_type': 'ch', 'character_dict_path': self.character_dict_path, - 'loss_type': 'ctc' + 'loss_type': 'ctc', + 'max_text_length': 25, + 'use_space_char': True } self.char_ops = CharacterOps(char_ops_params) self.rec_image_shape = [3, 32, 320] self._text_detector_module = text_detector_module self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf') - self.pretrained_model_path = os.path.join(self.directory, 'assets', - 'ch_rec_r34_vd_crnn') self.enable_mkldnn = enable_mkldnn - self._set_config() + self.rec_pretrained_model_path = os.path.join( + self.directory, 'inference_model', 'character_rec') + self.cls_pretrained_model_path = os.path.join( + self.directory, 'inference_model', 'angle_cls') + self.rec_predictor, self.rec_input_tensor, self.rec_output_tensors = self._set_config( + self.rec_pretrained_model_path) + self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config( + self.cls_pretrained_model_path) - def _set_config(self): + def _set_config(self, pretrained_model_path): """ - predictor config setting + predictor config path """ - model_file_path = os.path.join(self.pretrained_model_path, 'model') - params_file_path = os.path.join(self.pretrained_model_path, 'params') + model_file_path = os.path.join(pretrained_model_path, 'model') + params_file_path = os.path.join(pretrained_model_path, 'params') config = AnalysisConfig(model_file_path, params_file_path) try: @@ -73,21 +80,25 @@ class ChineseOCRDBCRNNServer(hub.Module): else: config.disable_gpu() if self.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() config.disable_glog_info() - - # use zero copy config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - self.predictor = create_paddle_predictor(config) - input_names = self.predictor.get_input_names() - self.input_tensor = self.predictor.get_input_tensor(input_names[0]) - output_names = self.predictor.get_output_names() - self.output_tensors = [] + + predictor = create_paddle_predictor(config) + + input_names = predictor.get_input_names() + input_tensor = predictor.get_input_tensor(input_names[0]) + output_names = predictor.get_output_names() + output_tensors = [] for output_name in output_names: - output_tensor = self.predictor.get_output_tensor(output_name) - self.output_tensors.append(output_tensor) + output_tensor = predictor.get_output_tensor(output_name) + output_tensors.append(output_tensor) + + return predictor, input_tensor, output_tensors @property def text_detector_module(self): @@ -98,7 +109,7 @@ class ChineseOCRDBCRNNServer(hub.Module): self._text_detector_module = hub.Module( name='chinese_text_detection_db_server', enable_mkldnn=self.enable_mkldnn, - version='1.0.1') + version='1.0.2') return self._text_detector_module def read_images(self, paths=[]): @@ -114,6 +125,7 @@ class ChineseOCRDBCRNNServer(hub.Module): return images def get_rotate_crop_image(self, img, points): + ''' img_height, img_width = img.shape[0:2] left = int(np.min(points[:, 0])) right = int(np.max(points[:, 0])) @@ -122,23 +134,51 @@ class ChineseOCRDBCRNNServer(hub.Module): img_crop = img[top:bottom, left:right, :].copy() points[:, 0] = points[:, 0] - left points[:, 1] = points[:, 1] - top - img_crop_width = int(np.linalg.norm(points[0] - points[1])) - img_crop_height = int(np.linalg.norm(points[0] - points[3])) - pts_std = np.float32([[0, 0], [img_crop_width, 0],\ - [img_crop_width, img_crop_height], [0, img_crop_height]]) + ''' + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) M = cv2.getPerspectiveTransform(points, pts_std) dst_img = cv2.warpPerspective( - img_crop, + img, M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE) + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) dst_img_height, dst_img_width = dst_img.shape[0:2] if dst_img_height * 1.0 / dst_img_width >= 1.5: dst_img = np.rot90(dst_img) return dst_img - def resize_norm_img(self, img, max_wh_ratio): + def resize_norm_img_rec(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape - imgW = int(32 * max_wh_ratio) + assert imgC == img.shape[2] + imgW = int((32 * max_wh_ratio)) + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def resize_norm_img_cls(self, img): + cls_image_shape = [3, 48, 192] + imgC, imgH, imgW = cls_image_shape h = img.shape[0] w = img.shape[1] ratio = w / float(h) @@ -148,7 +188,11 @@ class ChineseOCRDBCRNNServer(hub.Module): resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 + if cls_image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) @@ -162,7 +206,8 @@ class ChineseOCRDBCRNNServer(hub.Module): output_dir='ocr_result', visualization=False, box_thresh=0.5, - text_thresh=0.5): + text_thresh=0.5, + angle_classification_thresh=0.9): """ Get the chinese texts in the predicted images. Args: @@ -173,7 +218,9 @@ class ChineseOCRDBCRNNServer(hub.Module): output_dir (str): The directory to store output images. visualization (bool): Whether to save image or not. box_thresh(float): the threshold of the detected text box's confidence - text_thresh(float): the threshold of the recognize chinese texts' confidence + text_thresh(float): the threshold of the chinese text recognition confidence + angle_classification_thresh(float): the threshold of the angle classification confidence + Returns: res (list): The result of chinese texts and save path of images. """ @@ -199,6 +246,7 @@ class ChineseOCRDBCRNNServer(hub.Module): detection_results = self.text_detector_module.detect_text( images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh) + boxes = [ np.array(item['data']).astype(np.float32) for item in detection_results @@ -207,7 +255,7 @@ class ChineseOCRDBCRNNServer(hub.Module): for index, img_boxes in enumerate(boxes): original_image = predicted_data[index].copy() result = {'save_path': ''} - if img_boxes is None: + if img_boxes.size == 0: result['data'] = [] else: img_crop_list = [] @@ -217,8 +265,11 @@ class ChineseOCRDBCRNNServer(hub.Module): img_crop = self.get_rotate_crop_image( original_image, tmp_box) img_crop_list.append(img_crop) - + img_crop_list, angle_list = self._classify_text( + img_crop_list, + angle_classification_thresh=angle_classification_thresh) rec_results = self._recognize_text(img_crop_list) + # if the recognized text confidence score is lower than text_thresh, then drop it rec_res_final = [] for index, res in enumerate(rec_results): @@ -251,12 +302,14 @@ class ChineseOCRDBCRNNServer(hub.Module): results = self.recognize_text(images_decode, **kwargs) return results - def save_result_image(self, - original_image, - detection_boxes, - rec_results, - output_dir='ocr_result', - text_thresh=0.5): + def save_result_image( + self, + original_image, + detection_boxes, + rec_results, + output_dir='ocr_result', + text_thresh=0.5, + ): image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)) txts = [item[0] for item in rec_results] scores = [item[1] for item in rec_results] @@ -277,32 +330,86 @@ class ChineseOCRDBCRNNServer(hub.Module): cv2.imwrite(save_file_path, draw_img[:, :, ::-1]) return save_file_path - def _recognize_text(self, image_list): - img_num = len(image_list) + def _classify_text(self, image_list, angle_classification_thresh=0.9): + img_list = copy.deepcopy(image_list) + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + + cls_res = [['', 0.0]] * img_num batch_num = 30 - rec_res = [] - predict_time = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] max_wh_ratio = 0 for ino in range(beg_img_no, end_img_no): - h, w = image_list[ino].shape[0:2] - wh_ratio = w / h + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(image_list[ino], max_wh_ratio) + norm_img = self.resize_norm_img_cls(img_list[indices[ino]]) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.zero_copy_run() - rec_idx_batch = self.output_tensors[0].copy_to_cpu() - rec_idx_lod = self.output_tensors[0].lod()[0] - predict_batch = self.output_tensors[1].copy_to_cpu() - predict_lod = self.output_tensors[1].lod()[0] + self.cls_input_tensor.copy_from_cpu(norm_img_batch) + self.cls_predictor.zero_copy_run() + + prob_out = self.cls_output_tensors[0].copy_to_cpu() + label_out = self.cls_output_tensors[1].copy_to_cpu() + if len(label_out.shape) != 1: + prob_out, label_out = label_out, prob_out + label_list = ['0', '180'] + for rno in range(len(label_out)): + label_idx = label_out[rno] + score = prob_out[rno][label_idx] + label = label_list[label_idx] + cls_res[indices[beg_img_no + rno]] = [label, score] + if '180' in label and score > angle_classification_thresh: + img_list[indices[beg_img_no + rno]] = cv2.rotate( + img_list[indices[beg_img_no + rno]], 1) + return img_list, cls_res + + def _recognize_text(self, img_list): + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + + rec_res = [['', 0.0]] * img_num + batch_num = 30 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img_rec(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + + norm_img_batch = np.concatenate(norm_img_batch, axis=0) + norm_img_batch = norm_img_batch.copy() + + self.rec_input_tensor.copy_from_cpu(norm_img_batch) + self.rec_predictor.zero_copy_run() + + rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu() + rec_idx_lod = self.rec_output_tensors[0].lod()[0] + predict_batch = self.rec_output_tensors[1].copy_to_cpu() + predict_lod = self.rec_output_tensors[1].lod()[0] for rno in range(len(rec_idx_lod) - 1): beg = rec_idx_lod[rno] end = rec_idx_lod[rno + 1] @@ -317,9 +424,10 @@ class ChineseOCRDBCRNNServer(hub.Module): if len(valid_ind) == 0: continue score = np.mean(probs[valid_ind, ind[valid_ind]]) - rec_res.append([preds_text, score]) + # rec_res.append([preds_text, score]) + rec_res[indices[beg_img_no + rno]] = [preds_text, score] - return rec_res + return rec_res def save_inference_model(self, dirname, @@ -327,9 +435,12 @@ class ChineseOCRDBCRNNServer(hub.Module): params_filename=None, combined=True): detector_dir = os.path.join(dirname, 'text_detector') + classifier_dir = os.path.join(dirname, 'angle_classifier') recognizer_dir = os.path.join(dirname, 'text_recognizer') self._save_detector_model(detector_dir, model_filename, params_filename, combined) + self._save_classifier_model(classifier_dir, model_filename, + params_filename, combined) self._save_recognizer_model(recognizer_dir, model_filename, params_filename, combined) logger.info("The inference model has been saved in the path {}".format( @@ -354,10 +465,40 @@ class ChineseOCRDBCRNNServer(hub.Module): place = fluid.CPUPlace() exe = fluid.Executor(place) - model_file_path = os.path.join(self.pretrained_model_path, 'model') - params_file_path = os.path.join(self.pretrained_model_path, 'params') + model_file_path = os.path.join(self.rec_pretrained_model_path, 'model') + params_file_path = os.path.join(self.rec_pretrained_model_path, + 'params') + program, feeded_var_names, target_vars = fluid.io.load_inference_model( + dirname=self.rec_pretrained_model_path, + model_filename=model_file_path, + params_filename=params_file_path, + executor=exe) + + fluid.io.save_inference_model( + dirname=dirname, + main_program=program, + executor=exe, + feeded_var_names=feeded_var_names, + target_vars=target_vars, + model_filename=model_filename, + params_filename=params_filename) + + def _save_classifier_model(self, + dirname, + model_filename=None, + params_filename=None, + combined=True): + if combined: + model_filename = "__model__" if not model_filename else model_filename + params_filename = "__params__" if not params_filename else params_filename + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + model_file_path = os.path.join(self.cls_pretrained_model_path, 'model') + params_file_path = os.path.join(self.cls_pretrained_model_path, + 'params') program, feeded_var_names, target_vars = fluid.io.load_inference_model( - dirname=self.pretrained_model_path, + dirname=self.cls_pretrained_model_path, model_filename=model_file_path, params_filename=params_file_path, executor=exe) @@ -429,8 +570,7 @@ class ChineseOCRDBCRNNServer(hub.Module): if __name__ == '__main__': - ocr = ChineseOCRDBCRNNServer(enable_mkldnn=True) - print(ocr.name) + ocr = ChineseOCRDBCRNNServer(enable_mkldnn=False) image_path = [ '/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg', diff --git a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/utils.py b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/utils.py index cc9e9effc1b5904426377617b899d9aba9900d3e..8c41af300cc91de369a473cb7327b794b6cf5715 100644 --- a/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/utils.py +++ b/hub_module/modules/image/text_recognition/chinese_ocr_db_crnn_server/utils.py @@ -175,8 +175,8 @@ def sorted_boxes(dt_boxes): _boxes = list(sorted_boxes) for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): tmp = _boxes[i] _boxes[i] = _boxes[i + 1] _boxes[i + 1] = tmp diff --git a/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md b/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md index 7358497a78290738ff18296feabe86f6a650240e..9ec79fbe7dacdde3cbe4f12dd1a49324b82da22c 100644 --- a/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md +++ b/hub_module/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md @@ -132,3 +132,15 @@ pyclipper * 1.0.1 修复使用在线服务调用模型失败问题 + +* 1.0.2 + + 支持mkldnn加速CPU计算 + +* 1.0.3 + + 增加更多预训练数据,更新预训练参数 + +1.1.0 + +使用超轻量级的三阶段模型(文本框检测-角度分类-文字识别)识别图片文字。 diff --git a/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/README.md b/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/README.md index 4410d176a0262ff0e9c1855b972b598aa1abbbed..756180838c8e67d04bddf8e0abd674587b56c6ed 100644 --- a/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/README.md +++ b/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/README.md @@ -129,3 +129,7 @@ pyclipper * 1.0.2 支持mkldnn加速CPU计算 + +* 1.0.3 + + 增加更多预训练数据,更新预训练参数 diff --git a/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/module.py b/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/module.py index 7baec43d9acba27c711546360e3e280f07cce4e4..53252010fcb3ef3eaf2eb24d832a683e28ceac7b 100644 --- a/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/module.py +++ b/hub_module/modules/image/text_recognition/chinese_text_detection_db_server/module.py @@ -29,7 +29,7 @@ def base64_to_cv2(b64str): @moduleinfo( name="chinese_text_detection_db_server", - version="1.0.1", + version="1.0.2", summary= "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", author="paddle-dev", @@ -41,7 +41,7 @@ class ChineseTextDetectionDBServer(hub.Module): initialize with the necessary elements """ self.pretrained_model_path = os.path.join(self.directory, - 'ch_det_r50_vd_db') + 'inference_model') self.enable_mkldnn = enable_mkldnn self._set_config()