""" # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import sys import six import cv2 import numpy as np class GenTableMask(object): """ gen table mask """ def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs): self.shrink_h_max = 5 self.shrink_w_max = 5 self.mask_type = mask_type def projection(self, erosion, h, w, spilt_threshold=0): # 水平投影 projection_map = np.ones_like(erosion) project_val_array = [0 for _ in range(0, h)] for j in range(0, h): for i in range(0, w): if erosion[j, i] == 255: project_val_array[j] += 1 # 根据数组,获取切割点 start_idx = 0 # 记录进入字符区的索引 end_idx = 0 # 记录进入空白区域的索引 in_text = False # 是否遍历到了字符区内 box_list = [] for i in range(len(project_val_array)): if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 in_text = True start_idx = i elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 end_idx = i in_text = False if end_idx - start_idx <= 2: continue box_list.append((start_idx, end_idx + 1)) if in_text: box_list.append((start_idx, h - 1)) # 绘制投影直方图 for j in range(0, h): for i in range(0, project_val_array[j]): projection_map[j, i] = 0 return box_list, projection_map def projection_cx(self, box_img): box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) h, w = box_gray_img.shape # 灰度图片进行二值化处理 ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) # 纵向腐蚀 if h < w: kernel = np.ones((2, 1), np.uint8) erode = cv2.erode(thresh1, kernel, iterations=1) else: erode = thresh1 # 水平膨胀 kernel = np.ones((1, 5), np.uint8) erosion = cv2.dilate(erode, kernel, iterations=1) # 水平投影 projection_map = np.ones_like(erosion) project_val_array = [0 for _ in range(0, h)] for j in range(0, h): for i in range(0, w): if erosion[j, i] == 255: project_val_array[j] += 1 # 根据数组,获取切割点 start_idx = 0 # 记录进入字符区的索引 end_idx = 0 # 记录进入空白区域的索引 in_text = False # 是否遍历到了字符区内 box_list = [] spilt_threshold = 0 for i in range(len(project_val_array)): if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 in_text = True start_idx = i elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 end_idx = i in_text = False if end_idx - start_idx <= 2: continue box_list.append((start_idx, end_idx + 1)) if in_text: box_list.append((start_idx, h - 1)) # 绘制投影直方图 for j in range(0, h): for i in range(0, project_val_array[j]): projection_map[j, i] = 0 split_bbox_list = [] if len(box_list) > 1: for i, (h_start, h_end) in enumerate(box_list): if i == 0: h_start = 0 if i == len(box_list): h_end = h word_img = erosion[h_start:h_end + 1, :] word_h, word_w = word_img.shape w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) w_start, w_end = w_split_list[0][0], w_split_list[-1][1] if h_start > 0: h_start -= 1 h_end += 1 word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :] split_bbox_list.append([w_start, h_start, w_end, h_end]) else: split_bbox_list.append([0, 0, w, h]) return split_bbox_list def shrink_bbox(self, bbox): left, top, right, bottom = bbox sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max) sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max) left_new = left + sh_w right_new = right - sh_w top_new = top + sh_h bottom_new = bottom - sh_h if left_new >= right_new: left_new = left right_new = right if top_new >= bottom_new: top_new = top bottom_new = bottom return [left_new, top_new, right_new, bottom_new] def __call__(self, data): img = data['image'] cells = data['cells'] height, width = img.shape[0:2] if self.mask_type == 1: mask_img = np.zeros((height, width), dtype=np.float32) else: mask_img = np.zeros((height, width, 3), dtype=np.float32) cell_num = len(cells) for cno in range(cell_num): if "bbox" in cells[cno]: bbox = cells[cno]['bbox'] left, top, right, bottom = bbox box_img = img[top:bottom, left:right, :].copy() split_bbox_list = self.projection_cx(box_img) for sno in range(len(split_bbox_list)): split_bbox_list[sno][0] += left split_bbox_list[sno][1] += top split_bbox_list[sno][2] += left split_bbox_list[sno][3] += top for sno in range(len(split_bbox_list)): left, top, right, bottom = split_bbox_list[sno] left, top, right, bottom = self.shrink_bbox([left, top, right, bottom]) if self.mask_type == 1: mask_img[top:bottom, left:right] = 1.0 data['mask_img'] = mask_img else: mask_img[top:bottom, left:right, :] = (255, 255, 255) data['image'] = mask_img return data class ResizeTableImage(object): def __init__(self, max_len, **kwargs): super(ResizeTableImage, self).__init__() self.max_len = max_len def get_img_bbox(self, cells): bbox_list = [] if len(cells) == 0: return bbox_list cell_num = len(cells) for cno in range(cell_num): if "bbox" in cells[cno]: bbox = cells[cno]['bbox'] bbox_list.append(bbox) return bbox_list def resize_img_table(self, img, bbox_list, max_len): height, width = img.shape[0:2] ratio = max_len / (max(height, width) * 1.0) resize_h = int(height * ratio) resize_w = int(width * ratio) img_new = cv2.resize(img, (resize_w, resize_h)) bbox_list_new = [] for bno in range(len(bbox_list)): left, top, right, bottom = bbox_list[bno].copy() left = int(left * ratio) top = int(top * ratio) right = int(right * ratio) bottom = int(bottom * ratio) bbox_list_new.append([left, top, right, bottom]) return img_new, bbox_list_new def __call__(self, data): img = data['image'] if 'cells' not in data: cells = [] else: cells = data['cells'] bbox_list = self.get_img_bbox(cells) img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len) data['image'] = img_new cell_num = len(cells) bno = 0 for cno in range(cell_num): if "bbox" in data['cells'][cno]: data['cells'][cno]['bbox'] = bbox_list_new[bno] bno += 1 data['max_len'] = self.max_len return data class PaddingTableImage(object): def __init__(self, **kwargs): super(PaddingTableImage, self).__init__() def __call__(self, data): img = data['image'] max_len = data['max_len'] padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) height, width = img.shape[0:2] padding_img[0:height, 0:width, :] = img.copy() data['image'] = padding_img return data