提交 8f5e5177 编写于 作者: C chenjian

fix

上级 9b3119df
...@@ -34,14 +34,14 @@ def base64_to_cv2(b64str): ...@@ -34,14 +34,14 @@ def base64_to_cv2(b64str):
@moduleinfo( @moduleinfo(
name="ppocrv3_det_ch", name="ch_pp-ocrv3_det",
version="1.0.0", version="1.0.0",
summary= summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChineseTextDetectionDB(hub.Module): class ChPPOCRv3Det(hub.Module):
def _initialize(self, enable_mkldnn=False): def _initialize(self, enable_mkldnn=False):
""" """
...@@ -155,7 +155,8 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -155,7 +155,8 @@ class ChineseTextDetectionDB(hub.Module):
use_gpu=False, use_gpu=False,
output_dir='detection_result', output_dir='detection_result',
visualization=False, visualization=False,
box_thresh=0.5): box_thresh=0.5,
det_db_unclip_ratio=1.5):
""" """
Get the text box in the predicted images. Get the text box in the predicted images.
Args: Args:
...@@ -165,6 +166,7 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -165,6 +166,7 @@ class ChineseTextDetectionDB(hub.Module):
output_dir (str): The directory to store output images. output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not. visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence box_thresh(float): the threshold of the detected text box's confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
Returns: Returns:
res (list): The result of text detection box and save path of images. res (list): The result of text detection box and save path of images.
""" """
...@@ -195,7 +197,7 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -195,7 +197,7 @@ class ChineseTextDetectionDB(hub.Module):
'thresh': 0.3, 'thresh': 0.3,
'box_thresh': 0.6, 'box_thresh': 0.6,
'max_candidates': 1000, 'max_candidates': 1000,
'unclip_ratio': 1.5 'unclip_ratio': det_db_unclip_ratio
}) })
all_imgs = [] all_imgs = []
...@@ -204,7 +206,6 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -204,7 +206,6 @@ class ChineseTextDetectionDB(hub.Module):
for original_image in predicted_data: for original_image in predicted_data:
ori_im = original_image.copy() ori_im = original_image.copy()
im, ratio_list = preprocessor(original_image) im, ratio_list = preprocessor(original_image)
print('after preprocess int det, shape{}'.format(im.shape))
res = {'save_path': ''} res = {'save_path': ''}
if im is None: if im is None:
res['data'] = [] res['data'] = []
...@@ -222,15 +223,10 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -222,15 +223,10 @@ class ChineseTextDetectionDB(hub.Module):
outs_dict = {} outs_dict = {}
outs_dict['maps'] = outputs[0] outs_dict['maps'] = outputs[0]
# data_out = self.output_tensors[0].copy_to_cpu()
print('Outputs[0] in det, shape: {}'.format(outputs[0].shape))
dt_boxes_list = postprocessor(outs_dict, [ratio_list]) dt_boxes_list = postprocessor(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0] dt_boxes = dt_boxes_list[0]
print('after postprocess int det, shape{}'.format(dt_boxes.shape))
boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape) boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape)
print('after fitler tag int det, shape{}'.format(boxes.shape))
res['data'] = boxes.astype(np.int).tolist() res['data'] = boxes.astype(np.int).tolist()
print('boxes: {}'.format(boxes))
all_imgs.append(im) all_imgs.append(im)
all_ratios.append(ratio_list) all_ratios.append(ratio_list)
if visualization: if visualization:
...@@ -278,6 +274,7 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -278,6 +274,7 @@ class ChineseTextDetectionDB(hub.Module):
results = self.detect_text(paths=[args.input_path], results = self.detect_text(paths=[args.input_path],
use_gpu=args.use_gpu, use_gpu=args.use_gpu,
output_dir=args.output_dir, output_dir=args.output_dir,
det_db_unclip_ratio=args.det_db_unclip_ratio,
visualization=args.visualization) visualization=args.visualization)
return results return results
...@@ -297,6 +294,10 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -297,6 +294,10 @@ class ChineseTextDetectionDB(hub.Module):
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="whether to save output as images.") help="whether to save output as images.")
self.arg_config_group.add_argument('--det_db_unclip_ratio',
type=float,
default=1.5,
help="unclip ratio for post processing in DB detection.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
...@@ -25,7 +25,6 @@ class DBProcessTest(object): ...@@ -25,7 +25,6 @@ class DBProcessTest(object):
self.resize_type = 0 self.resize_type = 0
if 'test_image_shape' in params: if 'test_image_shape' in params:
self.image_shape = params['test_image_shape'] self.image_shape = params['test_image_shape']
# print(self.image_shape)
self.resize_type = 1 self.resize_type = 1
if 'max_side_len' in params: if 'max_side_len' in params:
self.max_side_len = params['max_side_len'] self.max_side_len = params['max_side_len']
...@@ -54,15 +53,14 @@ class DBProcessTest(object): ...@@ -54,15 +53,14 @@ class DBProcessTest(object):
resize_h = int(h * ratio) resize_h = int(h * ratio)
resize_w = int(w * ratio) resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32) resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = int(round(resize_w / 32) * 32) resize_w = max(int(round(resize_w / 32) * 32), 32)
try: try:
if int(resize_w) <= 0 or int(resize_h) <= 0: if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None) return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h))) img = cv2.resize(img, (int(resize_w), int(resize_h)))
except: except:
print(img.shape, resize_w, resize_h)
sys.exit(0) sys.exit(0)
ratio_h = resize_h / float(h) ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w) ratio_w = resize_w / float(w)
...@@ -93,13 +91,14 @@ class DBProcessTest(object): ...@@ -93,13 +91,14 @@ class DBProcessTest(object):
return im return im
def __call__(self, im): def __call__(self, im):
src_h, src_w, _ = im.shape
if self.resize_type == 0: if self.resize_type == 0:
im, (ratio_h, ratio_w) = self.resize_image_type0(im) im, (ratio_h, ratio_w) = self.resize_image_type0(im)
else: else:
im, (ratio_h, ratio_w) = self.resize_image_type1(im) im, (ratio_h, ratio_w) = self.resize_image_type1(im)
im = self.normalize(im) im = self.normalize(im)
im = im[np.newaxis, :] im = im[np.newaxis, :]
return [im, (ratio_h, ratio_w)] return [im, (src_h, src_w, ratio_h, ratio_w)]
class DBPostProcess(object): class DBPostProcess(object):
...@@ -228,7 +227,7 @@ class DBPostProcess(object): ...@@ -228,7 +227,7 @@ class DBPostProcess(object):
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, ratio_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps'] pred = outs_dict['maps']
pred = pred[:, 0, :, :] pred = pred[:, 0, :, :]
...@@ -236,10 +235,10 @@ class DBPostProcess(object): ...@@ -236,10 +235,10 @@ class DBPostProcess(object):
boxes_batch = [] boxes_batch = []
for batch_index in range(pred.shape[0]): for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:] src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
mask = segmentation[batch_index] mask = segmentation[batch_index]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height) tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
boxes_batch.append(tmp_boxes) boxes_batch.append(tmp_boxes)
return boxes_batch return boxes_batch
......
...@@ -59,6 +59,7 @@ class CharacterOps(object): ...@@ -59,6 +59,7 @@ class CharacterOps(object):
self.character = dict_character self.character = dict_character
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character return dict_character
def encode(self, text): def encode(self, text):
...@@ -93,12 +94,6 @@ class CharacterOps(object): ...@@ -93,12 +94,6 @@ class CharacterOps(object):
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens: for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token selection &= text_index[batch_idx] != ignored_token
# print(text_index)
# print(batch_idx)
# print(selection)
# for text_id in text_index[batch_idx][selection]:
# print(text_id)
# print(self.character[text_id])
char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]] char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]]
if text_prob is not None: if text_prob is not None:
conf_list = text_prob[batch_idx][selection] conf_list = text_prob[batch_idx][selection]
......
...@@ -29,14 +29,14 @@ from paddlehub.module.module import serving ...@@ -29,14 +29,14 @@ from paddlehub.module.module import serving
@moduleinfo( @moduleinfo(
name="ppocrv3_rec_ch", name="ch_pp-ocrv3",
version="1.0.0", version="1.0.0",
summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \ summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ", based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChineseOCRDBCRNN(hub.Module): class ChPPOCRv3(hub.Module):
def _initialize(self, text_detector_module=None, enable_mkldnn=False): def _initialize(self, text_detector_module=None, enable_mkldnn=False):
""" """
...@@ -51,7 +51,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -51,7 +51,7 @@ class ChineseOCRDBCRNN(hub.Module):
'use_space_char': True 'use_space_char': True
} }
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
self.rec_image_shape = [3, 32, 320] self.rec_image_shape = [3, 48, 320]
self._text_detector_module = text_detector_module self._text_detector_module = text_detector_module
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf') self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.enable_mkldnn = enable_mkldnn self.enable_mkldnn = enable_mkldnn
...@@ -109,7 +109,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -109,7 +109,7 @@ class ChineseOCRDBCRNN(hub.Module):
text detect module text detect module
""" """
if not self._text_detector_module: if not self._text_detector_module:
self._text_detector_module = hub.Module(name='ppocrv3_det_ch', self._text_detector_module = hub.Module(name='ch_pp-ocrv3_det',
enable_mkldnn=self.enable_mkldnn, enable_mkldnn=self.enable_mkldnn,
version='1.0.0') version='1.0.0')
return self._text_detector_module return self._text_detector_module
...@@ -152,7 +152,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -152,7 +152,7 @@ class ChineseOCRDBCRNN(hub.Module):
def resize_norm_img_rec(self, img, max_wh_ratio): def resize_norm_img_rec(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2] assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio)) imgW = int((imgH * max_wh_ratio))
h, w = img.shape[:2] h, w = img.shape[:2]
ratio = w / float(h) ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW: if math.ceil(imgH * ratio) > imgW:
...@@ -199,7 +199,8 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -199,7 +199,8 @@ class ChineseOCRDBCRNN(hub.Module):
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.5,
text_thresh=0.5, text_thresh=0.5,
angle_classification_thresh=0.9): angle_classification_thresh=0.9,
det_db_unclip_ratio=1.5):
""" """
Get the chinese texts in the predicted images. Get the chinese texts in the predicted images.
Args: Args:
...@@ -212,7 +213,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -212,7 +213,7 @@ class ChineseOCRDBCRNN(hub.Module):
box_thresh(float): the threshold of the detected text box's confidence box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the chinese text recognition confidence text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence angle_classification_thresh(float): the threshold of the angle classification confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
Returns: Returns:
res (list): The result of chinese texts and save path of images. res (list): The result of chinese texts and save path of images.
""" """
...@@ -238,10 +239,10 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -238,10 +239,10 @@ class ChineseOCRDBCRNN(hub.Module):
detection_results = self.text_detector_module.detect_text(images=predicted_data, detection_results = self.text_detector_module.detect_text(images=predicted_data,
use_gpu=self.use_gpu, use_gpu=self.use_gpu,
box_thresh=box_thresh) box_thresh=box_thresh,
det_db_unclip_ratio=det_db_unclip_ratio)
boxes = [np.array(item['data']).astype(np.float32) for item in detection_results] boxes = [np.array(item['data']).astype(np.float32) for item in detection_results]
print("dt_boxes num : {}".format(len(boxes[0])))
all_results = [] all_results = []
for index, img_boxes in enumerate(boxes): for index, img_boxes in enumerate(boxes):
original_image = predicted_data[index].copy() original_image = predicted_data[index].copy()
...@@ -255,7 +256,6 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -255,7 +256,6 @@ class ChineseOCRDBCRNN(hub.Module):
tmp_box = copy.deepcopy(boxes[num_box]) tmp_box = copy.deepcopy(boxes[num_box])
img_crop = self.get_rotate_crop_image(original_image, tmp_box) img_crop = self.get_rotate_crop_image(original_image, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
print('img_crop shape {}'.format(img_crop.shape))
img_crop_list, angle_list = self._classify_text(img_crop_list, img_crop_list, angle_list = self._classify_text(img_crop_list,
angle_classification_thresh=angle_classification_thresh) angle_classification_thresh=angle_classification_thresh)
rec_results = self._recognize_text(img_crop_list) rec_results = self._recognize_text(img_crop_list)
...@@ -371,7 +371,8 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -371,7 +371,8 @@ class ChineseOCRDBCRNN(hub.Module):
for beg_img_no in range(0, img_num, batch_num): for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
max_wh_ratio = 0 imgC, imgH, imgW = self.rec_image_shape
max_wh_ratio = imgW / imgH
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2] h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
...@@ -400,10 +401,8 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -400,10 +401,8 @@ class ChineseOCRDBCRNN(hub.Module):
preds = preds[-1] preds = preds[-1]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
print('preds.shape: {}', preds.shape)
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)
# print('preds_idx: {} \n preds_prob: {}'.format(preds_idx, preds_prob) )
rec_result = self.char_ops.decode(preds_idx, preds_prob, is_remove_duplicate=True) rec_result = self.char_ops.decode(preds_idx, preds_prob, is_remove_duplicate=True)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
...@@ -431,6 +430,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -431,6 +430,7 @@ class ChineseOCRDBCRNN(hub.Module):
results = self.recognize_text(paths=[args.input_path], results = self.recognize_text(paths=[args.input_path],
use_gpu=args.use_gpu, use_gpu=args.use_gpu,
output_dir=args.output_dir, output_dir=args.output_dir,
det_db_unclip_ratio=args.det_db_unclip_ratio,
visualization=args.visualization) visualization=args.visualization)
return results return results
...@@ -450,6 +450,10 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -450,6 +450,10 @@ class ChineseOCRDBCRNN(hub.Module):
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="whether to save output as images.") help="whether to save output as images.")
self.arg_config_group.add_argument('--det_db_unclip_ratio',
type=float,
default=1.5,
help="unclip ratio for post processing in DB detection.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册