diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index ff084a725a27a909fcc1b29d7dc3b309fa0623a2..52194eb964f7a7fd159cc1a42b73d280f8ee5fb4 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -23,6 +23,7 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .randaugment import RandAugment +from .copy_paste import CopyPaste from .operators import * from .label_ops import * diff --git a/ppocr/data/imaug/copy_paste.py b/ppocr/data/imaug/copy_paste.py new file mode 100644 index 0000000000000000000000000000000000000000..9e13e806f0ae77cc9b37c1275218c05152bfa166 --- /dev/null +++ b/ppocr/data/imaug/copy_paste.py @@ -0,0 +1,164 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import cv2 +import random +import numpy as np +from PIL import Image +from shapely.geometry import Polygon + +from ppocr.data.imaug.iaa_augment import IaaAugment +from ppocr.data.imaug.random_crop_data import is_poly_outside_rect +from tools.infer.utility import get_rotate_crop_image + + +class CopyPaste(object): + def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs): + self.ext_data_num = 1 + self.objects_paste_ratio = objects_paste_ratio + self.limit_paste = limit_paste + augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}] + self.aug = IaaAugment(augmenter_args) + + def __call__(self, data): + src_img = data['image'] + src_polys = data['polys'].tolist() + src_ignores = data['ignore_tags'].tolist() + ext_data = data['ext_data'][0] + ext_image = ext_data['image'] + ext_polys = ext_data['polys'] + ext_ignores = ext_data['ignore_tags'] + + indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] + select_num = max( + 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30)) + + random.shuffle(indexs) + select_idxs = indexs[:select_num] + select_polys = ext_polys[select_idxs] + select_ignores = ext_ignores[select_idxs] + + src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) + ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) + src_img = Image.fromarray(src_img).convert('RGBA') + for poly, tag in zip(select_polys, select_ignores): + box_img = get_rotate_crop_image(ext_image, poly) + + src_img, box = self.paste_img(src_img, box_img, src_polys) + if box is not None: + src_polys.append(box) + src_ignores.append(tag) + src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) + h, w = src_img.shape[:2] + src_polys = np.array(src_polys) + src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w) + src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) + data['image'] = src_img + data['polys'] = src_polys + data['ignore_tags'] = np.array(src_ignores) + return data + + def paste_img(self, src_img, box_img, src_polys): + box_img_pil = Image.fromarray(box_img).convert('RGBA') + src_w, src_h = src_img.size + box_w, box_h = box_img_pil.size + if box_w > src_w or box_h > src_h: + return src_img, None + angle = np.random.randint(0, 360) + box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]]) + box = rotate_bbox(box_img, box, angle)[0] + + paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w, + src_h - box_h) + if paste_x is None: + return src_img, None + box[:, 0] += paste_x + box[:, 1] += paste_y + box_img_pil = box_img_pil.rotate(angle, expand=1) + r, g, b, A = box_img_pil.split() + src_img.paste(box_img_pil, (paste_x, paste_y), mask=A) + + return src_img, box + + def select_coord(self, src_polys, box, endx, endy): + if self.limit_paste: + xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min( + ), box[:, 0].max(), box[:, 1].max() + for _ in range(50): + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + xmin1 = xmin + paste_x + xmax1 = xmax + paste_x + ymin1 = ymin + paste_y + ymax1 = ymax + paste_y + + num_poly_in_rect = 0 + for poly in src_polys: + if not is_poly_outside_rect(poly, xmax1, ymin1, + xmax1 - xmin1, ymax1 - ymin1): + num_poly_in_rect += 1 + break + if num_poly_in_rect == 0: + return paste_x, paste_y + return None, None + else: + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + return paste_x, paste_y + + +def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + +def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + +def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + +def rotate_bbox(img, text_polys, angle, scale=1): + """ + from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py + Args: + img: np.ndarray + text_polys: np.ndarray N*4*2 + angle: int + scale: int + + Returns: + + """ + w = img.shape[1] + h = img.shape[0] + + rangle = np.deg2rad(angle) + nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) + nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) + rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale) + rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) + rot_mat[0, 2] += rot_move[0] + rot_mat[1, 2] += rot_move[1] + + # ---------------------- rotate box ---------------------- + rot_text_polys = list() + for bbox in text_polys: + point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) + point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) + point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) + point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) + rot_text_polys.append([point1, point2, point3, point4]) + return np.array(rot_text_polys, dtype=np.float32) diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 8f8fcb4dbdf3c68587875b50cb30a834a3943216..ce9e1b38675ae8df4a2e83b88c1adae4476a10b5 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -14,6 +14,7 @@ import numpy as np import os import random +import traceback from paddle.io import Dataset from .imaug import transform, create_operators @@ -69,6 +70,36 @@ class SimpleDataSet(Dataset): random.shuffle(self.data_lines) return + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:2] + ext_data = [] + + while len(ext_data) < ext_data_num: + file_idx = self.data_idx_order_list[np.random.randint(self.__len__( + ))] + data_line = self.data_lines[file_idx] + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + data = {'img_path': img_path, 'label': label} + if not os.path.exists(img_path): + continue + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + data = transform(data, load_data_ops) + if data is None: + continue + ext_data.append(data) + return ext_data + def __getitem__(self, idx): file_idx = self.data_idx_order_list[idx] data_line = self.data_lines[file_idx] @@ -84,11 +115,13 @@ class SimpleDataSet(Dataset): with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img + data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) - except Exception as e: + except: + error_meg = traceback.format_exc() self.logger.error( "When parsing line {}, error happened with msg: {}".format( - data_line, e)) + data_line, error_meg)) outs = None if outs is None: # during evaluation, we should fix the idx to get same results for many times of evaluation. diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 35bc00b90917e5eb91c0124c44e2c7832ed1a9ff..715bd3fa9d596dd60f7f789f3e367734ffec608b 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -33,7 +33,7 @@ import tools.infer.predict_det as predict_det import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from tools.infer.utility import draw_ocr_box_txt +from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image logger = get_logger() @@ -49,39 +49,6 @@ class TextSystem(object): if self.use_angle_cls: self.text_classifier = predict_cls.TextClassifier(args) - def get_rotate_crop_image(self, img, points): - ''' - img_height, img_width = img.shape[0:2] - left = int(np.min(points[:, 0])) - right = int(np.max(points[:, 0])) - top = int(np.min(points[:, 1])) - bottom = int(np.max(points[:, 1])) - img_crop = img[top:bottom, left:right, :].copy() - points[:, 0] = points[:, 0] - left - points[:, 1] = points[:, 1] - top - ''' - img_crop_width = int( - max( - np.linalg.norm(points[0] - points[1]), - np.linalg.norm(points[2] - points[3]))) - img_crop_height = int( - max( - np.linalg.norm(points[0] - points[3]), - np.linalg.norm(points[1] - points[2]))) - pts_std = np.float32([[0, 0], [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height]]) - M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC) - dst_img_height, dst_img_width = dst_img.shape[0:2] - if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) - return dst_img - def print_draw_crop_rec_res(self, img_crop_list, rec_res): bbox_num = len(img_crop_list) for bno in range(bbox_num): @@ -102,7 +69,7 @@ class TextSystem(object): for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) - img_crop = self.get_rotate_crop_image(ori_im, tmp_box) + img_crop = get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( diff --git a/tools/infer/utility.py b/tools/infer/utility.py index c17c5d2d59048f698b65da6994e2accc14716b12..021494ceea428709f4155e0d7c1142ca5a31858c 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -241,7 +241,7 @@ def create_predictor(args, mode, logger): config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") if mode == 'table': - config.delete_pass("fc_fuse_pass") # not supported for table + config.delete_pass("fc_fuse_pass") # not supported for table config.switch_use_feed_fetch_ops(False) config.switch_ir_optim(True) @@ -506,5 +506,40 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5): return image +def get_rotate_crop_image(img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + if __name__ == '__main__': pass