From bf7e085ea27842d08bd047911c12826529ccbbe3 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Sun, 13 Feb 2022 11:48:18 +0800 Subject: [PATCH] modify fcenet --- ...50_fce_ctw.yml => det_r50_dcn_fce_ctw.yml} | 16 +- ppocr/data/imaug/fce_aug.py | 129 ++----- ppocr/data/imaug/fce_targets.py | 48 +-- ppocr/data/imaug/operators.py | 22 ++ ppocr/losses/det_fce_loss.py | 21 +- ppocr/modeling/heads/det_fce_head.py | 34 +- ppocr/modeling/necks/fce_fpn.py | 18 + ppocr/postprocess/fce_postprocess.py | 340 ++++++------------ ppocr/utils/poly_nms.py | 145 ++++++++ tools/program.py | 2 +- train.sh | 3 +- 11 files changed, 385 insertions(+), 393 deletions(-) rename configs/det/{det_r50_fce_ctw.yml => det_r50_dcn_fce_ctw.yml} (87%) create mode 100644 ppocr/utils/poly_nms.py diff --git a/configs/det/det_r50_fce_ctw.yml b/configs/det/det_r50_dcn_fce_ctw.yml similarity index 87% rename from configs/det/det_r50_fce_ctw.yml rename to configs/det/det_r50_dcn_fce_ctw.yml index a360465d..49d65583 100755 --- a/configs/det/det_r50_fce_ctw.yml +++ b/configs/det/det_r50_dcn_fce_ctw.yml @@ -3,17 +3,17 @@ Global: epoch_num: 1500 log_smooth_window: 20 print_batch_step: 20 - save_model_dir: ./output/fce_r50_ctw/ + save_model_dir: ./output/det_r50_dcn_fce_ctw/ save_epoch_step: 100 # evaluation is run every 835 iterations eval_batch_step: [0, 835] cal_metric_during_train: False - pretrained_model: ../pretrain_models/ResNet50_vd_ssld_pretrained - checkpoints: #output/fce_r50_ctw/latest + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained + checkpoints: #output/det_r50_dcn_fce_ctw/latest save_inference_dir: use_visualdl: False infer_img: doc/imgs_en/img_10.jpg - save_res_path: ./output/fce_r50_ctw/predicts_ctw.txt + save_res_path: ./output/det_fce/predicts_fce.txt Architecture: @@ -65,9 +65,9 @@ Metric: Train: dataset: name: SimpleDataSet - data_dir: /data/Dataset/OCR_det/ctw1500/imgs/ + data_dir: ./train_data/ctw1500/imgs/ label_file_list: - - /data/Dataset/OCR_det/ctw1500/imgs/training.txt + - ./train_data/ctw1500/imgs/training.txt transforms: - DecodeImage: # load image img_mode: BGR @@ -113,9 +113,9 @@ Train: Eval: dataset: name: SimpleDataSet - data_dir: /data/Dataset/OCR_det/ctw1500/imgs/ + data_dir: ./train_data/ctw1500/imgs/ label_file_list: - - /data/Dataset/OCR_det/ctw1500/imgs/test.txt + - ./train_data/ctw1500/imgs/test.txt transforms: - DecodeImage: # load image img_mode: BGR diff --git a/ppocr/data/imaug/fce_aug.py b/ppocr/data/imaug/fce_aug.py index 6563a0d4..e1668d77 100644 --- a/ppocr/data/imaug/fce_aug.py +++ b/ppocr/data/imaug/fce_aug.py @@ -1,63 +1,26 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/transforms.py +""" import numpy as np from PIL import Image, ImageDraw -import paddle.vision.transforms as paddle_trans import cv2 import Polygon as plg import math - - -def imresize(img, - size, - return_scale=False, - interpolation='bilinear', - out=None, - backend=None): - """Resize image to a given size. - - Args: - img (ndarray): The input image. - size (tuple[int]): Target size (w, h). - return_scale (bool): Whether to return `w_scale` and `h_scale`. - interpolation (str): Interpolation method, accepted values are - "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' - backend, "nearest", "bilinear" for 'pillow' backend. - out (ndarray): The output destination. - backend (str | None): The image resize backend type. Options are `cv2`, - `pillow`, `None`. If backend is None, the global imread_backend - specified by ``mmcv.use_backend()`` will be used. Default: None. - - Returns: - tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or - `resized_img`. - """ - cv2_interp_codes = { - 'nearest': cv2.INTER_NEAREST, - 'bilinear': cv2.INTER_LINEAR, - 'bicubic': cv2.INTER_CUBIC, - 'area': cv2.INTER_AREA, - 'lanczos': cv2.INTER_LANCZOS4 - } - h, w = img.shape[:2] - if backend is None: - backend = 'cv2' - if backend not in ['cv2', 'pillow']: - raise ValueError(f'backend: {backend} is not supported for resize.' - f"Supported backends are 'cv2', 'pillow'") - - if backend == 'pillow': - assert img.dtype == np.uint8, 'Pillow backend only support uint8 type' - pil_image = Image.fromarray(img) - pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) - resized_img = np.array(pil_image) - else: - resized_img = cv2.resize( - img, size, dst=out, interpolation=cv2_interp_codes[interpolation]) - if not return_scale: - return resized_img - else: - w_scale = size[0] / w - h_scale = size[1] / h - return resized_img, w_scale, h_scale +from ppocr.utils.poly_nms import poly_intersection class RandomScaling: @@ -83,45 +46,16 @@ class RandomScaling: scales = self.size * 1.0 / max(h, w) * aspect_ratio scales = np.array([scales, scales]) out_size = (int(h * scales[1]), int(w * scales[0])) - image = imresize(image, out_size[::-1]) + image = cv2.resize(image, out_size[::-1]) data['image'] = image text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1] text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0] data['polys'] = text_polys - # import os - # base_name = os.path.split(data['img_path'])[-1] - # img = image[..., ::-1] - # img = Image.fromarray(img) - # draw = ImageDraw.Draw(img) - # for box in text_polys: - # draw.polygon(box, outline=(0, 255, 255,), ) - # import time - # img.save('tmp/{}.jpg'.format(base_name[:-4])) - return data -def poly_intersection(poly_det, poly_gt): - """Calculate the intersection area between two polygon. - - Args: - poly_det (Polygon): A polygon predicted by detector. - poly_gt (Polygon): A gt polygon. - - Returns: - intersection_area (float): The intersection area between two polygons. - """ - assert isinstance(poly_det, plg.Polygon) - assert isinstance(poly_gt, plg.Polygon) - - poly_inter = poly_det & poly_gt - if len(poly_inter) == 0: - return 0, poly_inter - return poly_inter.area(), poly_inter - - class RandomCropFlip: def __init__(self, pad_ratio=0.1, @@ -352,12 +286,7 @@ class RandomCropPolyInstances: max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0) min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1) - # for key in results.get('mask_fields', []): - # if len(results[key].masks) == 0: - # continue - # masks = results[key].masks for mask in key_masks: - # assert len(mask) == 1 mask = mask.reshape((-1, 2)).astype(np.int32) clip_x = np.clip(mask[:, 0], 0, w - 1) clip_y = np.clip(mask[:, 1], 0, h - 1) @@ -501,7 +430,8 @@ class RandomRotatePolyInstances: (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), np.random.randint(0, w * 7 // 8)) img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] - img_cut = imresize(img_cut, (canvas_size[1], canvas_size[0])) + img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0])) + mask = cv2.warpAffine( mask, rotation_matrix, (canvas_size[1], canvas_size[0]), @@ -574,7 +504,7 @@ class SquareResizePad: t_w = self.target_size if h <= w else int(w * self.target_size / h) else: t_h = t_w = self.target_size - img = imresize(img, (t_w, t_h)) + img = cv2.resize(img, (t_w, t_h)) return img, (t_h, t_w) def square_pad(self, img): @@ -589,7 +519,7 @@ class SquareResizePad: (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), np.random.randint(0, w * 7 // 8)) img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] - expand_img = imresize(img_cut, (pad_size, pad_size)) + expand_img = cv2.resize(img_cut, (pad_size, pad_size)) if h > w: y0, x0 = 0, (h - w) // 2 else: @@ -617,13 +547,14 @@ class SquareResizePad: else: image, out_size = self.resize_img(image, keep_ratio=False) offset = (0, 0) - # image, out_size = self.resize_img(image, keep_ratio=True) - # image, offset = self.square_pad(image) results['image'] = image - polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[1] / w + offset[ - 0] - polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[0] / h + offset[ - 1] + try: + polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[ + 1] / w + offset[0] + polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[ + 0] / h + offset[1] + except: + pass results['polys'] = polygons return results diff --git a/ppocr/data/imaug/fce_targets.py b/ppocr/data/imaug/fce_targets.py index 29bda579..18184808 100644 --- a/ppocr/data/imaug/fce_targets.py +++ b/ppocr/data/imaug/fce_targets.py @@ -1,3 +1,21 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py +""" + import cv2 import numpy as np from numpy.fft import fft @@ -470,7 +488,6 @@ class FCENetTargets: """ assert isinstance(img_size, tuple) - # assert check_argument.is_2dlist(text_polys) h, w = img_size k = self.fourier_degree @@ -478,9 +495,6 @@ class FCENetTargets: imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32) for poly in text_polys: - # assert len(poly) == 1 - # text_instance = [[poly[i], poly[i + 1]] - # for i in range(0, len(poly), 2)] mask = np.zeros((h, w), dtype=np.uint8) polygon = np.array(poly).reshape((1, -1, 2)) cv2.fillPoly(mask, polygon.astype(np.int32), 1) @@ -512,15 +526,11 @@ class FCENetTargets: """ assert isinstance(img_size, tuple) - # assert check_argument.is_2dlist(text_polys) h, w = img_size text_region_mask = np.zeros((h, w), dtype=np.uint8) for poly in text_polys: - # assert len(poly) == 1 - # text_instance = [[poly[i], poly[i + 1]] - # for i in range(0, len(poly), 2)] polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2)) cv2.fillPoly(text_region_mask, polygon, 1) @@ -539,8 +549,6 @@ class FCENetTargets: mask (ndarray): The effective mask of (height, width). """ - # assert check_argument.is_2dlist(polygons_ignore) - mask = np.ones(mask_size, dtype=np.uint8) for poly in polygons_ignore: @@ -566,9 +574,6 @@ class FCENetTargets: lv_ignore_polys = [[] for i in range(len(lv_size_divs))] level_maps = [] for poly in text_polys: - # assert len(poly) == 1 - # text_instance = [[poly[i], poly[i + 1]] - # for i in range(0, len(poly), 2)] polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2)) _, _, box_w, box_h = cv2.boundingRect(polygon) proportion = max(box_h, box_w) / (h + 1e-8) @@ -578,9 +583,6 @@ class FCENetTargets: lv_text_polys[ind].append(poly / lv_size_divs[ind]) for ignore_poly in ignore_polys: - # assert len(ignore_poly) == 1 - # text_instance = [[ignore_poly[i], ignore_poly[i + 1]] - # for i in range(0, len(ignore_poly), 2)] polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2)) _, _, box_w, box_h = cv2.boundingRect(polygon) proportion = max(box_h, box_w) / (h + 1e-8) @@ -630,18 +632,6 @@ class FCENetTargets: ignore_tags = results['ignore_tags'] h, w, _ = image.shape - # import time - # from PIL import Image, ImageDraw - # cur_time = time.time() - # image = results['image'] - # text_polys = results['polys'] - # img = image[..., ::-1] - # img = Image.fromarray(img) - # draw = ImageDraw.Draw(img) - # for box in text_polys: - # draw.polygon(box, outline=(0, 255, 255,), ) - # img.save('tmp/{}_resize_pad.jpg'.format(cur_time)) - polygon_masks = [] polygon_masks_ignore = [] for tag, polygon in zip(ignore_tags, polygons): @@ -653,8 +643,6 @@ class FCENetTargets: level_maps = self.generate_level_targets((h, w), polygon_masks, polygon_masks_ignore) - # results['mask_fields'].clear() # rm gt_masks encoded by polygons - # import remote_pdb as pdb;pdb.set_trace() mapping = { 'p3_maps': level_maps[0], 'p4_maps': level_maps[1], diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 920e7fee..efae1f41 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -23,6 +23,7 @@ import sys import six import cv2 import numpy as np +import math class DecodeImage(object): @@ -165,6 +166,27 @@ class KeepKeys(object): return data_list +class Pad(object): + def __init__(self, size_div=32, **kwargs): + self.size_div = size_div + + def __call__(self, data): + + img = data['image'] + resize_h2 = max(int(math.ceil(img.shape[0] / 32) * 32), 32) + resize_w2 = max(int(math.ceil(img.shape[1] / 32) * 32), 32) + img = cv2.copyMakeBorder( + img, + 0, + resize_h2 - img.shape[0], + 0, + resize_w2 - img.shape[1], + cv2.BORDER_CONSTANT, + value=0) + data['image'] = img + return data + + class Resize(object): def __init__(self, size=(640, 640), **kwargs): self.size = size diff --git a/ppocr/losses/det_fce_loss.py b/ppocr/losses/det_fce_loss.py index 80d5b672..d7dfb5aa 100644 --- a/ppocr/losses/det_fce_loss.py +++ b/ppocr/losses/det_fce_loss.py @@ -1,3 +1,21 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/fce_loss.py +""" + import numpy as np from paddle import nn import paddle @@ -39,7 +57,6 @@ class FCELoss(nn.Layer): assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\ 'fourier degree not equal in FCEhead and FCEtarget' - # device = preds[0][0].device # to tensor gts = [p3_maps, p4_maps, p5_maps] for idx, maps in enumerate(gts): @@ -94,7 +111,6 @@ class FCELoss(nn.Layer): [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1) # tr loss loss_tr = self.ohem(tr_pred, tr_mask, train_mask) - # import pdb; pdb.set_trace() # tcl loss loss_tcl = paddle.to_tensor(0.).astype('float32') tr_neg_mask = tr_train_mask.logical_not() @@ -138,7 +154,6 @@ class FCELoss(nn.Layer): return loss_tr, loss_tcl, loss_reg_x, loss_reg_y def ohem(self, predict, target, train_mask): - # device = train_mask.device pos = (target * train_mask).astype('bool') neg = ((1 - target) * train_mask).astype('bool') diff --git a/ppocr/modeling/heads/det_fce_head.py b/ppocr/modeling/heads/det_fce_head.py index 8f932851..2b6629b1 100644 --- a/ppocr/modeling/heads/det_fce_head.py +++ b/ppocr/modeling/heads/det_fce_head.py @@ -1,3 +1,21 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/dense_heads/fce_head.py +""" + from paddle import nn from paddle import ParamAttr import paddle.nn.functional as F @@ -7,22 +25,6 @@ from functools import partial def multi_apply(func, *args, **kwargs): - """Apply function to a list of arguments. - - Note: - This function applies the ``func`` to multiple inputs and - map the multiple outputs of the ``func`` into different - list. Each list contains the same type of outputs corresponding - to different inputs. - - Args: - func (Function): A function that will be applied to a list of - arguments - - Returns: - tuple(list): A tuple containing multiple list, each list contains \ - a kind of returned results by the function - """ pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) return tuple(map(list, zip(*map_results))) diff --git a/ppocr/modeling/necks/fce_fpn.py b/ppocr/modeling/necks/fce_fpn.py index 6a9e410a..954e964e 100644 --- a/ppocr/modeling/necks/fce_fpn.py +++ b/ppocr/modeling/necks/fce_fpn.py @@ -1,3 +1,21 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/ppdet/modeling/necks/fpn.py +""" + import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr diff --git a/ppocr/postprocess/fce_postprocess.py b/ppocr/postprocess/fce_postprocess.py index d97706b2..578bfe93 100755 --- a/ppocr/postprocess/fce_postprocess.py +++ b/ppocr/postprocess/fce_postprocess.py @@ -1,143 +1,26 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/v0.3.0/mmocr/models/textdet/postprocess/wrapper.py +""" -import numpy as np import cv2 import paddle +import numpy as np from numpy.fft import ifft -import Polygon as plg - - -def points2polygon(points): - """Convert k points to 1 polygon. - - Args: - points (ndarray or list): A ndarray or a list of shape (2k) - that indicates k points. - - Returns: - polygon (Polygon): A polygon object. - """ - if isinstance(points, list): - points = np.array(points) - - assert isinstance(points, np.ndarray) - assert (points.size % 2 == 0) and (points.size >= 8) - - point_mat = points.reshape([-1, 2]) - return plg.Polygon(point_mat) - - -def poly_intersection(poly_det, poly_gt): - """Calculate the intersection area between two polygon. - - Args: - poly_det (Polygon): A polygon predicted by detector. - poly_gt (Polygon): A gt polygon. - - Returns: - intersection_area (float): The intersection area between two polygons. - """ - assert isinstance(poly_det, plg.Polygon) - assert isinstance(poly_gt, plg.Polygon) - - poly_inter = poly_det & poly_gt - if len(poly_inter) == 0: - return 0, poly_inter - return poly_inter.area(), poly_inter - - -def poly_union(poly_det, poly_gt): - """Calculate the union area between two polygon. - - Args: - poly_det (Polygon): A polygon predicted by detector. - poly_gt (Polygon): A gt polygon. - - Returns: - union_area (float): The union area between two polygons. - """ - assert isinstance(poly_det, plg.Polygon) - assert isinstance(poly_gt, plg.Polygon) - - area_det = poly_det.area() - area_gt = poly_gt.area() - area_inters, _ = poly_intersection(poly_det, poly_gt) - return area_det + area_gt - area_inters - - -def valid_boundary(x, with_score=True): - num = len(x) - if num < 8: - return False - if num % 2 == 0 and (not with_score): - return True - if num % 2 == 1 and with_score: - return True - - return False - - -def boundary_iou(src, target): - """Calculate the IOU between two boundaries. - - Args: - src (list): Source boundary. - target (list): Target boundary. - - Returns: - iou (float): The iou between two boundaries. - """ - assert valid_boundary(src, False) - assert valid_boundary(target, False) - src_poly = points2polygon(src) - target_poly = points2polygon(target) - - return poly_iou(src_poly, target_poly) - - -def poly_iou(poly_det, poly_gt): - """Calculate the IOU between two polygons. - - Args: - poly_det (Polygon): A polygon predicted by detector. - poly_gt (Polygon): A gt polygon. - - Returns: - iou (float): The IOU between two polygons. - """ - assert isinstance(poly_det, plg.Polygon) - assert isinstance(poly_gt, plg.Polygon) - area_inters, _ = poly_intersection(poly_det, poly_gt) - area_union = poly_union(poly_det, poly_gt) - if area_union == 0: - return 0.0 - return area_inters / area_union - - -def poly_nms(polygons, threshold): - assert isinstance(polygons, list) - - polygons = np.array(sorted(polygons, key=lambda x: x[-1])) - - keep_poly = [] - index = [i for i in range(polygons.shape[0])] - - while len(index) > 0: - keep_poly.append(polygons[index[-1]].tolist()) - A = polygons[index[-1]][:-1] - index = np.delete(index, -1) - - iou_list = np.zeros((len(index), )) - for i in range(len(index)): - B = polygons[index[i]][:-1] - - iou_list[i] = boundary_iou(A, B) - remove_index = np.where(iou_list > threshold) - index = np.delete(index, remove_index) - - return keep_poly +from ppocr.utils.poly_nms import poly_nms, valid_boundary def fill_hole(input_mask): @@ -177,96 +60,6 @@ def fourier2poly(fourier_coeff, num_reconstr_points=50): return polygon.astype('int32').reshape((len(fourier_coeff), -1)) -def fcenet_decode(preds, - fourier_degree, - num_reconstr_points, - scale, - alpha=1.0, - beta=2.0, - text_repr_type='poly', - score_thr=0.3, - nms_thr=0.1): - """Decoding predictions of FCENet to instances. - - Args: - preds (list(Tensor)): The head output tensors. - fourier_degree (int): The maximum Fourier transform degree k. - num_reconstr_points (int): The points number of the polygon - reconstructed from predicted Fourier coefficients. - scale (int): The down-sample scale of the prediction. - alpha (float) : The parameter to calculate final scores. Score_{final} - = (Score_{text region} ^ alpha) - * (Score_{text center region}^ beta) - beta (float) : The parameter to calculate final score. - text_repr_type (str): Boundary encoding type 'poly' or 'quad'. - score_thr (float) : The threshold used to filter out the final - candidates. - nms_thr (float) : The threshold of nms. - - Returns: - boundaries (list[list[float]]): The instance boundary and confidence - list. - """ - assert isinstance(preds, list) - assert len(preds) == 2 - assert text_repr_type in ['poly', 'quad'] - - # import pdb;pdb.set_trace() - cls_pred = preds[0][0] - # tr_pred = F.softmax(cls_pred[0:2], axis=0).cpu().numpy() - # tcl_pred = F.softmax(cls_pred[2:], axis=0).cpu().numpy() - - tr_pred = cls_pred[0:2] - tcl_pred = cls_pred[2:] - - reg_pred = preds[1][0].transpose([1, 2, 0]) #.cpu().numpy() - x_pred = reg_pred[:, :, :2 * fourier_degree + 1] - y_pred = reg_pred[:, :, 2 * fourier_degree + 1:] - - score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta) - tr_pred_mask = (score_pred) > score_thr - tr_mask = fill_hole(tr_pred_mask) - - tr_contours, _ = cv2.findContours( - tr_mask.astype(np.uint8), cv2.RETR_TREE, - cv2.CHAIN_APPROX_SIMPLE) # opencv4 - - mask = np.zeros_like(tr_mask) - boundaries = [] - for cont in tr_contours: - deal_map = mask.copy().astype(np.int8) - cv2.drawContours(deal_map, [cont], -1, 1, -1) - - score_map = score_pred * deal_map - score_mask = score_map > 0 - xy_text = np.argwhere(score_mask) - dxy = xy_text[:, 1] + xy_text[:, 0] * 1j - - x, y = x_pred[score_mask], y_pred[score_mask] - c = x + y * 1j - c[:, fourier_degree] = c[:, fourier_degree] + dxy - c *= scale - - polygons = fourier2poly(c, num_reconstr_points) - score = score_map[score_mask].reshape(-1, 1) - polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr) - - boundaries = boundaries + polygons - - boundaries = poly_nms(boundaries, nms_thr) - - if text_repr_type == 'quad': - new_boundaries = [] - for boundary in boundaries: - poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32) - score = boundary[-1] - points = cv2.boxPoints(cv2.minAreaRect(poly)) - points = np.int0(points) - new_boundaries.append(points.reshape(-1).tolist() + [score]) - - return boundaries - - class FCEPostProcess(object): """ The post process for FCENet. @@ -316,10 +109,6 @@ class FCEPostProcess(object): Returns: boundaries (list[list[float]]): The scaled boundaries. """ - # assert check_argument.is_2dlist(boundaries) - # assert isinstance(scale_factor, np.ndarray) - # assert scale_factor.shape[0] == 4 - boxes = [] scores = [] for b in boundaries: @@ -335,7 +124,6 @@ class FCEPostProcess(object): def get_boundary(self, score_maps, shape_list): assert len(score_maps) == len(self.scales) - # import pdb;pdb.set_trace() boundaries = [] for idx, score_map in enumerate(score_maps): scale = self.scales[idx] @@ -344,8 +132,6 @@ class FCEPostProcess(object): # nms boundaries = poly_nms(boundaries, self.nms_thr) - # if rescale: - # import pdb;pdb.set_trace() boundaries, scores = self.resize_boundary( boundaries, (1 / shape_list[0, 2:]).tolist()[::-1]) @@ -356,7 +142,7 @@ class FCEPostProcess(object): assert len(score_map) == 2 assert score_map[1].shape[1] == 4 * self.fourier_degree + 2 - return fcenet_decode( + return self.fcenet_decode( preds=score_map, fourier_degree=self.fourier_degree, num_reconstr_points=self.num_reconstr_points, @@ -366,3 +152,89 @@ class FCEPostProcess(object): text_repr_type=self.text_repr_type, score_thr=self.score_thr, nms_thr=self.nms_thr) + + def fcenet_decode(self, + preds, + fourier_degree, + num_reconstr_points, + scale, + alpha=1.0, + beta=2.0, + text_repr_type='poly', + score_thr=0.3, + nms_thr=0.1): + """Decoding predictions of FCENet to instances. + + Args: + preds (list(Tensor)): The head output tensors. + fourier_degree (int): The maximum Fourier transform degree k. + num_reconstr_points (int): The points number of the polygon + reconstructed from predicted Fourier coefficients. + scale (int): The down-sample scale of the prediction. + alpha (float) : The parameter to calculate final scores. Score_{final} + = (Score_{text region} ^ alpha) + * (Score_{text center region}^ beta) + beta (float) : The parameter to calculate final score. + text_repr_type (str): Boundary encoding type 'poly' or 'quad'. + score_thr (float) : The threshold used to filter out the final + candidates. + nms_thr (float) : The threshold of nms. + + Returns: + boundaries (list[list[float]]): The instance boundary and confidence + list. + """ + assert isinstance(preds, list) + assert len(preds) == 2 + assert text_repr_type in ['poly', 'quad'] + + cls_pred = preds[0][0] + tr_pred = cls_pred[0:2] + tcl_pred = cls_pred[2:] + + reg_pred = preds[1][0].transpose([1, 2, 0]) + x_pred = reg_pred[:, :, :2 * fourier_degree + 1] + y_pred = reg_pred[:, :, 2 * fourier_degree + 1:] + + score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta) + tr_pred_mask = (score_pred) > score_thr + tr_mask = fill_hole(tr_pred_mask) + + tr_contours, _ = cv2.findContours( + tr_mask.astype(np.uint8), cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) # opencv4 + + mask = np.zeros_like(tr_mask) + boundaries = [] + for cont in tr_contours: + deal_map = mask.copy().astype(np.int8) + cv2.drawContours(deal_map, [cont], -1, 1, -1) + + score_map = score_pred * deal_map + score_mask = score_map > 0 + xy_text = np.argwhere(score_mask) + dxy = xy_text[:, 1] + xy_text[:, 0] * 1j + + x, y = x_pred[score_mask], y_pred[score_mask] + c = x + y * 1j + c[:, fourier_degree] = c[:, fourier_degree] + dxy + c *= scale + + polygons = fourier2poly(c, num_reconstr_points) + score = score_map[score_mask].reshape(-1, 1) + polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr) + + boundaries = boundaries + polygons + + boundaries = poly_nms(boundaries, nms_thr) + + if text_repr_type == 'quad': + new_boundaries = [] + for boundary in boundaries: + poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32) + score = boundary[-1] + points = cv2.boxPoints(cv2.minAreaRect(poly)) + points = np.int0(points) + new_boundaries.append(points.reshape(-1).tolist() + [score]) + + return boundaries diff --git a/ppocr/utils/poly_nms.py b/ppocr/utils/poly_nms.py new file mode 100644 index 00000000..2ee4ac0e --- /dev/null +++ b/ppocr/utils/poly_nms.py @@ -0,0 +1,145 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import Polygon as plg + + +def points2polygon(points): + """Convert k points to 1 polygon. + + Args: + points (ndarray or list): A ndarray or a list of shape (2k) + that indicates k points. + + Returns: + polygon (Polygon): A polygon object. + """ + if isinstance(points, list): + points = np.array(points) + + assert isinstance(points, np.ndarray) + assert (points.size % 2 == 0) and (points.size >= 8) + + point_mat = points.reshape([-1, 2]) + return plg.Polygon(point_mat) + + +def poly_intersection(poly_det, poly_gt): + """Calculate the intersection area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + intersection_area (float): The intersection area between two polygons. + """ + assert isinstance(poly_det, plg.Polygon) + assert isinstance(poly_gt, plg.Polygon) + + poly_inter = poly_det & poly_gt + if len(poly_inter) == 0: + return 0, poly_inter + return poly_inter.area(), poly_inter + + +def poly_union(poly_det, poly_gt): + """Calculate the union area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + union_area (float): The union area between two polygons. + """ + assert isinstance(poly_det, plg.Polygon) + assert isinstance(poly_gt, plg.Polygon) + + area_det = poly_det.area() + area_gt = poly_gt.area() + area_inters, _ = poly_intersection(poly_det, poly_gt) + return area_det + area_gt - area_inters + + +def valid_boundary(x, with_score=True): + num = len(x) + if num < 8: + return False + if num % 2 == 0 and (not with_score): + return True + if num % 2 == 1 and with_score: + return True + + return False + + +def boundary_iou(src, target): + """Calculate the IOU between two boundaries. + + Args: + src (list): Source boundary. + target (list): Target boundary. + + Returns: + iou (float): The iou between two boundaries. + """ + assert valid_boundary(src, False) + assert valid_boundary(target, False) + src_poly = points2polygon(src) + target_poly = points2polygon(target) + + return poly_iou(src_poly, target_poly) + + +def poly_iou(poly_det, poly_gt): + """Calculate the IOU between two polygons. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + iou (float): The IOU between two polygons. + """ + assert isinstance(poly_det, plg.Polygon) + assert isinstance(poly_gt, plg.Polygon) + area_inters, _ = poly_intersection(poly_det, poly_gt) + area_union = poly_union(poly_det, poly_gt) + if area_union == 0: + return 0.0 + return area_inters / area_union + + +def poly_nms(polygons, threshold): + assert isinstance(polygons, list) + + polygons = np.array(sorted(polygons, key=lambda x: x[-1])) + + keep_poly = [] + index = [i for i in range(polygons.shape[0])] + + while len(index) > 0: + keep_poly.append(polygons[index[-1]].tolist()) + A = polygons[index[-1]][:-1] + index = np.delete(index, -1) + iou_list = np.zeros((len(index), )) + for i in range(len(index)): + B = polygons[index[i]][:-1] + iou_list[i] = boundary_iou(A, B) + remove_index = np.where(iou_list > threshold) + index = np.delete(index, remove_index) + + return keep_poly \ No newline at end of file diff --git a/tools/program.py b/tools/program.py index 10299940..0e38ac92 100755 --- a/tools/program.py +++ b/tools/program.py @@ -503,7 +503,7 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'FCE' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/train.sh b/train.sh index 24277ec8..4225470c 100644 --- a/train.sh +++ b/train.sh @@ -1,3 +1,2 @@ # recommended paddle.__version__ == 2.0.0 -# python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -python -m paddle.distributed.launch --gpus '7' tools/train.py -c configs/det/det_r50_fce_ctw.yml +python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -- GitLab