未验证 提交 0f269066 编写于 作者: C chenjian 提交者: GitHub

Update ocr module version (#1637)

上级 07d13b73
...@@ -21,7 +21,7 @@ from chinese_ocr_db_crnn_mobile.utils import base64_to_cv2, draw_ocr, get_image_ ...@@ -21,7 +21,7 @@ from chinese_ocr_db_crnn_mobile.utils import base64_to_cv2, draw_ocr, get_image_
@moduleinfo( @moduleinfo(
name="chinese_ocr_db_crnn_mobile", name="chinese_ocr_db_crnn_mobile",
version="1.1.1", version="1.1.2",
summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \ summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ", based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
author="paddle-dev", author="paddle-dev",
...@@ -490,14 +490,3 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -490,14 +490,3 @@ class ChineseOCRDBCRNN(hub.Module):
Add the command input options Add the command input options
""" """
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image") self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image")
if __name__ == '__main__':
ocr = ChineseOCRDBCRNN()
image_path = [
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/2.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/test_image.jpg'
]
res = ocr.recognize_text(paths=image_path, visualization=True)
ocr.save_inference_model('save')
print(res)
...@@ -25,7 +25,7 @@ from chinese_ocr_db_crnn_server.utils import base64_to_cv2, draw_ocr, get_image_ ...@@ -25,7 +25,7 @@ from chinese_ocr_db_crnn_server.utils import base64_to_cv2, draw_ocr, get_image_
@moduleinfo( @moduleinfo(
name="chinese_ocr_db_crnn_server", name="chinese_ocr_db_crnn_server",
version="1.1.1", version="1.1.2",
summary= summary=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ", "The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ",
author="paddle-dev", author="paddle-dev",
...@@ -494,14 +494,3 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -494,14 +494,3 @@ class ChineseOCRDBCRNNServer(hub.Module):
Add the command input options Add the command input options
""" """
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image") self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image")
if __name__ == '__main__':
ocr = ChineseOCRDBCRNNServer(enable_mkldnn=False)
image_path = [
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/test_image.jpg'
]
res = ocr.recognize_text(paths=image_path, visualization=True)
ocr.save_inference_model('save')
print(res)
...@@ -172,6 +172,6 @@ def sorted_boxes(dt_boxes): ...@@ -172,6 +172,6 @@ def sorted_boxes(dt_boxes):
def base64_to_cv2(b64str): def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
data = np.frombuffer(data, np.uint8) data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR) data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data return data
...@@ -29,7 +29,7 @@ def base64_to_cv2(b64str): ...@@ -29,7 +29,7 @@ def base64_to_cv2(b64str):
@moduleinfo( @moduleinfo(
name="chinese_text_detection_db_mobile", name="chinese_text_detection_db_mobile",
version="1.0.3", version="1.0.4",
summary= summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev", author="paddle-dev",
...@@ -103,26 +103,6 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -103,26 +103,6 @@ class ChineseTextDetectionDB(hub.Module):
images.append(img) images.append(img)
return images return images
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 10 or rect_height <= 10:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
""" """
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
...@@ -147,6 +127,35 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -147,6 +127,35 @@ class ChineseTextDetectionDB(hub.Module):
rect = np.array([tl, tr, br, bl], dtype="float32") rect = np.array([tl, tr, br, bl], dtype="float32")
return rect return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def detect_text(self, def detect_text(self,
images=[], images=[],
paths=[], paths=[],
...@@ -193,7 +202,7 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -193,7 +202,7 @@ class ChineseTextDetectionDB(hub.Module):
'thresh': 0.3, 'thresh': 0.3,
'box_thresh': 0.5, 'box_thresh': 0.5,
'max_candidates': 1000, 'max_candidates': 1000,
'unclip_ratio': 2.0 'unclip_ratio': 1.6
}) })
all_imgs = [] all_imgs = []
...@@ -314,14 +323,3 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -314,14 +323,3 @@ class ChineseTextDetectionDB(hub.Module):
Add the command input options Add the command input options
""" """
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image") self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image")
if __name__ == '__main__':
db = ChineseTextDetectionDB()
image_path = [
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/2.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/test_image.jpg'
]
res = db.detect_text(paths=image_path, visualization=True)
db.save_inference_model('save')
print(res)
...@@ -120,6 +120,7 @@ class DBPostProcess(object): ...@@ -120,6 +120,7 @@ class DBPostProcess(object):
self.max_candidates = params['max_candidates'] self.max_candidates = params['max_candidates']
self.unclip_ratio = params['unclip_ratio'] self.unclip_ratio = params['unclip_ratio']
self.min_size = 3 self.min_size = 3
self.dilation_kernel = np.array([[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
...@@ -218,7 +219,9 @@ class DBPostProcess(object): ...@@ -218,7 +219,9 @@ class DBPostProcess(object):
boxes_batch = [] boxes_batch = []
for batch_index in range(pred.shape[0]): for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:] height, width = pred.shape[-2:]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel)
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height)
boxes = [] boxes = []
for k in range(len(tmp_boxes)): for k in range(len(tmp_boxes)):
......
...@@ -297,11 +297,3 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -297,11 +297,3 @@ class ChineseTextDetectionDBServer(hub.Module):
Add the command input options Add the command input options
""" """
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image") self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image")
if __name__ == '__main__':
db = ChineseTextDetectionDBServer()
image_path = ['/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg']
res = db.detect_text(paths=image_path, visualization=True)
db.save_inference_model('save')
print(res)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册