gen_table_mask.py 9.1 KB
Newer Older
W
WenmuZhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import sys
import six
import cv2
import numpy as np


class GenTableMask(object):
    """ gen table mask """

    def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
        self.shrink_h_max = 5
        self.shrink_w_max = 5
        self.mask_type = mask_type
        
    def projection(self, erosion, h, w, spilt_threshold=0):
        # 水平投影
        projection_map = np.ones_like(erosion)
        project_val_array = [0 for _ in range(0, h)]

        for j in range(0, h):
            for i in range(0, w):
                if erosion[j, i] == 255:
                    project_val_array[j] += 1
        # 根据数组,获取切割点
        start_idx = 0  # 记录进入字符区的索引
        end_idx = 0  # 记录进入空白区域的索引
        in_text = False  # 是否遍历到了字符区内
        box_list = []
        for i in range(len(project_val_array)):
            if in_text == False and project_val_array[i] > spilt_threshold:  # 进入字符区了
                in_text = True
                start_idx = i
            elif project_val_array[i] <= spilt_threshold and in_text == True:  # 进入空白区了
                end_idx = i
                in_text = False
                if end_idx - start_idx <= 2:
                    continue
                box_list.append((start_idx, end_idx + 1))

        if in_text:
            box_list.append((start_idx, h - 1))
        # 绘制投影直方图
        for j in range(0, h):
            for i in range(0, project_val_array[j]):
                projection_map[j, i] = 0
        return box_list, projection_map

    def projection_cx(self, box_img):
        box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
        h, w = box_gray_img.shape
        # 灰度图片进行二值化处理
        ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
        # 纵向腐蚀
        if h < w:
            kernel = np.ones((2, 1), np.uint8)
            erode = cv2.erode(thresh1, kernel, iterations=1)
        else:
            erode = thresh1
        # 水平膨胀
        kernel = np.ones((1, 5), np.uint8)
        erosion = cv2.dilate(erode, kernel, iterations=1)
        # 水平投影
        projection_map = np.ones_like(erosion)
        project_val_array = [0 for _ in range(0, h)]

        for j in range(0, h):
            for i in range(0, w):
                if erosion[j, i] == 255:
                    project_val_array[j] += 1
        # 根据数组,获取切割点
        start_idx = 0  # 记录进入字符区的索引
        end_idx = 0  # 记录进入空白区域的索引
        in_text = False  # 是否遍历到了字符区内
        box_list = []
        spilt_threshold = 0
        for i in range(len(project_val_array)):
            if in_text == False and project_val_array[i] > spilt_threshold:  # 进入字符区了
                in_text = True
                start_idx = i
            elif project_val_array[i] <= spilt_threshold and in_text == True:  # 进入空白区了
                end_idx = i
                in_text = False
                if end_idx - start_idx <= 2:
                    continue
                box_list.append((start_idx, end_idx + 1))

        if in_text:
            box_list.append((start_idx, h - 1))
        # 绘制投影直方图
        for j in range(0, h):
            for i in range(0, project_val_array[j]):
                projection_map[j, i] = 0
        split_bbox_list = []
        if len(box_list) > 1:
            for i, (h_start, h_end) in enumerate(box_list):
                if i == 0:
                    h_start = 0
                if i == len(box_list):
                    h_end = h
                word_img = erosion[h_start:h_end + 1, :]
                word_h, word_w = word_img.shape
                w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
                w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
                if h_start > 0:
                    h_start -= 1
                h_end += 1
                word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
                split_bbox_list.append([w_start, h_start, w_end, h_end])
        else:
            split_bbox_list.append([0, 0, w, h])
        return split_bbox_list

    def shrink_bbox(self, bbox):
        left, top, right, bottom = bbox
        sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
        sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
        left_new = left + sh_w
        right_new = right - sh_w
        top_new = top + sh_h
        bottom_new = bottom - sh_h
        if left_new >= right_new:
            left_new = left
            right_new = right
        if top_new >= bottom_new:
            top_new = top
            bottom_new = bottom
        return [left_new, top_new, right_new, bottom_new]

    def __call__(self, data):
        img = data['image']
        cells = data['cells']
        height, width = img.shape[0:2]
        if self.mask_type == 1:
            mask_img = np.zeros((height, width), dtype=np.float32)
        else:
            mask_img = np.zeros((height, width, 3), dtype=np.float32)
        cell_num = len(cells)
        for cno in range(cell_num):
            if "bbox" in cells[cno]:
                bbox = cells[cno]['bbox']
                left, top, right, bottom = bbox
                box_img = img[top:bottom, left:right, :].copy()
                split_bbox_list = self.projection_cx(box_img)
                for sno in range(len(split_bbox_list)):
                    split_bbox_list[sno][0] += left
                    split_bbox_list[sno][1] += top
                    split_bbox_list[sno][2] += left
                    split_bbox_list[sno][3] += top

                for sno in range(len(split_bbox_list)):
                    left, top, right, bottom = split_bbox_list[sno]
                    left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
                    if self.mask_type == 1:
                        mask_img[top:bottom, left:right] = 1.0
                        data['mask_img'] = mask_img
                    else:
                        mask_img[top:bottom, left:right, :] = (255, 255, 255)        
                        data['image'] = mask_img
        return data

class ResizeTableImage(object):
    def __init__(self, max_len, **kwargs):
        super(ResizeTableImage, self).__init__()
        self.max_len = max_len

    def get_img_bbox(self, cells):
        bbox_list = []
        if len(cells) == 0:
            return bbox_list
        cell_num = len(cells)
        for cno in range(cell_num):
            if "bbox" in cells[cno]:
                bbox = cells[cno]['bbox']
                bbox_list.append(bbox)
        return bbox_list

    def resize_img_table(self, img, bbox_list, max_len):
        height, width = img.shape[0:2]
        ratio = max_len / (max(height, width) * 1.0)
        resize_h = int(height * ratio)
        resize_w = int(width * ratio)
        img_new = cv2.resize(img, (resize_w, resize_h))
        bbox_list_new = []
        for bno in range(len(bbox_list)):
            left, top, right, bottom = bbox_list[bno].copy()
            left = int(left * ratio)
            top = int(top * ratio)
            right = int(right * ratio)
            bottom = int(bottom * ratio)
            bbox_list_new.append([left, top, right, bottom])
        return img_new, bbox_list_new
    
    def __call__(self, data):
        img = data['image']
        if 'cells' not in data:
            cells = []
        else:
            cells = data['cells']
        bbox_list = self.get_img_bbox(cells)
        img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
        data['image'] = img_new
        cell_num = len(cells)
        bno = 0
        for cno in range(cell_num):
            if "bbox" in data['cells'][cno]:
                data['cells'][cno]['bbox'] = bbox_list_new[bno]
                bno += 1
        data['max_len'] = self.max_len
        return data

class PaddingTableImage(object):
    def __init__(self, **kwargs):
        super(PaddingTableImage, self).__init__()
    
    def __call__(self, data):
        img = data['image']
        max_len = data['max_len']
        padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
        height, width = img.shape[0:2]
        padding_img[0:height, 0:width, :] = img.copy()
        data['image'] = padding_img
        return data