""" # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import sys import six import cv2 import numpy as np class GenTableMask(object): """ gen table mask """ def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs): self.shrink_h_max = 5 self.shrink_w_max = 5 self.mask_type = mask_type def projection(self, erosion, h, w, spilt_threshold=0): # 水平投影 projection_map = np.ones_like(erosion) project_val_array = [0 for _ in range(0, h)] for j in range(0, h): for i in range(0, w): if erosion[j, i] == 255: project_val_array[j] += 1 # 根据数组,获取切割点 start_idx = 0 # 记录进入字符区的索引 end_idx = 0 # 记录进入空白区域的索引 in_text = False # 是否遍历到了字符区内 box_list = [] for i in range(len(project_val_array)): if in_text == False and project_val_array[ i] > spilt_threshold: # 进入字符区了 in_text = True start_idx = i elif project_val_array[ i] <= spilt_threshold and in_text == True: # 进入空白区了 end_idx = i in_text = False if end_idx - start_idx <= 2: continue box_list.append((start_idx, end_idx + 1)) if in_text: box_list.append((start_idx, h - 1)) # 绘制投影直方图 for j in range(0, h): for i in range(0, project_val_array[j]): projection_map[j, i] = 0 return box_list, projection_map def projection_cx(self, box_img): box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) h, w = box_gray_img.shape # 灰度图片进行二值化处理 ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) # 纵向腐蚀 if h < w: kernel = np.ones((2, 1), np.uint8) erode = cv2.erode(thresh1, kernel, iterations=1) else: erode = thresh1 # 水平膨胀 kernel = np.ones((1, 5), np.uint8) erosion = cv2.dilate(erode, kernel, iterations=1) # 水平投影 projection_map = np.ones_like(erosion) project_val_array = [0 for _ in range(0, h)] for j in range(0, h): for i in range(0, w): if erosion[j, i] == 255: project_val_array[j] += 1 # 根据数组,获取切割点 start_idx = 0 # 记录进入字符区的索引 end_idx = 0 # 记录进入空白区域的索引 in_text = False # 是否遍历到了字符区内 box_list = [] spilt_threshold = 0 for i in range(len(project_val_array)): if in_text == False and project_val_array[ i] > spilt_threshold: # 进入字符区了 in_text = True start_idx = i elif project_val_array[ i] <= spilt_threshold and in_text == True: # 进入空白区了 end_idx = i in_text = False if end_idx - start_idx <= 2: continue box_list.append((start_idx, end_idx + 1)) if in_text: box_list.append((start_idx, h - 1)) # 绘制投影直方图 for j in range(0, h): for i in range(0, project_val_array[j]): projection_map[j, i] = 0 split_bbox_list = [] if len(box_list) > 1: for i, (h_start, h_end) in enumerate(box_list): if i == 0: h_start = 0 if i == len(box_list): h_end = h word_img = erosion[h_start:h_end + 1, :] word_h, word_w = word_img.shape w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) w_start, w_end = w_split_list[0][0], w_split_list[-1][1] if h_start > 0: h_start -= 1 h_end += 1 word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :] split_bbox_list.append([w_start, h_start, w_end, h_end]) else: split_bbox_list.append([0, 0, w, h]) return split_bbox_list def shrink_bbox(self, bbox): left, top, right, bottom = bbox sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max) sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max) left_new = left + sh_w right_new = right - sh_w top_new = top + sh_h bottom_new = bottom - sh_h if left_new >= right_new: left_new = left right_new = right if top_new >= bottom_new: top_new = top bottom_new = bottom return [left_new, top_new, right_new, bottom_new] def __call__(self, data): img = data['image'] cells = data['cells'] height, width = img.shape[0:2] if self.mask_type == 1: mask_img = np.zeros((height, width), dtype=np.float32) else: mask_img = np.zeros((height, width, 3), dtype=np.float32) cell_num = len(cells) for cno in range(cell_num): if "bbox" in cells[cno]: bbox = cells[cno]['bbox'] left, top, right, bottom = bbox box_img = img[top:bottom, left:right, :].copy() split_bbox_list = self.projection_cx(box_img) for sno in range(len(split_bbox_list)): split_bbox_list[sno][0] += left split_bbox_list[sno][1] += top split_bbox_list[sno][2] += left split_bbox_list[sno][3] += top for sno in range(len(split_bbox_list)): left, top, right, bottom = split_bbox_list[sno] left, top, right, bottom = self.shrink_bbox( [left, top, right, bottom]) if self.mask_type == 1: mask_img[top:bottom, left:right] = 1.0 data['mask_img'] = mask_img else: mask_img[top:bottom, left:right, :] = (255, 255, 255) data['image'] = mask_img return data class ResizeTableImage(object): def __init__(self, max_len, resize_bboxes=False, infer_mode=False, **kwargs): super(ResizeTableImage, self).__init__() self.max_len = max_len self.resize_bboxes = resize_bboxes self.infer_mode = infer_mode def __call__(self, data): img = data['image'] height, width = img.shape[0:2] ratio = self.max_len / (max(height, width) * 1.0) resize_h = int(height * ratio) resize_w = int(width * ratio) resize_img = cv2.resize(img, (resize_w, resize_h)) if self.resize_bboxes and not self.infer_mode: data['bboxes'] = data['bboxes'] * ratio data['image'] = resize_img data['src_img'] = img data['shape'] = np.array([resize_h, resize_w, ratio, ratio]) data['max_len'] = self.max_len return data class PaddingTableImage(object): def __init__(self, size, **kwargs): super(PaddingTableImage, self).__init__() self.size = size def __call__(self, data): img = data['image'] pad_h, pad_w = self.size padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32) height, width = img.shape[0:2] padding_img[0:height, 0:width, :] = img.copy() data['image'] = padding_img shape = data['shape'].tolist() shape.extend([pad_h, pad_w]) data['shape'] = np.array(shape) return data