diff --git a/configs/det/det_mv3_pse.yml b/configs/det/det_mv3_pse.yml new file mode 100644 index 0000000000000000000000000000000000000000..61ac24727acbd4f0b1eea15af08c0f9e71ce95a3 --- /dev/null +++ b/configs/det/det_mv3_pse.yml @@ -0,0 +1,135 @@ +Global: + use_gpu: true + epoch_num: 600 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/det_mv3_pse/ + save_epoch_step: 600 + # evaluation is run every 63 iterations + eval_batch_step: [ 0,63 ] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: #./output/det_r50_vd_pse_batch8_ColorJitter/best_accuracy + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_pse/predicts_pse.txt + +Architecture: + model_type: det + algorithm: PSE + Transform: null + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + Neck: + name: FPN + out_channels: 96 + Head: + name: PSEHead + hidden_dim: 96 + out_channels: 7 + +Loss: + name: PSELoss + alpha: 0.7 + ohem_ratio: 3 + kernel_sample_mask: pred + reduction: none + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Step + learning_rate: 0.001 + step_size: 200 + gamma: 0.1 + regularizer: + name: 'L2' + factor: 0.0005 + +PostProcess: + name: PSEPostProcess + thresh: 0 + box_thresh: 0.85 + min_area: 16 + box_type: box # 'box' or 'poly' + scale: 1 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - ColorJitter: + brightness: 0.12549019607843137 + saturation: 0.5 + - IaaAugment: + augmenter_args: + - { 'type': Resize, 'args': { 'size': [ 0.5, 3 ] } } + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [ -10, 10 ] } } + - MakePseGt: + kernel_num: 7 + min_shrink_ratio: 0.4 + size: 640 + - RandomCropImgMask: + size: [ 640,640 ] + main_key: gt_text + crop_keys: [ 'image', 'gt_text', 'gt_kernels', 'mask' ] + - NormalizeImage: + scale: 1./255. + mean: [ 0.485, 0.456, 0.406 ] + std: [ 0.229, 0.224, 0.225 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'gt_text', 'gt_kernels', 'mask' ] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 16 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + limit_side_len: 736 + limit_type: min + - NormalizeImage: + scale: 1./255. + mean: [ 0.485, 0.456, 0.406 ] + std: [ 0.229, 0.224, 0.225 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'shape', 'polys', 'ignore_tags' ] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 8 \ No newline at end of file diff --git a/configs/det/det_r50_vd_pse.yml b/configs/det/det_r50_vd_pse.yml new file mode 100644 index 0000000000000000000000000000000000000000..4629210747d3b61344cc47b11dcff01e6b738586 --- /dev/null +++ b/configs/det/det_r50_vd_pse.yml @@ -0,0 +1,134 @@ +Global: + use_gpu: true + epoch_num: 600 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/det_r50_vd_pse/ + save_epoch_step: 600 + # evaluation is run every 125 iterations + eval_batch_step: [ 0,125 ] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained + checkpoints: #./output/det_r50_vd_pse_batch8_ColorJitter/best_accuracy + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_pse/predicts_pse.txt + +Architecture: + model_type: det + algorithm: PSE + Transform: + Backbone: + name: ResNet + layers: 50 + Neck: + name: FPN + out_channels: 256 + Head: + name: PSEHead + hidden_dim: 256 + out_channels: 7 + +Loss: + name: PSELoss + alpha: 0.7 + ohem_ratio: 3 + kernel_sample_mask: pred + reduction: none + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Step + learning_rate: 0.0001 + step_size: 200 + gamma: 0.1 + regularizer: + name: 'L2' + factor: 0.0005 + +PostProcess: + name: PSEPostProcess + thresh: 0 + box_thresh: 0.85 + min_area: 16 + box_type: box # 'box' or 'poly' + scale: 1 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - ColorJitter: + brightness: 0.12549019607843137 + saturation: 0.5 + - IaaAugment: + augmenter_args: + - { 'type': Resize, 'args': { 'size': [ 0.5, 3 ] } } + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [ -10, 10 ] } } + - MakePseGt: + kernel_num: 7 + min_shrink_ratio: 0.4 + size: 640 + - RandomCropImgMask: + size: [ 640,640 ] + main_key: gt_text + crop_keys: [ 'image', 'gt_text', 'gt_kernels', 'mask' ] + - NormalizeImage: + scale: 1./255. + mean: [ 0.485, 0.456, 0.406 ] + std: [ 0.229, 0.224, 0.225 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'gt_text', 'gt_kernels', 'mask' ] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + limit_side_len: 736 + limit_type: min + - NormalizeImage: + scale: 1./255. + mean: [ 0.485, 0.456, 0.406 ] + std: [ 0.229, 0.224, 0.225 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'shape', 'polys', 'ignore_tags' ] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 8 \ No newline at end of file diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index d465539166130bb5db9bc587ccbfc3e97b15a236..6daacaf7f08e5bd0c5e5aba8b435c5bd26eaca7b 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -9,11 +9,13 @@ ### 1.文本检测算法 PaddleOCR开源的文本检测算法列表: -- [x] DB([paper]( https://arxiv.org/abs/1911.08947)) [2](ppocr推荐) -- [x] EAST([paper](https://arxiv.org/abs/1704.03155))[1] -- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4] +- [x] DB([paper]( https://arxiv.org/abs/1911.08947))(ppocr推荐) +- [x] EAST([paper](https://arxiv.org/abs/1704.03155)) +- [x] SAST([paper](https://arxiv.org/abs/1908.05498)) +- [x] PSENet([paper](https://arxiv.org/abs/1903.12473v2)) 在ICDAR2015文本检测公开数据集上,算法效果如下: + |模型|骨干网络|precision|recall|Hmean|下载链接| | --- | --- | --- | --- | --- | --- | |EAST|ResNet50_vd|85.80%|86.71%|86.25%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)| @@ -21,6 +23,8 @@ PaddleOCR开源的文本检测算法列表: |DB|ResNet50_vd|86.41%|78.72%|82.38%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| +|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| +|PSE|MobileNetV3|82.20%|70.48%|75.89%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| 在Total-text文本检测公开数据集上,算法效果如下: @@ -39,15 +43,15 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训 ### 2.文本识别算法 PaddleOCR基于动态图开源的文本识别算法列表: -- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐) -- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10] -- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] -- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] -- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))(ppocr推荐) +- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085)) +- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) +- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1)) +- [x] SRN([paper](https://arxiv.org/abs/2003.12294)) - [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) - [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) -参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: +参考[DTRB](https://arxiv.org/abs/1904.01906) 文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: |模型|骨干网络|Avg Accuracy|模型存储命名|下载链接| |---|---|---|---|---| diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 05728153239ba1a52ae08c7e59b45340648612d5..df8a4ce3ef5fbcadb7ebdfd8ddf2bdf59637783e 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -11,9 +11,10 @@ This tutorial lists the text detection algorithms and text recognition algorithm ### 1. Text Detection Algorithm PaddleOCR open source text detection algorithms list: -- [x] EAST([paper](https://arxiv.org/abs/1704.03155))[2] -- [x] DB([paper](https://arxiv.org/abs/1911.08947))[1] -- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4] +- [x] EAST([paper](https://arxiv.org/abs/1704.03155)) +- [x] DB([paper](https://arxiv.org/abs/1911.08947)) +- [x] SAST([paper](https://arxiv.org/abs/1908.05498)) +- [x] PSE([paper](https://arxiv.org/abs/1903.12473v2)) On the ICDAR2015 dataset, the text detection result is as follows: @@ -24,6 +25,8 @@ On the ICDAR2015 dataset, the text detection result is as follows: |DB|ResNet50_vd|86.41%|78.72%|82.38%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| +|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| +|PSE|MobileNetV3|82.20%|70.48%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| On Total-Text dataset, the text detection result is as follows: @@ -41,11 +44,11 @@ For the training guide and use of PaddleOCR text detection algorithms, please re ### 2. Text Recognition Algorithm PaddleOCR open-source text recognition algorithms list: -- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7] -- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10] -- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] -- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] -- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] CRNN([paper](https://arxiv.org/abs/1507.05717)) +- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085)) +- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) +- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1)) +- [x] SRN([paper](https://arxiv.org/abs/2003.12294)) - [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) - [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) diff --git a/ppocr/data/imaug/ColorJitter.py b/ppocr/data/imaug/ColorJitter.py new file mode 100644 index 0000000000000000000000000000000000000000..4b542abc8f9dc5af76529f9feb4bcb8b47b5f7d0 --- /dev/null +++ b/ppocr/data/imaug/ColorJitter.py @@ -0,0 +1,26 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.vision.transforms import ColorJitter as pp_ColorJitter + +__all__ = ['ColorJitter'] + +class ColorJitter(object): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs): + self.aug = pp_ColorJitter(brightness, contrast, saturation, hue) + + def __call__(self, data): + image = data['image'] + image = self.aug(image) + data['image'] = image + return data diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 8bfc17508398224b86d7e9d9a03a250c266a54a8..5aaa1cd71eb791efa94e6bd812f3ab76632c96c6 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -19,11 +19,13 @@ from __future__ import unicode_literals from .iaa_augment import IaaAugment from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap -from .random_crop_data import EastRandomCropData, PSERandomCrop +from .random_crop_data import EastRandomCropData, RandomCropImgMask +from .make_pse_gt import MakePseGt from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste +from .ColorJitter import ColorJitter from .operators import * from .label_ops import * diff --git a/ppocr/data/imaug/make_pse_gt.py b/ppocr/data/imaug/make_pse_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..55abc8970784fd00843d2e91f259c58b65ae8579 --- /dev/null +++ b/ppocr/data/imaug/make_pse_gt.py @@ -0,0 +1,85 @@ +# -*- coding:utf-8 -*- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +__all__ = ['MakePseGt'] + +class MakePseGt(object): + r''' + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs): + self.kernel_num = kernel_num + self.min_shrink_ratio = min_shrink_ratio + self.size = size + + def __call__(self, data): + + image = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + + h, w, _ = image.shape + short_edge = min(h, w) + if short_edge < self.size: + # keep short_size >= self.size + scale = self.size / short_edge + image = cv2.resize(image, dsize=None, fx=scale, fy=scale) + text_polys *= scale + + gt_kernels = [] + for i in range(1,self.kernel_num+1): + # s1->sn, from big to small + rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i + text_kernel, ignore_tags = self.generate_kernel(image.shape[0:2], rate, text_polys, ignore_tags) + gt_kernels.append(text_kernel) + + training_mask = np.ones(image.shape[0:2], dtype='uint8') + for i in range(text_polys.shape[0]): + if ignore_tags[i]: + cv2.fillPoly(training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0) + + gt_kernels = np.array(gt_kernels) + gt_kernels[gt_kernels > 0] = 1 + + data['image'] = image + data['polys'] = text_polys + data['gt_kernels'] = gt_kernels[0:] + data['gt_text'] = gt_kernels[0] + data['mask'] = training_mask.astype('float32') + return data + + def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None): + h, w = img_size + text_kernel = np.zeros((h, w), dtype=np.float32) + for i, poly in enumerate(text_polys): + polygon = Polygon(poly) + distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (polygon.length + 1e-6) + subject = [tuple(l) for l in poly] + pco = pyclipper.PyclipperOffset() + pco.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + shrinked = np.array(pco.Execute(-distance)) + + if len(shrinked) == 0 or shrinked.size == 0: + if ignore_tags is not None: + ignore_tags[i] = True + continue + try: + shrinked = np.array(shrinked[0]).reshape(-1, 2) + except: + if ignore_tags is not None: + ignore_tags[i] = True + continue + cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1) + return text_kernel, ignore_tags diff --git a/ppocr/data/imaug/random_crop_data.py b/ppocr/data/imaug/random_crop_data.py index 4d67cff61d6f340be6d80d8243c68909a94c4e88..7c1c25abb56a0cf7d4d59b8523962bd5d81c873a 100644 --- a/ppocr/data/imaug/random_crop_data.py +++ b/ppocr/data/imaug/random_crop_data.py @@ -164,47 +164,55 @@ class EastRandomCropData(object): return data -class PSERandomCrop(object): - def __init__(self, size, **kwargs): +class RandomCropImgMask(object): + def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs): self.size = size + self.main_key = main_key + self.crop_keys = crop_keys + self.p = p def __call__(self, data): - imgs = data['imgs'] + image = data['image'] - h, w = imgs[0].shape[0:2] + h, w = image.shape[0:2] th, tw = self.size if w == tw and h == th: - return imgs + return data - # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制 - if np.max(imgs[2]) > 0 and random.random() > 3 / 8: - # 文本实例的左上角点 - tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size + mask = data[self.main_key] + if np.max(mask) > 0 and random.random() > self.p: + # make sure to crop the text region + tl = np.min(np.where(mask > 0), axis=1) - (th, tw) tl[tl < 0] = 0 - # 文本实例的右下角点 - br = np.max(np.where(imgs[2] > 0), axis=1) - self.size + br = np.max(np.where(mask > 0), axis=1) - (th, tw) br[br < 0] = 0 - # 保证选到右下角点时,有足够的距离进行crop + br[0] = min(br[0], h - th) br[1] = min(br[1], w - tw) - for _ in range(50000): - i = random.randint(tl[0], br[0]) - j = random.randint(tl[1], br[1]) - # 保证shrink_label_map有文本 - if imgs[1][i:i + th, j:j + tw].sum() <= 0: - continue - else: - break + i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 + j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 else: - i = random.randint(0, h - th) - j = random.randint(0, w - tw) + i = random.randint(0, h - th) if h - th > 0 else 0 + j = random.randint(0, w - tw) if w - tw > 0 else 0 # return i, j, th, tw - for idx in range(len(imgs)): - if len(imgs[idx].shape) == 3: - imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] - else: - imgs[idx] = imgs[idx][i:i + th, j:j + tw] - data['imgs'] = imgs + for k in data: + if k in self.crop_keys: + if len(data[k].shape) == 3: + if np.argmin(data[k].shape) == 0: + img = data[k][:, i:i + th, j:j + tw] + if img.shape[1] != img.shape[2]: + a = 1 + elif np.argmin(data[k].shape) == 2: + img = data[k][i:i + th, j:j + tw, :] + if img.shape[1] != img.shape[0]: + a = 1 + else: + img = data[k] + else: + img = data[k][i:i + th, j:j + tw] + if img.shape[0] != img.shape[1]: + a = 1 + data[k] = img return data diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 0484542f0b8c3cc46ce643d9f4577fe87c6346b7..467d7b6579cef5eca2532cd30df43169184967fc 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -20,6 +20,7 @@ import paddle.nn as nn from .det_db_loss import DBLoss from .det_east_loss import EASTLoss from .det_sast_loss import SASTLoss +from .det_pse_loss import PSELoss # rec loss from .rec_ctc_loss import CTCLoss @@ -42,10 +43,12 @@ from .combined_loss import CombinedLoss # table loss from .table_att_loss import TableAttentionLoss + def build_loss(config): support_dict = [ - 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss' + 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', + 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', + 'TableAttentionLoss', 'SARLoss' ] config = copy.deepcopy(config) diff --git a/ppocr/losses/det_basic_loss.py b/ppocr/losses/det_basic_loss.py index eba5526dd2bd1c0328130b50817172df437cc360..7017236c284e55710f242275a413d56d32158d34 100644 --- a/ppocr/losses/det_basic_loss.py +++ b/ppocr/losses/det_basic_loss.py @@ -75,12 +75,6 @@ class BalanceLoss(nn.Layer): mask (variable): masked maps. return: (variable) balanced loss """ - # if self.main_loss_type in ['DiceLoss']: - # # For the loss that returns to scalar value, perform ohem on the mask - # mask = ohem_batch(pred, gt, mask, self.negative_ratio) - # loss = self.loss(pred, gt, mask) - # return loss - positive = gt * mask negative = (1 - gt) * mask @@ -153,53 +147,4 @@ class BCELoss(nn.Layer): def forward(self, input, label, mask=None, weight=None, name=None): loss = F.binary_cross_entropy(input, label, reduction=self.reduction) - return loss - - -def ohem_single(score, gt_text, training_mask, ohem_ratio): - pos_num = (int)(np.sum(gt_text > 0.5)) - ( - int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5))) - - if pos_num == 0: - # selected_mask = gt_text.copy() * 0 # may be not good - selected_mask = training_mask - selected_mask = selected_mask.reshape( - 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') - return selected_mask - - neg_num = (int)(np.sum(gt_text <= 0.5)) - neg_num = (int)(min(pos_num * ohem_ratio, neg_num)) - - if neg_num == 0: - selected_mask = training_mask - selected_mask = selected_mask.reshape( - 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') - return selected_mask - - neg_score = score[gt_text <= 0.5] - # 将负样本得分从高到低排序 - neg_score_sorted = np.sort(-neg_score) - threshold = -neg_score_sorted[neg_num - 1] - # 选出 得分高的 负样本 和正样本 的 mask - selected_mask = ((score >= threshold) | - (gt_text > 0.5)) & (training_mask > 0.5) - selected_mask = selected_mask.reshape( - 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') - return selected_mask - - -def ohem_batch(scores, gt_texts, training_masks, ohem_ratio): - scores = scores.numpy() - gt_texts = gt_texts.numpy() - training_masks = training_masks.numpy() - - selected_masks = [] - for i in range(scores.shape[0]): - selected_masks.append( - ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[ - i, :, :], ohem_ratio)) - - selected_masks = np.concatenate(selected_masks, 0) - selected_masks = paddle.to_tensor(selected_masks) - - return selected_masks + return loss \ No newline at end of file diff --git a/ppocr/losses/det_pse_loss.py b/ppocr/losses/det_pse_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..78423091f841f29b1217f73f79beb26fe1575844 --- /dev/null +++ b/ppocr/losses/det_pse_loss.py @@ -0,0 +1,145 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +from paddle.nn import functional as F +import numpy as np +from ppocr.utils.iou import iou + + +class PSELoss(nn.Layer): + def __init__(self, + alpha, + ohem_ratio=3, + kernel_sample_mask='pred', + reduction='sum', + eps=1e-6, + **kwargs): + """Implement PSE Loss. + """ + super(PSELoss, self).__init__() + assert reduction in ['sum', 'mean', 'none'] + self.alpha = alpha + self.ohem_ratio = ohem_ratio + self.kernel_sample_mask = kernel_sample_mask + self.reduction = reduction + self.eps = eps + + def forward(self, outputs, labels): + predicts = outputs['maps'] + predicts = F.interpolate(predicts, scale_factor=4) + + texts = predicts[:, 0, :, :] + kernels = predicts[:, 1:, :, :] + gt_texts, gt_kernels, training_masks = labels[1:] + + # text loss + selected_masks = self.ohem_batch(texts, gt_texts, training_masks) + + loss_text = self.dice_loss(texts, gt_texts, selected_masks) + iou_text = iou((texts > 0).astype('int64'), + gt_texts, + training_masks, + reduce=False) + losses = dict(loss_text=loss_text, iou_text=iou_text) + + # kernel loss + loss_kernels = [] + if self.kernel_sample_mask == 'gt': + selected_masks = gt_texts * training_masks + elif self.kernel_sample_mask == 'pred': + selected_masks = ( + F.sigmoid(texts) > 0.5).astype('float32') * training_masks + + for i in range(kernels.shape[1]): + kernel_i = kernels[:, i, :, :] + gt_kernel_i = gt_kernels[:, i, :, :] + loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, + selected_masks) + loss_kernels.append(loss_kernel_i) + loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1) + iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'), + gt_kernels[:, -1, :, :], + training_masks * gt_texts, + reduce=False) + losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel)) + loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels + losses['loss'] = loss + if self.reduction == 'sum': + losses = {x: paddle.sum(v) for x, v in losses.items()} + elif self.reduction == 'mean': + losses = {x: paddle.mean(v) for x, v in losses.items()} + return losses + + def dice_loss(self, input, target, mask): + input = F.sigmoid(input) + + input = input.reshape([input.shape[0], -1]) + target = target.reshape([target.shape[0], -1]) + mask = mask.reshape([mask.shape[0], -1]) + + input = input * mask + target = target * mask + + a = paddle.sum(input * target, 1) + b = paddle.sum(input * input, 1) + self.eps + c = paddle.sum(target * target, 1) + self.eps + d = (2 * a) / (b + c) + return 1 - d + + def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3): + pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int( + paddle.sum( + paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5)) + .astype('float32'))) + + if pos_num == 0: + selected_mask = training_mask + selected_mask = selected_mask.reshape( + [1, selected_mask.shape[0], selected_mask.shape[1]]).astype( + 'float32') + return selected_mask + + neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32'))) + neg_num = int(min(pos_num * ohem_ratio, neg_num)) + + if neg_num == 0: + selected_mask = training_mask + selected_mask = selected_mask.view( + 1, selected_mask.shape[0], + selected_mask.shape[1]).astype('float32') + return selected_mask + + neg_score = paddle.masked_select(score, gt_text <= 0.5) + neg_score_sorted = paddle.sort(-neg_score) + threshold = -neg_score_sorted[neg_num - 1] + + selected_mask = paddle.logical_and( + paddle.logical_or((score >= threshold), (gt_text > 0.5)), + (training_mask > 0.5)) + selected_mask = selected_mask.reshape( + [1, selected_mask.shape[0], selected_mask.shape[1]]).astype( + 'float32') + return selected_mask + + def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3): + selected_masks = [] + for i in range(scores.shape[0]): + selected_masks.append( + self.ohem_single(scores[i, :, :], gt_texts[i, :, :], + training_masks[i, :, :], ohem_ratio)) + + selected_masks = paddle.concat(selected_masks, 0).astype('float32') + return selected_masks diff --git a/ppocr/metrics/eval_det_iou.py b/ppocr/metrics/eval_det_iou.py index 0e32b2d19281de9a18a1fe0343bd7e8237825b7b..bc05e7df7d1d21abfb9d9fbd224ecd7254d9f393 100644 --- a/ppocr/metrics/eval_det_iou.py +++ b/ppocr/metrics/eval_det_iou.py @@ -169,21 +169,10 @@ class DetectionIoUEvaluator(object): numGlobalCareDet += numDetCare perSampleMetrics = { - 'precision': precision, - 'recall': recall, - 'hmean': hmean, - 'pairs': pairs, - 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), - 'gtPolPoints': gtPolPoints, - 'detPolPoints': detPolPoints, 'gtCare': numGtCare, 'detCare': numDetCare, - 'gtDontCare': gtDontCarePolsNum, - 'detDontCare': detDontCarePolsNum, 'detMatched': detMatched, - 'evaluationLog': evaluationLog } - return perSampleMetrics def combine_results(self, results): diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 80311f92a6c0c7a67673b90ee19ff5d8778ac0e8..20eb150fc3f91c3abd9a3fd8ce655c24ccab8179 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -20,6 +20,7 @@ def build_head(config): from .det_db_head import DBHead from .det_east_head import EASTHead from .det_sast_head import SASTHead + from .det_pse_head import PSEHead from .e2e_pg_head import PGHead # rec head @@ -32,8 +33,9 @@ def build_head(config): # cls head from .cls_head import ClsHead support_dict = [ - 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead' + 'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', + 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', + 'TableAttentionHead', 'SARHead' ] #table head diff --git a/ppocr/modeling/heads/det_pse_head.py b/ppocr/modeling/heads/det_pse_head.py new file mode 100644 index 0000000000000000000000000000000000000000..db800f57a216ab437b724988ce692a9ac0c545d9 --- /dev/null +++ b/ppocr/modeling/heads/det_pse_head.py @@ -0,0 +1,35 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle import nn + + +class PSEHead(nn.Layer): + def __init__(self, + in_channels, + hidden_dim=256, + out_channels=7, + **kwargs): + super(PSEHead, self).__init__() + self.conv1 = nn.Conv2D(in_channels, hidden_dim, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2D(hidden_dim) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2D(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0) + + + def forward(self, x, **kwargs): + out = self.conv1(x) + out = self.relu1(self.bn1(out)) + out = self.conv2(out) + return {'maps': out} diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index e97c4f64bdc9acd6729d67a9c6ff7a7563f6c95e..5606a4c35f68021e7f151a7eae4a0da4d5b6b95e 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -22,7 +22,8 @@ def build_neck(config): from .rnn import SequenceEncoder from .pg_fpn import PGFPN from .table_fpn import TableFPN - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN'] + from .fpn import FPN + support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/modeling/necks/fpn.py b/ppocr/modeling/necks/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..8728a5c9ded5b9c174fd34f088d8012961f65ec0 --- /dev/null +++ b/ppocr/modeling/necks/fpn.py @@ -0,0 +1,100 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn +import paddle +import math +import paddle.nn.functional as F + +class Conv_BN_ReLU(nn.Layer): + def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0): + super(Conv_BN_ReLU, self).__init__() + self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + bias_attr=False) + self.bn = nn.BatchNorm2D(out_planes, momentum=0.1) + self.relu = nn.ReLU() + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Normal(0, math.sqrt(2. / n))) + elif isinstance(m, nn.BatchNorm2D): + m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(1.0)) + m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(0.0)) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + +class FPN(nn.Layer): + def __init__(self, in_channels, out_channels): + super(FPN, self).__init__() + + # Top layer + self.toplayer_ = Conv_BN_ReLU(in_channels[3], out_channels, kernel_size=1, stride=1, padding=0) + # Lateral layers + self.latlayer1_ = Conv_BN_ReLU(in_channels[2], out_channels, kernel_size=1, stride=1, padding=0) + + self.latlayer2_ = Conv_BN_ReLU(in_channels[1], out_channels, kernel_size=1, stride=1, padding=0) + + self.latlayer3_ = Conv_BN_ReLU(in_channels[0], out_channels, kernel_size=1, stride=1, padding=0) + + # Smooth layers + self.smooth1_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.smooth2_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.smooth3_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + + self.out_channels = out_channels * 4 + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', + default_initializer=paddle.nn.initializer.Normal(0, + math.sqrt(2. / n))) + elif isinstance(m, nn.BatchNorm2D): + m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', + default_initializer=paddle.nn.initializer.Constant(1.0)) + m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', + default_initializer=paddle.nn.initializer.Constant(0.0)) + + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + + def _upsample_add(self, x, y, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + y + + def forward(self, x): + f2, f3, f4, f5 = x + p5 = self.toplayer_(f5) + + f4 = self.latlayer1_(f4) + p4 = self._upsample_add(p5, f4,2) + p4 = self.smooth1_(p4) + + f3 = self.latlayer2_(f3) + p3 = self._upsample_add(p4, f3,2) + p3 = self.smooth2_(p3) + + f2 = self.latlayer3_(f2) + p2 = self._upsample_add(p3, f2,2) + p2 = self.smooth3_(p2) + + p3 = self._upsample(p3, 2) + p4 = self._upsample(p4, 4) + p5 = self._upsample(p5, 8) + + fuse = paddle.concat([p2, p3, p4, p5], axis=1) + return fuse \ No newline at end of file diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 77081abeb691f01d0d3fcc8285d78cf5878411a7..3eb5e28da34bf4d9ed478c52ab38b1de21c0d1e3 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -28,12 +28,14 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di TableLabelDecode, SARLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess +from .pse_postprocess import PSEPostProcess + def build_post_process(config, global_config=None): support_dict = [ - 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', - 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode', 'TableLabelDecode', + 'DBPostProcess', 'PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess', + 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', + 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode' ] diff --git a/ppocr/postprocess/pse_postprocess/__init__.py b/ppocr/postprocess/pse_postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..680473bf4b1863ac695dc8173778e59bd4fdacf9 --- /dev/null +++ b/ppocr/postprocess/pse_postprocess/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pse_postprocess import PSEPostProcess \ No newline at end of file diff --git a/ppocr/postprocess/pse_postprocess/pse/README.md b/ppocr/postprocess/pse_postprocess/pse/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9c2d9eaeaa5f93550358ebdd4d9161330b78a86f --- /dev/null +++ b/ppocr/postprocess/pse_postprocess/pse/README.md @@ -0,0 +1,5 @@ +## 编译 +code from https://github.com/whai362/pan_pp.pytorch +```python +python3 setup.py build_ext --inplace +``` diff --git a/ppocr/postprocess/pse_postprocess/pse/__init__.py b/ppocr/postprocess/pse_postprocess/pse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97b8d8aff0cf229a4e3ec1961638273bd201822a --- /dev/null +++ b/ppocr/postprocess/pse_postprocess/pse/__init__.py @@ -0,0 +1,23 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +import subprocess + +python_path = sys.executable + +if subprocess.call('cd ppocr/postprocess/pse_postprocess/pse;{} setup.py build_ext --inplace;cd -'.format(python_path), shell=True) != 0: + raise RuntimeError('Cannot compile pse: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + +from .pse import pse \ No newline at end of file diff --git a/ppocr/postprocess/pse_postprocess/pse/pse.pyx b/ppocr/postprocess/pse_postprocess/pse/pse.pyx new file mode 100644 index 0000000000000000000000000000000000000000..b2be49e9471865c11b840207f922258e67a554b6 --- /dev/null +++ b/ppocr/postprocess/pse_postprocess/pse/pse.pyx @@ -0,0 +1,70 @@ + +import numpy as np +import cv2 +cimport numpy as np +cimport cython +cimport libcpp +cimport libcpp.pair +cimport libcpp.queue +from libcpp.pair cimport * +from libcpp.queue cimport * + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, + np.ndarray[np.int32_t, ndim=2] label, + int kernel_num, + int label_num, + float min_area=0): + cdef np.ndarray[np.int32_t, ndim=2] pred + pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) + + for label_idx in range(1, label_num): + if np.sum(label == label_idx) < min_area: + label[label == label_idx] = 0 + + cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ + queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() + cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ + queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() + cdef np.int16_t* dx = [-1, 1, 0, 0] + cdef np.int16_t* dy = [0, 0, -1, 1] + cdef np.int16_t tmpx, tmpy + + points = np.array(np.where(label > 0)).transpose((1, 0)) + for point_idx in range(points.shape[0]): + tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] + que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) + pred[tmpx, tmpy] = label[tmpx, tmpy] + + cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur + cdef int cur_label + for kernel_idx in range(kernel_num - 1, -1, -1): + while not que.empty(): + cur = que.front() + que.pop() + cur_label = pred[cur.first, cur.second] + + is_edge = True + for j in range(4): + tmpx = cur.first + dx[j] + tmpy = cur.second + dy[j] + if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: + continue + if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: + continue + + que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) + pred[tmpx, tmpy] = cur_label + is_edge = False + if is_edge: + nxt_que.push(cur) + + que, nxt_que = nxt_que, que + + return pred + +def pse(kernels, min_area): + kernel_num = kernels.shape[0] + label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4) + return _pse(kernels[:-1], label, kernel_num, label_num, min_area) \ No newline at end of file diff --git a/ppocr/postprocess/pse_postprocess/pse/setup.py b/ppocr/postprocess/pse_postprocess/pse/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..03746782af791938bff31c24e4a760f566c73b49 --- /dev/null +++ b/ppocr/postprocess/pse_postprocess/pse/setup.py @@ -0,0 +1,14 @@ +from distutils.core import setup, Extension +from Cython.Build import cythonize +import numpy + +setup(ext_modules=cythonize(Extension( + 'pse', + sources=['pse.pyx'], + language='c++', + include_dirs=[numpy.get_include()], + library_dirs=[], + libraries=[], + extra_compile_args=['-O3'], + extra_link_args=[] +))) diff --git a/ppocr/postprocess/pse_postprocess/pse_postprocess.py b/ppocr/postprocess/pse_postprocess/pse_postprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..4b89d221d284602933ab3d4f21468fcae79ef310 --- /dev/null +++ b/ppocr/postprocess/pse_postprocess/pse_postprocess.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 +import paddle +from paddle.nn import functional as F + +from ppocr.postprocess.pse_postprocess.pse import pse + + +class PSEPostProcess(object): + """ + The post process for PSE. + """ + + def __init__(self, + thresh=0.5, + box_thresh=0.85, + min_area=16, + box_type='box', + scale=4, + **kwargs): + assert box_type in ['box', 'poly'], 'Only box and poly is supported' + self.thresh = thresh + self.box_thresh = box_thresh + self.min_area = min_area + self.box_type = box_type + self.scale = scale + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] + if not isinstance(pred, paddle.Tensor): + pred = paddle.to_tensor(pred) + pred = F.interpolate(pred, scale_factor=4 // self.scale, mode='bilinear') + + score = F.sigmoid(pred[:, 0, :, :]) + + kernels = (pred > self.thresh).astype('float32') + text_mask = kernels[:, 0, :, :] + kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask + + score = score.numpy() + kernels = kernels.numpy().astype(np.uint8) + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + boxes, scores = self.boxes_from_bitmap(score[batch_index], kernels[batch_index], shape_list[batch_index]) + + boxes_batch.append({'points': boxes, 'scores': scores}) + return boxes_batch + + def boxes_from_bitmap(self, score, kernels, shape): + label = pse(kernels, self.min_area) + return self.generate_box(score, label, shape) + + def generate_box(self, score, label, shape): + src_h, src_w, ratio_h, ratio_w = shape + label_num = np.max(label) + 1 + + boxes = [] + scores = [] + for i in range(1, label_num): + ind = label == i + points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1] + + if points.shape[0] < self.min_area: + label[ind] = 0 + continue + + score_i = np.mean(score[ind]) + if score_i < self.box_thresh: + label[ind] = 0 + continue + + if self.box_type == 'box': + rect = cv2.minAreaRect(points) + bbox = cv2.boxPoints(rect) + elif self.box_type == 'poly': + box_height = np.max(points[:, 1]) + 10 + box_width = np.max(points[:, 0]) + 10 + + mask = np.zeros((box_height, box_width), np.uint8) + mask[points[:, 1], points[:, 0]] = 255 + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + bbox = np.squeeze(contours[0], 1) + else: + raise NotImplementedError + + bbox[:, 0] = np.clip( + np.round(bbox[:, 0] / ratio_w), 0, src_w) + bbox[:, 1] = np.clip( + np.round(bbox[:, 1] / ratio_h), 0, src_h) + boxes.append(bbox) + scores.append(score_i) + return boxes, scores diff --git a/ppocr/utils/iou.py b/ppocr/utils/iou.py new file mode 100644 index 0000000000000000000000000000000000000000..20529dee2d14083f3de4ac034668d004136c56e2 --- /dev/null +++ b/ppocr/utils/iou.py @@ -0,0 +1,48 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +EPS = 1e-6 + +def iou_single(a, b, mask, n_class): + valid = mask == 1 + a = a.masked_select(valid) + b = b.masked_select(valid) + miou = [] + for i in range(n_class): + if a.shape == [0] and a.shape==b.shape: + inter = paddle.to_tensor(0.0) + union = paddle.to_tensor(0.0) + else: + inter = ((a == i).logical_and(b == i)).astype('float32') + union = ((a == i).logical_or(b == i)).astype('float32') + miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS)) + miou = sum(miou) / len(miou) + return miou + +def iou(a, b, mask, n_class=2, reduce=True): + batch_size = a.shape[0] + + a = a.reshape([batch_size, -1]) + b = b.reshape([batch_size, -1]) + mask = mask.reshape([batch_size, -1]) + + iou = paddle.zeros((batch_size,), dtype='float32') + for i in range(batch_size): + iou[i] = iou_single(a[i], b[i], mask[i], n_class) + + if reduce: + iou = paddle.mean(iou) + return iou \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2c7baa8516932f56f77b71b4e6dc7d45cd43072e..0b2366c5cd344260d7afab811b27e19499a89b26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ numpy visualdl python-Levenshtein opencv-contrib-python==4.4.0.46 +cython lxml premailer openpyxl \ No newline at end of file diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 6347ca6dc719f0d489736dbf285eedd775d3790e..b24ad2bbb504caf1f262b4e47625348ce32d6fce 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -89,6 +89,14 @@ class TextDetector(object): postprocess_params["sample_pts_num"] = 2 postprocess_params["expand_scale"] = 1.0 postprocess_params["shrink_ratio_of_width"] = 0.3 + elif self.det_algorithm == "PSE": + postprocess_params['name'] = 'PSEPostProcess' + postprocess_params["thresh"] = args.det_pse_thresh + postprocess_params["box_thresh"] = args.det_pse_box_thresh + postprocess_params["min_area"] = args.det_pse_min_area + postprocess_params["box_type"] = args.det_pse_box_type + postprocess_params["scale"] = args.det_pse_scale + self.det_pse_box_type = args.det_pse_box_type else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -209,7 +217,7 @@ class TextDetector(object): preds['f_score'] = outputs[1] preds['f_tco'] = outputs[2] preds['f_tvo'] = outputs[3] - elif self.det_algorithm == 'DB': + elif self.det_algorithm in ['DB', 'PSE']: preds['maps'] = outputs[0] else: raise NotImplementedError @@ -217,7 +225,9 @@ class TextDetector(object): #self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] - if self.det_algorithm == "SAST" and self.det_sast_polygon: + if (self.det_algorithm == "SAST" and + self.det_sast_polygon) or (self.det_algorithm == "PSE" and + self.det_pse_box_type == 'poly'): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 466f824c29d5493f56e56bc3243fc907aec24d60..538f55c42b223f9741c5c7006dd7d1478ce1920b 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -63,6 +63,13 @@ def init_args(): parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) parser.add_argument("--det_sast_polygon", type=str2bool, default=False) + # PSE parmas + parser.add_argument("--det_pse_thresh", type=float, default=0) + parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) + parser.add_argument("--det_pse_min_area", type=float, default=16) + parser.add_argument("--det_pse_box_type", type=str, default='box') + parser.add_argument("--det_pse_scale", type=int, default=1) + # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) diff --git a/tools/program.py b/tools/program.py index d6d47d047b5bc65024b5db866b67ee260f8f8bb3..f484cf4a1f512107bd755a6b446935751563bfcb 100755 --- a/tools/program.py +++ b/tools/program.py @@ -402,7 +402,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR' + 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'