From fdaf1f1577778daf1e48f8edb65fbdfa9d0a32bb Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 30 Jun 2021 10:45:27 +0800 Subject: [PATCH] Extraction get_rotate_crop_image --- ppocr/data/imaug/copy_paste.py | 46 +++++++++------------------------- tools/infer/predict_system.py | 38 ++-------------------------- tools/infer/utility.py | 37 +++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 72 deletions(-) diff --git a/ppocr/data/imaug/copy_paste.py b/ppocr/data/imaug/copy_paste.py index 357884fe..9e13e806 100644 --- a/ppocr/data/imaug/copy_paste.py +++ b/ppocr/data/imaug/copy_paste.py @@ -20,40 +20,7 @@ from shapely.geometry import Polygon from ppocr.data.imaug.iaa_augment import IaaAugment from ppocr.data.imaug.random_crop_data import is_poly_outside_rect - - -def get_rotate_crop_image(img, points): - ''' - img_height, img_width = img.shape[0:2] - left = int(np.min(points[:, 0])) - right = int(np.max(points[:, 0])) - top = int(np.min(points[:, 1])) - bottom = int(np.max(points[:, 1])) - img_crop = img[top:bottom, left:right, :].copy() - points[:, 0] = points[:, 0] - left - points[:, 1] = points[:, 1] - top - ''' - img_crop_width = int( - max( - np.linalg.norm(points[0] - points[1]), - np.linalg.norm(points[2] - points[3]))) - img_crop_height = int( - max( - np.linalg.norm(points[0] - points[3]), - np.linalg.norm(points[1] - points[2]))) - pts_std = np.float32([[0, 0], [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height]]) - M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC) - dst_img_height, dst_img_width = dst_img.shape[0:2] - if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) - return dst_img +from tools.infer.utility import get_rotate_crop_image class CopyPaste(object): @@ -164,6 +131,17 @@ def get_intersection(pD, pG): def rotate_bbox(img, text_polys, angle, scale=1): + """ + from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py + Args: + img: np.ndarray + text_polys: np.ndarray N*4*2 + angle: int + scale: int + + Returns: + + """ w = img.shape[1] h = img.shape[0] diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index ad1b7d4e..ab4a2ea1 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -31,7 +31,7 @@ import tools.infer.predict_det as predict_det import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb +from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb, get_rotate_crop_image import tools.infer.benchmark_utils as benchmark_utils logger = get_logger() @@ -45,39 +45,6 @@ class TextSystem(object): if self.use_angle_cls: self.text_classifier = predict_cls.TextClassifier(args) - def get_rotate_crop_image(self, img, points): - ''' - img_height, img_width = img.shape[0:2] - left = int(np.min(points[:, 0])) - right = int(np.max(points[:, 0])) - top = int(np.min(points[:, 1])) - bottom = int(np.max(points[:, 1])) - img_crop = img[top:bottom, left:right, :].copy() - points[:, 0] = points[:, 0] - left - points[:, 1] = points[:, 1] - top - ''' - img_crop_width = int( - max( - np.linalg.norm(points[0] - points[1]), - np.linalg.norm(points[2] - points[3]))) - img_crop_height = int( - max( - np.linalg.norm(points[0] - points[3]), - np.linalg.norm(points[1] - points[2]))) - pts_std = np.float32([[0, 0], [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height]]) - M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC) - dst_img_height, dst_img_width = dst_img.shape[0:2] - if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) - return dst_img - def print_draw_crop_rec_res(self, img_crop_list, rec_res): bbox_num = len(img_crop_list) for bno in range(bbox_num): @@ -89,7 +56,6 @@ class TextSystem(object): dt_boxes, elapse = self.text_detector(img) logger.info("dt_boxes num : {}, elapse : {}".format( - len(dt_boxes), elapse)) if dt_boxes is None: return None, None @@ -99,7 +65,7 @@ class TextSystem(object): for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) - img_crop = self.get_rotate_crop_image(ori_im, tmp_box) + img_crop = get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 69f28e00..3fa62c77 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -109,11 +109,10 @@ def init_args(): parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--process_id", type=int, default=0) - + parser.add_argument("--benchmark", type=bool, default=False) parser.add_argument("--save_log_path", type=str, default="./log_output/") - return parser @@ -615,5 +614,39 @@ def get_current_memory_mb(gpu_id=None): return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4) +def get_rotate_crop_image(img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + if __name__ == '__main__': pass -- GitLab