未验证 提交 2758efec 编写于 作者: S Steffy-zxf 提交者: GitHub

Update ocr (#911)

上级 34e60467
...@@ -36,7 +36,8 @@ def recognize_text(images=[], ...@@ -36,7 +36,8 @@ def recognize_text(images=[],
output_dir='ocr_result', output_dir='ocr_result',
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.5,
text_thresh=0.5) text_thresh=0.5,
angle_classification_thresh=0.9)
``` ```
预测API,检测输入图片中的所有中文文本的位置。 预测API,检测输入图片中的所有中文文本的位置。
...@@ -48,6 +49,7 @@ def recognize_text(images=[], ...@@ -48,6 +49,7 @@ def recognize_text(images=[],
* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量** * use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
* box\_thresh (float): 检测文本框置信度的阈值; * box\_thresh (float): 检测文本框置信度的阈值;
* text\_thresh (float): 识别中文文本置信度的阈值; * text\_thresh (float): 识别中文文本置信度的阈值;
* angle_classification_thresh(float): 文本角度分类置信度的阈值
* visualization (bool): 是否将识别结果保存为图片文件; * visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 ocr\_result; * output\_dir (str): 图片的保存路径,默认设为 ocr\_result;
......
...@@ -203,7 +203,8 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -203,7 +203,8 @@ class ChineseOCRDBCRNN(hub.Module):
output_dir='ocr_result', output_dir='ocr_result',
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.5,
text_thresh=0.5): text_thresh=0.5,
angle_classification_thresh=0.9):
""" """
Get the chinese texts in the predicted images. Get the chinese texts in the predicted images.
Args: Args:
...@@ -214,7 +215,9 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -214,7 +215,9 @@ class ChineseOCRDBCRNN(hub.Module):
output_dir (str): The directory to store output images. output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not. visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the recognize chinese texts' confidence text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
Returns: Returns:
res (list): The result of chinese texts and save path of images. res (list): The result of chinese texts and save path of images.
""" """
...@@ -259,7 +262,9 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -259,7 +262,9 @@ class ChineseOCRDBCRNN(hub.Module):
img_crop = self.get_rotate_crop_image( img_crop = self.get_rotate_crop_image(
original_image, tmp_box) original_image, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
img_crop_list, angle_list = self._classify_text(img_crop_list) img_crop_list, angle_list = self._classify_text(
img_crop_list,
angle_classification_thresh=angle_classification_thresh)
rec_results = self._recognize_text(img_crop_list) rec_results = self._recognize_text(img_crop_list)
# if the recognized text confidence score is lower than text_thresh, then drop it # if the recognized text confidence score is lower than text_thresh, then drop it
...@@ -294,12 +299,14 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -294,12 +299,14 @@ class ChineseOCRDBCRNN(hub.Module):
results = self.recognize_text(images_decode, **kwargs) results = self.recognize_text(images_decode, **kwargs)
return results return results
def save_result_image(self, def save_result_image(
original_image, self,
detection_boxes, original_image,
rec_results, detection_boxes,
output_dir='ocr_result', rec_results,
text_thresh=0.5): output_dir='ocr_result',
text_thresh=0.5,
):
image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
txts = [item[0] for item in rec_results] txts = [item[0] for item in rec_results]
scores = [item[1] for item in rec_results] scores = [item[1] for item in rec_results]
...@@ -320,7 +327,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -320,7 +327,7 @@ class ChineseOCRDBCRNN(hub.Module):
cv2.imwrite(save_file_path, draw_img[:, :, ::-1]) cv2.imwrite(save_file_path, draw_img[:, :, ::-1])
return save_file_path return save_file_path
def _classify_text(self, image_list): def _classify_text(self, image_list, angle_classification_thresh=0.9):
img_list = copy.deepcopy(image_list) img_list = copy.deepcopy(image_list)
img_num = len(img_list) img_num = len(img_list)
# Calculate the aspect ratio of all text bars # Calculate the aspect ratio of all text bars
...@@ -360,7 +367,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -360,7 +367,7 @@ class ChineseOCRDBCRNN(hub.Module):
score = prob_out[rno][label_idx] score = prob_out[rno][label_idx]
label = label_list[label_idx] label = label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score] cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > 0.9999: if '180' in label and score > angle_classification_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1) img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res return img_list, cls_res
......
...@@ -35,7 +35,8 @@ def recognize_text(images=[], ...@@ -35,7 +35,8 @@ def recognize_text(images=[],
output_dir='ocr_result', output_dir='ocr_result',
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.5,
text_thresh=0.5) text_thresh=0.5,
angle_classification_thresh=0.9)
``` ```
预测API,检测输入图片中的所有中文文本的位置。 预测API,检测输入图片中的所有中文文本的位置。
...@@ -47,6 +48,7 @@ def recognize_text(images=[], ...@@ -47,6 +48,7 @@ def recognize_text(images=[],
* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量** * use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
* box\_thresh (float): 检测文本框置信度的阈值; * box\_thresh (float): 检测文本框置信度的阈值;
* text\_thresh (float): 识别中文文本置信度的阈值; * text\_thresh (float): 识别中文文本置信度的阈值;
* angle_classification_thresh(float): 文本角度分类置信度的阈值
* visualization (bool): 是否将识别结果保存为图片文件; * visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 ocr\_result; * output\_dir (str): 图片的保存路径,默认设为 ocr\_result;
...@@ -141,3 +143,7 @@ pyclipper ...@@ -141,3 +143,7 @@ pyclipper
* 1.0.1 * 1.0.1
支持mkldnn加速CPU计算 支持mkldnn加速CPU计算
* 1.1.0
使用三阶段模型(文本框检测-角度分类-文字识别)识别图片文字。
...@@ -22,17 +22,23 @@ class CharacterOps(object): ...@@ -22,17 +22,23 @@ class CharacterOps(object):
def __init__(self, config): def __init__(self, config):
self.character_type = config['character_type'] self.character_type = config['character_type']
self.loss_type = config['loss_type'] self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
if self.character_type == "en": if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif self.character_type == "ch": elif self.character_type == "ch":
character_dict_path = config['character_dict_path'] character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config:
add_space = config['use_space_char']
self.character_str = "" self.character_str = ""
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
line = line.decode('utf-8').strip("\n") line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line self.character_str += line
if add_space:
self.character_str += " "
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif self.character_type == "en_sensitive": elif self.character_type == "en_sensitive":
# same with ASTER setting (use 94 char). # same with ASTER setting (use 94 char).
...@@ -46,6 +52,8 @@ class CharacterOps(object): ...@@ -46,6 +52,8 @@ class CharacterOps(object):
self.end_str = "eos" self.end_str = "eos"
if self.loss_type == "attention": if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character dict_character = [self.beg_str, self.end_str] + dict_character
elif self.loss_type == "srn":
dict_character = dict_character + [self.beg_str, self.end_str]
self.dict = {} self.dict = {}
for i, char in enumerate(dict_character): for i, char in enumerate(dict_character):
self.dict[char] = i self.dict[char] = i
...@@ -90,7 +98,7 @@ class CharacterOps(object): ...@@ -90,7 +98,7 @@ class CharacterOps(object):
if is_remove_duplicate: if is_remove_duplicate:
if idx > 0 and text_index[idx - 1] == text_index[idx]: if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue continue
char_list.append(self.character[text_index[idx]]) char_list.append(self.character[int(text_index[idx])])
text = ''.join(char_list) text = ''.join(char_list)
return text return text
...@@ -139,6 +147,42 @@ def cal_predicts_accuracy(char_ops, ...@@ -139,6 +147,42 @@ def cal_predicts_accuracy(char_ops,
return acc, acc_num, img_num return acc, acc_num, img_num
def cal_predicts_accuracy_srn(char_ops,
preds,
labels,
max_text_len,
is_debug=False):
acc_num = 0
img_num = 0
char_num = char_ops.get_char_num()
total_len = preds.shape[0]
img_num = int(total_len / max_text_len)
for i in range(img_num):
cur_label = []
cur_pred = []
for j in range(max_text_len):
if labels[j + i * max_text_len] != int(char_num - 1): #0
cur_label.append(labels[j + i * max_text_len][0])
else:
break
for j in range(max_text_len + 1):
if j < len(cur_label) and preds[
j + i * max_text_len][0] != cur_label[j]:
break
elif j == len(cur_label) and j == max_text_len:
acc_num += 1
break
elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(
char_num - 1):
acc_num += 1
break
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def convert_rec_attention_infer_res(preds): def convert_rec_attention_infer_res(preds):
img_num = preds.shape[0] img_num = preds.shape[0]
target_lod = [0] target_lod = [0]
......
...@@ -25,7 +25,7 @@ from chinese_ocr_db_crnn_server.utils import base64_to_cv2, draw_ocr, get_image_ ...@@ -25,7 +25,7 @@ from chinese_ocr_db_crnn_server.utils import base64_to_cv2, draw_ocr, get_image_
@moduleinfo( @moduleinfo(
name="chinese_ocr_db_crnn_server", name="chinese_ocr_db_crnn_server",
version="1.0.3", version="1.1.0",
summary= summary=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ", "The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ",
author="paddle-dev", author="paddle-dev",
...@@ -41,24 +41,31 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -41,24 +41,31 @@ class ChineseOCRDBCRNNServer(hub.Module):
char_ops_params = { char_ops_params = {
'character_type': 'ch', 'character_type': 'ch',
'character_dict_path': self.character_dict_path, 'character_dict_path': self.character_dict_path,
'loss_type': 'ctc' 'loss_type': 'ctc',
'max_text_length': 25,
'use_space_char': True
} }
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
self.rec_image_shape = [3, 32, 320] self.rec_image_shape = [3, 32, 320]
self._text_detector_module = text_detector_module self._text_detector_module = text_detector_module
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf') self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.pretrained_model_path = os.path.join(self.directory, 'assets',
'ch_rec_r34_vd_crnn')
self.enable_mkldnn = enable_mkldnn self.enable_mkldnn = enable_mkldnn
self._set_config() self.rec_pretrained_model_path = os.path.join(
self.directory, 'inference_model', 'character_rec')
self.cls_pretrained_model_path = os.path.join(
self.directory, 'inference_model', 'angle_cls')
self.rec_predictor, self.rec_input_tensor, self.rec_output_tensors = self._set_config(
self.rec_pretrained_model_path)
self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config(
self.cls_pretrained_model_path)
def _set_config(self): def _set_config(self, pretrained_model_path):
""" """
predictor config setting predictor config path
""" """
model_file_path = os.path.join(self.pretrained_model_path, 'model') model_file_path = os.path.join(pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params') params_file_path = os.path.join(pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path) config = AnalysisConfig(model_file_path, params_file_path)
try: try:
...@@ -73,21 +80,25 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -73,21 +80,25 @@ class ChineseOCRDBCRNNServer(hub.Module):
else: else:
config.disable_gpu() config.disable_gpu()
if self.enable_mkldnn: if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn() config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config)
input_names = self.predictor.get_input_names() predictor = create_paddle_predictor(config)
self.input_tensor = self.predictor.get_input_tensor(input_names[0])
output_names = self.predictor.get_output_names() input_names = predictor.get_input_names()
self.output_tensors = [] input_tensor = predictor.get_input_tensor(input_names[0])
output_names = predictor.get_output_names()
output_tensors = []
for output_name in output_names: for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name) output_tensor = predictor.get_output_tensor(output_name)
self.output_tensors.append(output_tensor) output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors
@property @property
def text_detector_module(self): def text_detector_module(self):
...@@ -98,7 +109,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -98,7 +109,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
self._text_detector_module = hub.Module( self._text_detector_module = hub.Module(
name='chinese_text_detection_db_server', name='chinese_text_detection_db_server',
enable_mkldnn=self.enable_mkldnn, enable_mkldnn=self.enable_mkldnn,
version='1.0.1') version='1.0.2')
return self._text_detector_module return self._text_detector_module
def read_images(self, paths=[]): def read_images(self, paths=[]):
...@@ -114,6 +125,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -114,6 +125,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
return images return images
def get_rotate_crop_image(self, img, points): def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2] img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0])) left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0])) right = int(np.max(points[:, 0]))
...@@ -122,23 +134,51 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -122,23 +134,51 @@ class ChineseOCRDBCRNNServer(hub.Module):
img_crop = img[top:bottom, left:right, :].copy() img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1])) '''
img_crop_height = int(np.linalg.norm(points[0] - points[3])) img_crop_width = int(
pts_std = np.float32([[0, 0], [img_crop_width, 0],\ max(
[img_crop_width, img_crop_height], [0, img_crop_height]]) np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std) M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective( dst_img = cv2.warpPerspective(
img_crop, img,
M, (img_crop_width, img_crop_height), M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE) borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2] dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5: if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img) dst_img = np.rot90(dst_img)
return dst_img return dst_img
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img_rec(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
imgW = int(32 * max_wh_ratio) assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_cls(self, img):
cls_image_shape = [3, 48, 192]
imgC, imgH, imgW = cls_image_shape
h = img.shape[0] h = img.shape[0]
w = img.shape[1] w = img.shape[1]
ratio = w / float(h) ratio = w / float(h)
...@@ -148,7 +188,11 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -148,7 +188,11 @@ class ChineseOCRDBCRNNServer(hub.Module):
resized_w = int(math.ceil(imgH * ratio)) resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32') resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255 if cls_image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5 resized_image -= 0.5
resized_image /= 0.5 resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
...@@ -162,7 +206,8 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -162,7 +206,8 @@ class ChineseOCRDBCRNNServer(hub.Module):
output_dir='ocr_result', output_dir='ocr_result',
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.5,
text_thresh=0.5): text_thresh=0.5,
angle_classification_thresh=0.9):
""" """
Get the chinese texts in the predicted images. Get the chinese texts in the predicted images.
Args: Args:
...@@ -173,7 +218,9 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -173,7 +218,9 @@ class ChineseOCRDBCRNNServer(hub.Module):
output_dir (str): The directory to store output images. output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not. visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the recognize chinese texts' confidence text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
Returns: Returns:
res (list): The result of chinese texts and save path of images. res (list): The result of chinese texts and save path of images.
""" """
...@@ -199,6 +246,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -199,6 +246,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
detection_results = self.text_detector_module.detect_text( detection_results = self.text_detector_module.detect_text(
images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh) images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh)
boxes = [ boxes = [
np.array(item['data']).astype(np.float32) np.array(item['data']).astype(np.float32)
for item in detection_results for item in detection_results
...@@ -207,7 +255,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -207,7 +255,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
for index, img_boxes in enumerate(boxes): for index, img_boxes in enumerate(boxes):
original_image = predicted_data[index].copy() original_image = predicted_data[index].copy()
result = {'save_path': ''} result = {'save_path': ''}
if img_boxes is None: if img_boxes.size == 0:
result['data'] = [] result['data'] = []
else: else:
img_crop_list = [] img_crop_list = []
...@@ -217,8 +265,11 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -217,8 +265,11 @@ class ChineseOCRDBCRNNServer(hub.Module):
img_crop = self.get_rotate_crop_image( img_crop = self.get_rotate_crop_image(
original_image, tmp_box) original_image, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
img_crop_list, angle_list = self._classify_text(
img_crop_list,
angle_classification_thresh=angle_classification_thresh)
rec_results = self._recognize_text(img_crop_list) rec_results = self._recognize_text(img_crop_list)
# if the recognized text confidence score is lower than text_thresh, then drop it # if the recognized text confidence score is lower than text_thresh, then drop it
rec_res_final = [] rec_res_final = []
for index, res in enumerate(rec_results): for index, res in enumerate(rec_results):
...@@ -251,12 +302,14 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -251,12 +302,14 @@ class ChineseOCRDBCRNNServer(hub.Module):
results = self.recognize_text(images_decode, **kwargs) results = self.recognize_text(images_decode, **kwargs)
return results return results
def save_result_image(self, def save_result_image(
original_image, self,
detection_boxes, original_image,
rec_results, detection_boxes,
output_dir='ocr_result', rec_results,
text_thresh=0.5): output_dir='ocr_result',
text_thresh=0.5,
):
image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
txts = [item[0] for item in rec_results] txts = [item[0] for item in rec_results]
scores = [item[1] for item in rec_results] scores = [item[1] for item in rec_results]
...@@ -277,32 +330,86 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -277,32 +330,86 @@ class ChineseOCRDBCRNNServer(hub.Module):
cv2.imwrite(save_file_path, draw_img[:, :, ::-1]) cv2.imwrite(save_file_path, draw_img[:, :, ::-1])
return save_file_path return save_file_path
def _recognize_text(self, image_list): def _classify_text(self, image_list, angle_classification_thresh=0.9):
img_num = len(image_list) img_list = copy.deepcopy(image_list)
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
cls_res = [['', 0.0]] * img_num
batch_num = 30 batch_num = 30
rec_res = []
predict_time = 0
for beg_img_no in range(0, img_num, batch_num): for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
max_wh_ratio = 0 max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
h, w = image_list[ino].shape[0:2] h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(image_list[ino], max_wh_ratio) norm_img = self.resize_norm_img_cls(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0]
predict_batch = self.output_tensors[1].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0]
self.cls_input_tensor.copy_from_cpu(norm_img_batch)
self.cls_predictor.zero_copy_run()
prob_out = self.cls_output_tensors[0].copy_to_cpu()
label_out = self.cls_output_tensors[1].copy_to_cpu()
if len(label_out.shape) != 1:
prob_out, label_out = label_out, prob_out
label_list = ['0', '180']
for rno in range(len(label_out)):
label_idx = label_out[rno]
score = prob_out[rno][label_idx]
label = label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > angle_classification_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res
def _recognize_text(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = 30
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img_rec(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy()
self.rec_input_tensor.copy_from_cpu(norm_img_batch)
self.rec_predictor.zero_copy_run()
rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu()
rec_idx_lod = self.rec_output_tensors[0].lod()[0]
predict_batch = self.rec_output_tensors[1].copy_to_cpu()
predict_lod = self.rec_output_tensors[1].lod()[0]
for rno in range(len(rec_idx_lod) - 1): for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno] beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1] end = rec_idx_lod[rno + 1]
...@@ -317,9 +424,10 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -317,9 +424,10 @@ class ChineseOCRDBCRNNServer(hub.Module):
if len(valid_ind) == 0: if len(valid_ind) == 0:
continue continue
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score]) # rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
return rec_res return rec_res
def save_inference_model(self, def save_inference_model(self,
dirname, dirname,
...@@ -327,9 +435,12 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -327,9 +435,12 @@ class ChineseOCRDBCRNNServer(hub.Module):
params_filename=None, params_filename=None,
combined=True): combined=True):
detector_dir = os.path.join(dirname, 'text_detector') detector_dir = os.path.join(dirname, 'text_detector')
classifier_dir = os.path.join(dirname, 'angle_classifier')
recognizer_dir = os.path.join(dirname, 'text_recognizer') recognizer_dir = os.path.join(dirname, 'text_recognizer')
self._save_detector_model(detector_dir, model_filename, params_filename, self._save_detector_model(detector_dir, model_filename, params_filename,
combined) combined)
self._save_classifier_model(classifier_dir, model_filename,
params_filename, combined)
self._save_recognizer_model(recognizer_dir, model_filename, self._save_recognizer_model(recognizer_dir, model_filename,
params_filename, combined) params_filename, combined)
logger.info("The inference model has been saved in the path {}".format( logger.info("The inference model has been saved in the path {}".format(
...@@ -354,10 +465,40 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -354,10 +465,40 @@ class ChineseOCRDBCRNNServer(hub.Module):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
model_file_path = os.path.join(self.pretrained_model_path, 'model') model_file_path = os.path.join(self.rec_pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params') params_file_path = os.path.join(self.rec_pretrained_model_path,
'params')
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.rec_pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
def _save_classifier_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
model_file_path = os.path.join(self.cls_pretrained_model_path, 'model')
params_file_path = os.path.join(self.cls_pretrained_model_path,
'params')
program, feeded_var_names, target_vars = fluid.io.load_inference_model( program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.pretrained_model_path, dirname=self.cls_pretrained_model_path,
model_filename=model_file_path, model_filename=model_file_path,
params_filename=params_file_path, params_filename=params_file_path,
executor=exe) executor=exe)
...@@ -429,8 +570,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -429,8 +570,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
if __name__ == '__main__': if __name__ == '__main__':
ocr = ChineseOCRDBCRNNServer(enable_mkldnn=True) ocr = ChineseOCRDBCRNNServer(enable_mkldnn=False)
print(ocr.name)
image_path = [ image_path = [
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg',
......
...@@ -175,8 +175,8 @@ def sorted_boxes(dt_boxes): ...@@ -175,8 +175,8 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes) _boxes = list(sorted_boxes)
for i in range(num_boxes - 1): for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]): (_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i] tmp = _boxes[i]
_boxes[i] = _boxes[i + 1] _boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp _boxes[i + 1] = tmp
......
...@@ -132,3 +132,15 @@ pyclipper ...@@ -132,3 +132,15 @@ pyclipper
* 1.0.1 * 1.0.1
修复使用在线服务调用模型失败问题 修复使用在线服务调用模型失败问题
* 1.0.2
支持mkldnn加速CPU计算
* 1.0.3
增加更多预训练数据,更新预训练参数
1.1.0
使用超轻量级的三阶段模型(文本框检测-角度分类-文字识别)识别图片文字。
...@@ -129,3 +129,7 @@ pyclipper ...@@ -129,3 +129,7 @@ pyclipper
* 1.0.2 * 1.0.2
支持mkldnn加速CPU计算 支持mkldnn加速CPU计算
* 1.0.3
增加更多预训练数据,更新预训练参数
...@@ -29,7 +29,7 @@ def base64_to_cv2(b64str): ...@@ -29,7 +29,7 @@ def base64_to_cv2(b64str):
@moduleinfo( @moduleinfo(
name="chinese_text_detection_db_server", name="chinese_text_detection_db_server",
version="1.0.1", version="1.0.2",
summary= summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev", author="paddle-dev",
...@@ -41,7 +41,7 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -41,7 +41,7 @@ class ChineseTextDetectionDBServer(hub.Module):
initialize with the necessary elements initialize with the necessary elements
""" """
self.pretrained_model_path = os.path.join(self.directory, self.pretrained_model_path = os.path.join(self.directory,
'ch_det_r50_vd_db') 'inference_model')
self.enable_mkldnn = enable_mkldnn self.enable_mkldnn = enable_mkldnn
self._set_config() self._set_config()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册