提交 8f5e5177 编写于 作者: C chenjian

fix

上级 9b3119df
......@@ -34,14 +34,14 @@ def base64_to_cv2(b64str):
@moduleinfo(
name="ppocrv3_det_ch",
name="ch_pp-ocrv3_det",
version="1.0.0",
summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChineseTextDetectionDB(hub.Module):
class ChPPOCRv3Det(hub.Module):
def _initialize(self, enable_mkldnn=False):
"""
......@@ -155,7 +155,8 @@ class ChineseTextDetectionDB(hub.Module):
use_gpu=False,
output_dir='detection_result',
visualization=False,
box_thresh=0.5):
box_thresh=0.5,
det_db_unclip_ratio=1.5):
"""
Get the text box in the predicted images.
Args:
......@@ -165,6 +166,7 @@ class ChineseTextDetectionDB(hub.Module):
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
Returns:
res (list): The result of text detection box and save path of images.
"""
......@@ -195,7 +197,7 @@ class ChineseTextDetectionDB(hub.Module):
'thresh': 0.3,
'box_thresh': 0.6,
'max_candidates': 1000,
'unclip_ratio': 1.5
'unclip_ratio': det_db_unclip_ratio
})
all_imgs = []
......@@ -204,7 +206,6 @@ class ChineseTextDetectionDB(hub.Module):
for original_image in predicted_data:
ori_im = original_image.copy()
im, ratio_list = preprocessor(original_image)
print('after preprocess int det, shape{}'.format(im.shape))
res = {'save_path': ''}
if im is None:
res['data'] = []
......@@ -222,15 +223,10 @@ class ChineseTextDetectionDB(hub.Module):
outs_dict = {}
outs_dict['maps'] = outputs[0]
# data_out = self.output_tensors[0].copy_to_cpu()
print('Outputs[0] in det, shape: {}'.format(outputs[0].shape))
dt_boxes_list = postprocessor(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0]
print('after postprocess int det, shape{}'.format(dt_boxes.shape))
boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape)
print('after fitler tag int det, shape{}'.format(boxes.shape))
res['data'] = boxes.astype(np.int).tolist()
print('boxes: {}'.format(boxes))
all_imgs.append(im)
all_ratios.append(ratio_list)
if visualization:
......@@ -278,6 +274,7 @@ class ChineseTextDetectionDB(hub.Module):
results = self.detect_text(paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
det_db_unclip_ratio=args.det_db_unclip_ratio,
visualization=args.visualization)
return results
......@@ -297,6 +294,10 @@ class ChineseTextDetectionDB(hub.Module):
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
self.arg_config_group.add_argument('--det_db_unclip_ratio',
type=float,
default=1.5,
help="unclip ratio for post processing in DB detection.")
def add_module_input_arg(self):
"""
......
......@@ -25,7 +25,6 @@ class DBProcessTest(object):
self.resize_type = 0
if 'test_image_shape' in params:
self.image_shape = params['test_image_shape']
# print(self.image_shape)
self.resize_type = 1
if 'max_side_len' in params:
self.max_side_len = params['max_side_len']
......@@ -54,15 +53,14 @@ class DBProcessTest(object):
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
......@@ -93,13 +91,14 @@ class DBProcessTest(object):
return im
def __call__(self, im):
src_h, src_w, _ = im.shape
if self.resize_type == 0:
im, (ratio_h, ratio_w) = self.resize_image_type0(im)
else:
im, (ratio_h, ratio_w) = self.resize_image_type1(im)
im = self.normalize(im)
im = im[np.newaxis, :]
return [im, (ratio_h, ratio_w)]
return [im, (src_h, src_w, ratio_h, ratio_w)]
class DBPostProcess(object):
......@@ -228,7 +227,7 @@ class DBPostProcess(object):
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, ratio_list):
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
pred = pred[:, 0, :, :]
......@@ -236,10 +235,10 @@ class DBPostProcess(object):
boxes_batch = []
for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:]
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
mask = segmentation[batch_index]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height)
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
boxes_batch.append(tmp_boxes)
return boxes_batch
......
......@@ -59,6 +59,7 @@ class CharacterOps(object):
self.character = dict_character
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
def encode(self, text):
......@@ -93,12 +94,6 @@ class CharacterOps(object):
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
# print(text_index)
# print(batch_idx)
# print(selection)
# for text_id in text_index[batch_idx][selection]:
# print(text_id)
# print(self.character[text_id])
char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
......
......@@ -29,14 +29,14 @@ from paddlehub.module.module import serving
@moduleinfo(
name="ppocrv3_rec_ch",
name="ch_pp-ocrv3",
version="1.0.0",
summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChineseOCRDBCRNN(hub.Module):
class ChPPOCRv3(hub.Module):
def _initialize(self, text_detector_module=None, enable_mkldnn=False):
"""
......@@ -51,7 +51,7 @@ class ChineseOCRDBCRNN(hub.Module):
'use_space_char': True
}
self.char_ops = CharacterOps(char_ops_params)
self.rec_image_shape = [3, 32, 320]
self.rec_image_shape = [3, 48, 320]
self._text_detector_module = text_detector_module
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.enable_mkldnn = enable_mkldnn
......@@ -109,7 +109,7 @@ class ChineseOCRDBCRNN(hub.Module):
text detect module
"""
if not self._text_detector_module:
self._text_detector_module = hub.Module(name='ppocrv3_det_ch',
self._text_detector_module = hub.Module(name='ch_pp-ocrv3_det',
enable_mkldnn=self.enable_mkldnn,
version='1.0.0')
return self._text_detector_module
......@@ -152,7 +152,7 @@ class ChineseOCRDBCRNN(hub.Module):
def resize_norm_img_rec(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio))
imgW = int((imgH * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
......@@ -199,7 +199,8 @@ class ChineseOCRDBCRNN(hub.Module):
visualization=False,
box_thresh=0.5,
text_thresh=0.5,
angle_classification_thresh=0.9):
angle_classification_thresh=0.9,
det_db_unclip_ratio=1.5):
"""
Get the chinese texts in the predicted images.
Args:
......@@ -212,7 +213,7 @@ class ChineseOCRDBCRNN(hub.Module):
box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
Returns:
res (list): The result of chinese texts and save path of images.
"""
......@@ -238,10 +239,10 @@ class ChineseOCRDBCRNN(hub.Module):
detection_results = self.text_detector_module.detect_text(images=predicted_data,
use_gpu=self.use_gpu,
box_thresh=box_thresh)
box_thresh=box_thresh,
det_db_unclip_ratio=det_db_unclip_ratio)
boxes = [np.array(item['data']).astype(np.float32) for item in detection_results]
print("dt_boxes num : {}".format(len(boxes[0])))
all_results = []
for index, img_boxes in enumerate(boxes):
original_image = predicted_data[index].copy()
......@@ -255,7 +256,6 @@ class ChineseOCRDBCRNN(hub.Module):
tmp_box = copy.deepcopy(boxes[num_box])
img_crop = self.get_rotate_crop_image(original_image, tmp_box)
img_crop_list.append(img_crop)
print('img_crop shape {}'.format(img_crop.shape))
img_crop_list, angle_list = self._classify_text(img_crop_list,
angle_classification_thresh=angle_classification_thresh)
rec_results = self._recognize_text(img_crop_list)
......@@ -371,7 +371,8 @@ class ChineseOCRDBCRNN(hub.Module):
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
imgC, imgH, imgW = self.rec_image_shape
max_wh_ratio = imgW / imgH
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
......@@ -400,10 +401,8 @@ class ChineseOCRDBCRNN(hub.Module):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
print('preds.shape: {}', preds.shape)
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
# print('preds_idx: {} \n preds_prob: {}'.format(preds_idx, preds_prob) )
rec_result = self.char_ops.decode(preds_idx, preds_prob, is_remove_duplicate=True)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......@@ -431,6 +430,7 @@ class ChineseOCRDBCRNN(hub.Module):
results = self.recognize_text(paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
det_db_unclip_ratio=args.det_db_unclip_ratio,
visualization=args.visualization)
return results
......@@ -450,6 +450,10 @@ class ChineseOCRDBCRNN(hub.Module):
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
self.arg_config_group.add_argument('--det_db_unclip_ratio',
type=float,
default=1.5,
help="unclip ratio for post processing in DB detection.")
def add_module_input_arg(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册