make_pse_gt.py 3.8 KB
Newer Older
W
WenmuZhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WenmuZhou 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon

__all__ = ['MakePseGt']

W
WenmuZhou 已提交
27

W
WenmuZhou 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
class MakePseGt(object):
    def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
        self.kernel_num = kernel_num
        self.min_shrink_ratio = min_shrink_ratio
        self.size = size

    def __call__(self, data):

        image = data['image']
        text_polys = data['polys']
        ignore_tags = data['ignore_tags']

        h, w, _ = image.shape
        short_edge = min(h, w)
        if short_edge < self.size:
            # keep short_size >= self.size
            scale = self.size / short_edge
            image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
            text_polys *= scale

        gt_kernels = []
W
WenmuZhou 已提交
49
        for i in range(1, self.kernel_num + 1):
W
WenmuZhou 已提交
50
            # s1->sn, from big to small
W
WenmuZhou 已提交
51 52 53 54
            rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1
                                                          ) * i
            text_kernel, ignore_tags = self.generate_kernel(
                image.shape[0:2], rate, text_polys, ignore_tags)
W
WenmuZhou 已提交
55 56 57 58 59
            gt_kernels.append(text_kernel)

        training_mask = np.ones(image.shape[0:2], dtype='uint8')
        for i in range(text_polys.shape[0]):
            if ignore_tags[i]:
W
WenmuZhou 已提交
60 61 62
                cv2.fillPoly(training_mask,
                             text_polys[i].astype(np.int32)[np.newaxis, :, :],
                             0)
W
WenmuZhou 已提交
63 64 65 66 67 68 69 70 71 72 73

        gt_kernels = np.array(gt_kernels)
        gt_kernels[gt_kernels > 0] = 1

        data['image'] = image
        data['polys'] = text_polys
        data['gt_kernels'] = gt_kernels[0:]
        data['gt_text'] = gt_kernels[0]
        data['mask'] = training_mask.astype('float32')
        return data

W
WenmuZhou 已提交
74 75 76 77 78 79 80 81 82 83
    def generate_kernel(self,
                        img_size,
                        shrink_ratio,
                        text_polys,
                        ignore_tags=None):
        """
        Refer to part of the code:
        https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
        """

W
WenmuZhou 已提交
84 85 86 87
        h, w = img_size
        text_kernel = np.zeros((h, w), dtype=np.float32)
        for i, poly in enumerate(text_polys):
            polygon = Polygon(poly)
W
WenmuZhou 已提交
88 89
            distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (
                polygon.length + 1e-6)
W
WenmuZhou 已提交
90 91
            subject = [tuple(l) for l in poly]
            pco = pyclipper.PyclipperOffset()
W
WenmuZhou 已提交
92
            pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
W
WenmuZhou 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106
            shrinked = np.array(pco.Execute(-distance))

            if len(shrinked) == 0 or shrinked.size == 0:
                if ignore_tags is not None:
                    ignore_tags[i] = True
                continue
            try:
                shrinked = np.array(shrinked[0]).reshape(-1, 2)
            except:
                if ignore_tags is not None:
                    ignore_tags[i] = True
                continue
            cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
        return text_kernel, ignore_tags