diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml new file mode 100755 index 0000000000000000000000000000000000000000..9dfc0e274686398836de892dbcec318432e1155c --- /dev/null +++ b/configs/table/table_master.yml @@ -0,0 +1,138 @@ +Global: + use_gpu: true + epoch_num: 17 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/table_master/ + save_epoch_step: 17 + # evaluation is run every 400 iterations after the 0th iteration + eval_batch_step: [0, 400] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: ppstructure/docs/table/table.jpg + save_res_path: output/table_master + # for data or label process + character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt + infer_mode: False + max_text_length: 500 + process_total_num: 0 + process_cut_num: 0 + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: MultiStepDecay + learning_rate: 0.001 + milestones: [12, 15] + gamma: 0.1 + warmup_epoch: 0.02 + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: table + algorithm: TableMaster + Backbone: + name: TableResNetExtra + gcb_config: + ratio: 0.0625 + headers: 1 + att_scale: False + fusion_type: channel_add + layers: [False, True, True, True] + layers: [1,2,5,3] + Head: + name: TableMasterHead + hidden_size: 512 + headers: 8 + dropout: 0 + d_ff: 2024 + max_text_length: 500 + +Loss: + name: TableMasterLoss + ignore_index: 42 # set to len of dict + 3 + +PostProcess: + name: TableMasterLabelDecode + box_shape: pad + +Metric: + name: TableMetric + main_indicator: acc + compute_bbox_metric: true # cost many time, set False for training + +Train: + dataset: + name: PubTabDataSet + data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/train/ + label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_train.jsonl] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - TableMasterLabelEncode: + learn_empty_box: False + merge_no_span_structure: True + replace_empty_cell_token: True + - ResizeTableImage: + max_len: 480 + resize_bboxes: True + - PaddingTableImage: + size: [480, 480] + - TableBoxEncode: + use_xywh: true + - NormalizeImage: + scale: 1./255. + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks','shape'] + loader: + shuffle: True + batch_size_per_card: 8 + drop_last: True + num_workers: 1 + +Eval: + dataset: + name: PubTabDataSet + data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/ + label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - TableMasterLabelEncode: + learn_empty_box: False + merge_no_span_structure: True + replace_empty_cell_token: True + - ResizeTableImage: + max_len: 480 + resize_bboxes: True + - PaddingTableImage: + size: [ 480, 480 ] + - TableBoxEncode: + use_xywh: true + - NormalizeImage: + scale: 1./255. + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks','shape' ] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 2 + num_workers: 8 diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index 1a91ea95afb4ff91d3fd68fe0df6afaac9304661..9159addc3151626f6b26a2a6bcd3eb1ef6937232 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -4,21 +4,20 @@ Global: log_smooth_window: 20 print_batch_step: 5 save_model_dir: ./output/table_mv3/ - save_epoch_step: 3 + save_epoch_step: 400 # evaluation is run every 400 iterations after the 0th iteration eval_batch_step: [0, 400] cal_metric_during_train: True pretrained_model: - checkpoints: + checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/table/table.jpg + infer_img: ppstructure/docs/table/table.jpg + save_res_path: output/table_mv3 # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en - max_text_length: 100 - max_elem_length: 800 - max_cell_num: 500 + max_text_length: 500 infer_mode: False process_total_num: 0 process_cut_num: 0 @@ -44,11 +43,8 @@ Architecture: Head: name: TableAttentionHead hidden_size: 256 - l2_decay: 0.00001 loc_type: 2 - max_text_length: 100 - max_elem_length: 800 - max_cell_num: 500 + max_text_length: 500 Loss: name: TableAttentionLoss @@ -61,6 +57,7 @@ PostProcess: Metric: name: TableMetric main_indicator: acc + compute_bbox_metric: False # cost many time, set False for training Train: dataset: @@ -71,18 +68,23 @@ Train: - DecodeImage: # load image img_mode: BGR channel_first: False + - TableLabelEncode: + learn_empty_box: False + merge_no_span_structure: False + replace_empty_cell_token: False + - TableBoxEncode: - ResizeTableImage: max_len: 488 - - TableLabelEncode: - NormalizeImage: scale: 1./255. mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] order: 'hwc' - PaddingTableImage: + size: [488, 488] - ToCHWImage: - KeepKeys: - keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] loader: shuffle: True batch_size_per_card: 32 @@ -92,24 +94,29 @@ Train: Eval: dataset: name: PubTabDataSet - data_dir: train_data/table/pubtabnet/val/ - label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl + data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/ + label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl] transforms: - DecodeImage: # load image img_mode: BGR channel_first: False + - TableLabelEncode: + learn_empty_box: False + merge_no_span_structure: False + replace_empty_cell_token: False + - TableBoxEncode: - ResizeTableImage: max_len: 488 - - TableLabelEncode: - NormalizeImage: scale: 1./255. mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] order: 'hwc' - PaddingTableImage: + size: [488, 488] - ToCHWImage: - KeepKeys: - keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] loader: shuffle: False drop_last: False diff --git a/ppocr/data/imaug/gen_table_mask.py b/ppocr/data/imaug/gen_table_mask.py index 08e35d5d1df7f9663b4e008451308d0ee409cf5a..8d139190ab4b22c553036ddc8e31cfbc7ec3423d 100644 --- a/ppocr/data/imaug/gen_table_mask.py +++ b/ppocr/data/imaug/gen_table_mask.py @@ -32,7 +32,7 @@ class GenTableMask(object): self.shrink_h_max = 5 self.shrink_w_max = 5 self.mask_type = mask_type - + def projection(self, erosion, h, w, spilt_threshold=0): # 水平投影 projection_map = np.ones_like(erosion) @@ -48,10 +48,12 @@ class GenTableMask(object): in_text = False # 是否遍历到了字符区内 box_list = [] for i in range(len(project_val_array)): - if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 + if in_text == False and project_val_array[ + i] > spilt_threshold: # 进入字符区了 in_text = True start_idx = i - elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 + elif project_val_array[ + i] <= spilt_threshold and in_text == True: # 进入空白区了 end_idx = i in_text = False if end_idx - start_idx <= 2: @@ -70,7 +72,8 @@ class GenTableMask(object): box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) h, w = box_gray_img.shape # 灰度图片进行二值化处理 - ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) + ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, + cv2.THRESH_BINARY_INV) # 纵向腐蚀 if h < w: kernel = np.ones((2, 1), np.uint8) @@ -95,10 +98,12 @@ class GenTableMask(object): box_list = [] spilt_threshold = 0 for i in range(len(project_val_array)): - if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 + if in_text == False and project_val_array[ + i] > spilt_threshold: # 进入字符区了 in_text = True start_idx = i - elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 + elif project_val_array[ + i] <= spilt_threshold and in_text == True: # 进入空白区了 end_idx = i in_text = False if end_idx - start_idx <= 2: @@ -120,7 +125,8 @@ class GenTableMask(object): h_end = h word_img = erosion[h_start:h_end + 1, :] word_h, word_w = word_img.shape - w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) + w_split_list, w_projection_map = self.projection(word_img.T, + word_w, word_h) w_start, w_end = w_split_list[0][0], w_split_list[-1][1] if h_start > 0: h_start -= 1 @@ -170,75 +176,54 @@ class GenTableMask(object): for sno in range(len(split_bbox_list)): left, top, right, bottom = split_bbox_list[sno] - left, top, right, bottom = self.shrink_bbox([left, top, right, bottom]) + left, top, right, bottom = self.shrink_bbox( + [left, top, right, bottom]) if self.mask_type == 1: mask_img[top:bottom, left:right] = 1.0 data['mask_img'] = mask_img else: - mask_img[top:bottom, left:right, :] = (255, 255, 255) + mask_img[top:bottom, left:right, :] = (255, 255, 255) data['image'] = mask_img return data + class ResizeTableImage(object): - def __init__(self, max_len, **kwargs): + def __init__(self, max_len, resize_bboxes=False, infer_mode=False, + **kwargs): super(ResizeTableImage, self).__init__() self.max_len = max_len + self.resize_bboxes = resize_bboxes + self.infer_mode = infer_mode - def get_img_bbox(self, cells): - bbox_list = [] - if len(cells) == 0: - return bbox_list - cell_num = len(cells) - for cno in range(cell_num): - if "bbox" in cells[cno]: - bbox = cells[cno]['bbox'] - bbox_list.append(bbox) - return bbox_list - - def resize_img_table(self, img, bbox_list, max_len): + def __call__(self, data): + img = data['image'] height, width = img.shape[0:2] - ratio = max_len / (max(height, width) * 1.0) + ratio = self.max_len / (max(height, width) * 1.0) resize_h = int(height * ratio) resize_w = int(width * ratio) - img_new = cv2.resize(img, (resize_w, resize_h)) - bbox_list_new = [] - for bno in range(len(bbox_list)): - left, top, right, bottom = bbox_list[bno].copy() - left = int(left * ratio) - top = int(top * ratio) - right = int(right * ratio) - bottom = int(bottom * ratio) - bbox_list_new.append([left, top, right, bottom]) - return img_new, bbox_list_new - - def __call__(self, data): - img = data['image'] - if 'cells' not in data: - cells = [] - else: - cells = data['cells'] - bbox_list = self.get_img_bbox(cells) - img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len) - data['image'] = img_new - cell_num = len(cells) - bno = 0 - for cno in range(cell_num): - if "bbox" in data['cells'][cno]: - data['cells'][cno]['bbox'] = bbox_list_new[bno] - bno += 1 + resize_img = cv2.resize(img, (resize_w, resize_h)) + if self.resize_bboxes and not self.infer_mode: + data['bboxes'] = data['bboxes'] * ratio + data['image'] = resize_img + data['src_img'] = img + data['shape'] = np.array([resize_h, resize_w, ratio, ratio]) data['max_len'] = self.max_len return data + class PaddingTableImage(object): - def __init__(self, **kwargs): + def __init__(self, size, **kwargs): super(PaddingTableImage, self).__init__() - + self.size = size + def __call__(self, data): img = data['image'] - max_len = data['max_len'] - padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) + pad_h, pad_w = self.size + padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32) height, width = img.shape[0:2] padding_img[0:height, 0:width, :] = img.copy() data['image'] = padding_img + shape = data['shape'].tolist() + shape.extend([pad_h, pad_w]) + data['shape'] = np.array(shape) return data - \ No newline at end of file diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 02a5187dad27b76d04e866de45333d79383c1347..a55869a641f8b36a85b1771d487f04c60124651a 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -443,7 +443,9 @@ class KieLabelEncode(object): elif 'key_cls' in anno.keys(): labels.append(anno['key_cls']) else: - raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.") + raise ValueError( + "Cannot found 'key_cls' in ann.keys(), please check your training annotation." + ) edges.append(ann.get('edge', 0)) ann_infos = dict( image=data['image'], @@ -580,171 +582,197 @@ class SRNLabelEncode(BaseRecLabelEncode): return idx -class TableLabelEncode(object): +class TableLabelEncode(AttnLabelEncode): """ Convert between text-label and text-index """ def __init__(self, max_text_length, - max_elem_length, - max_cell_num, character_dict_path, - span_weight=1.0, + replace_empty_cell_token=False, + merge_no_span_structure=False, + learn_empty_box=False, + point_num=4, **kwargs): - self.max_text_length = max_text_length - self.max_elem_length = max_elem_length - self.max_cell_num = max_cell_num - list_character, list_elem = self.load_char_elem_dict( - character_dict_path) - list_character = self.add_special_char(list_character) - list_elem = self.add_special_char(list_elem) - self.dict_character = {} - for i, char in enumerate(list_character): - self.dict_character[char] = i - self.dict_elem = {} - for i, elem in enumerate(list_elem): - self.dict_elem[elem] = i - self.span_weight = span_weight - - def load_char_elem_dict(self, character_dict_path): - list_character = [] - list_elem = [] + self.max_text_len = max_text_length + self.lower = False + self.learn_empty_box = learn_empty_box + self.merge_no_span_structure = merge_no_span_structure + self.replace_empty_cell_token = replace_empty_cell_token + + dict_character = [] with open(character_dict_path, "rb") as fin: lines = fin.readlines() - substr = lines[0].decode('utf-8').strip("\r\n").split("\t") - character_num = int(substr[0]) - elem_num = int(substr[1]) - for cno in range(1, 1 + character_num): - character = lines[cno].decode('utf-8').strip("\r\n") - list_character.append(character) - for eno in range(1 + character_num, 1 + character_num + elem_num): - elem = lines[eno].decode('utf-8').strip("\r\n") - list_elem.append(elem) - return list_character, list_elem - - def add_special_char(self, list_character): - self.beg_str = "sos" - self.end_str = "eos" - list_character = [self.beg_str] + list_character + [self.end_str] - return list_character + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + dict_character.append(line) + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.idx2char = {v: k for k, v in self.dict.items()} - def get_span_idx_list(self): - span_idx_list = [] - for elem in self.dict_elem: - if 'span' in elem: - span_idx_list.append(self.dict_elem[elem]) - return span_idx_list + self.character = dict_character + self.point_num = point_num + self.pad_idx = self.dict[self.beg_str] + self.start_idx = self.dict[self.beg_str] + self.end_idx = self.dict[self.end_str] + + self.td_token = ['', '', ''] + self.empty_bbox_token_dict = { + "[]": '', + "[' ']": '', + "['', ' ', '']": '', + "['\\u2028', '\\u2028']": '', + "['', ' ', '']": '', + "['', '']": '', + "['', ' ', '']": '', + "['', '', '', '']": '', + "['', '', ' ', '', '']": '', + "['', '']": '', + "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": + '', + } + + @property + def _max_text_len(self): + return self.max_text_len + 2 def __call__(self, data): cells = data['cells'] - structure = data['structure']['tokens'] - structure = self.encode(structure, 'elem') + structure = data['structure'] + if self.merge_no_span_structure: + structure = self._merge_no_span_structure(structure) + if self.replace_empty_cell_token: + structure = self._replace_empty_cell_token(structure, cells) + # remove empty token and add " " to span token + new_structure = [] + for token in structure: + if token != '': + if 'span' in token and token[0] != ' ': + token = ' ' + token + new_structure.append(token) + # encode structure + structure = self.encode(new_structure) if structure is None: return None - elem_num = len(structure) - structure = [0] + structure + [len(self.dict_elem) - 1] - structure = structure + [0] * (self.max_elem_length + 2 - len(structure) - ) + + structure = [self.start_idx] + structure + [self.end_idx + ] # add sos abd eos + structure = structure + [self.pad_idx] * (self._max_text_len - + len(structure)) # pad structure = np.array(structure) data['structure'] = structure - elem_char_idx1 = self.dict_elem[''] - elem_char_idx2 = self.dict_elem[' 0: - span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) - span_weight = min(max(span_weight, 1.0), self.span_weight) - for cno in range(len(cells)): - if 'bbox' in cells[cno]: - bbox = cells[cno]['bbox'].copy() - bbox[0] = bbox[0] * 1.0 / img_width - bbox[1] = bbox[1] * 1.0 / img_height - bbox[2] = bbox[2] * 1.0 / img_width - bbox[3] = bbox[3] * 1.0 / img_height - td_idx = td_idx_list[cno] - bbox_list[td_idx] = bbox - bbox_list_mask[td_idx] = 1.0 - cand_span_idx = td_idx + 1 - if cand_span_idx < (self.max_elem_length + 2): - if structure[cand_span_idx] in span_idx_list: - structure_mask[cand_span_idx] = span_weight - - data['bbox_list'] = bbox_list - data['bbox_list_mask'] = bbox_list_mask - data['structure_mask'] = structure_mask - char_beg_idx = self.get_beg_end_flag_idx('beg', 'char') - char_end_idx = self.get_beg_end_flag_idx('end', 'char') - elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') - elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') - data['sp_tokens'] = np.array([ - char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx, - elem_char_idx1, elem_char_idx2, self.max_text_length, - self.max_elem_length, self.max_cell_num, elem_num - ]) - return data - def encode(self, text, char_or_elem): - """convert text-label into text-index. - """ - if char_or_elem == "char": - max_len = self.max_text_length - current_dict = self.dict_character - else: - max_len = self.max_elem_length - current_dict = self.dict_elem - if len(text) > max_len: + if len(structure) > self._max_text_len: return None - if len(text) == 0: - if char_or_elem == "char": - return [self.dict_character['space']] - else: - return None - text_list = [] - for char in text: - if char not in current_dict: - return None - text_list.append(current_dict[char]) - if len(text_list) == 0: - if char_or_elem == "char": - return [self.dict_character['space']] - else: - return None - return text_list - def get_ignored_tokens(self, char_or_elem): - beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) - end_idx = self.get_beg_end_flag_idx("end", char_or_elem) - return [beg_idx, end_idx] + # encode box + bboxes = np.zeros( + (self._max_text_len, self.point_num), dtype=np.float32) + bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32) + + bbox_idx = 0 + for i, token in enumerate(structure): + if self.idx2char[token] in self.td_token: + if 'bbox' in cells[bbox_idx]: + bbox = cells[bbox_idx]['bbox'].copy() + bbox = np.array(bbox, dtype=np.float32).reshape(-1) + bboxes[i] = bbox + bbox_masks[i] = 1.0 + if self.learn_empty_box: + bbox_masks[i] = 1.0 + bbox_idx += 1 + data['bboxes'] = bboxes + data['bbox_masks'] = bbox_masks + return data - def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): - if char_or_elem == "char": - if beg_or_end == "beg": - idx = np.array(self.dict_character[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict_character[self.end_str]) - else: - assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ - % beg_or_end - elif char_or_elem == "elem": - if beg_or_end == "beg": - idx = np.array(self.dict_elem[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict_elem[self.end_str]) + def _merge_no_span_structure(self, structure): + new_structure = [] + i = 0 + while i < len(structure): + token = structure[i] + if token == '': + token = '' + i += 1 + new_structure.append(token) + i += 1 + return new_structure + + def _replace_empty_cell_token(self, token_list, cells): + bbox_idx = 0 + add_empty_bbox_token_list = [] + for token in token_list: + if token in ['', '']: + if 'bbox' not in cells[bbox_idx].keys(): + content = str(cells[bbox_idx]['tokens']) + token = self.empty_bbox_token_dict[content] + add_empty_bbox_token_list.append(token) + bbox_idx += 1 else: - assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ - % beg_or_end - else: - assert False, "Unsupport type %s in char_or_elem" \ - % char_or_elem - return idx + add_empty_bbox_token_list.append(token) + return add_empty_bbox_token_list + + +class TableMasterLabelEncode(TableLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path, + replace_empty_cell_token=False, + merge_no_span_structure=False, + learn_empty_box=False, + point_num=4, + **kwargs): + super(TableMasterLabelEncode, self).__init__( + max_text_length, character_dict_path, replace_empty_cell_token, + merge_no_span_structure, learn_empty_box, point_num, **kwargs) + + @property + def _max_text_len(self): + return self.max_text_len + + def add_special_char(self, dict_character): + self.beg_str = '' + self.end_str = '' + self.unknown_str = '' + self.pad_str = '' + dict_character = dict_character + dict_character = dict_character + [ + self.unknown_str, self.beg_str, self.end_str, self.pad_str + ] + return dict_character + + +class TableBoxEncode(object): + def __init__(self, use_xywh=False, **kwargs): + self.use_xywh = use_xywh + + def __call__(self, data): + img_height, img_width = data['image'].shape[:2] + bboxes = data['bboxes'] + if self.use_xywh and bboxes.shape[1] == 4: + bboxes = self.xyxy2xywh(bboxes) + bboxes[:, 0::2] /= img_width + bboxes[:, 1::2] /= img_height + data['bboxes'] = bboxes + return data + + def xyxy2xywh(self, bboxes): + """ + Convert coord (x1,y1,x2,y2) to (x,y,w,h). + where (x1,y1) is top-left, (x2,y2) is bottom-right. + (x,y) is bbox center and (w,h) is width and height. + :param bboxes: (x1, y1, x2, y2) + :return: + """ + new_bboxes = np.empty_like(bboxes) + new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center + new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center + new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width + new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height + return new_bboxes class SARLabelEncode(BaseRecLabelEncode): @@ -1030,7 +1058,6 @@ class MultiLabelEncode(BaseRecLabelEncode): use_space_char, **kwargs) def __call__(self, data): - data_ctc = copy.deepcopy(data) data_sar = copy.deepcopy(data) data_out = dict() diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py index 671cda76fb4c36f3ac6bcc7da5a7fc4de241c0e2..105f28db420631e5b6b2e527b5a6536e03d18f7d 100644 --- a/ppocr/data/pubtab_dataset.py +++ b/ppocr/data/pubtab_dataset.py @@ -16,6 +16,7 @@ import os import random from paddle.io import Dataset import json +from copy import deepcopy from .imaug import transform, create_operators @@ -29,33 +30,63 @@ class PubTabDataSet(Dataset): dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] - label_file_path = dataset_config.pop('label_file_path') + label_file_list = dataset_config.pop('label_file_list') + data_source_num = len(label_file_list) + ratio_list = dataset_config.get("ratio_list", [1.0]) + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * int(data_source_num) + + assert len( + ratio_list + ) == data_source_num, "The length of ratio_list should be the same as the file_list." self.data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] - self.do_hard_select = False - if 'hard_select' in loader_config: - self.do_hard_select = loader_config['hard_select'] - self.hard_prob = loader_config['hard_prob'] - if self.do_hard_select: - self.img_select_prob = self.load_hard_select_prob() - self.table_select_type = None - if 'table_select_type' in loader_config: - self.table_select_type = loader_config['table_select_type'] - self.table_select_prob = loader_config['table_select_prob'] self.seed = seed - logger.info("Initialize indexs of datasets:%s" % label_file_path) - with open(label_file_path, "rb") as f: - self.data_lines = f.readlines() - self.data_idx_order_list = list(range(len(self.data_lines))) - if mode.lower() == "train": + self.mode = mode.lower() + logger.info("Initialize indexs of datasets:%s" % label_file_list) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list) + # self.check(config['Global']['max_text_length']) + + if mode.lower() == "train" and self.do_shuffle: self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) - - ratio_list = dataset_config.get("ratio_list", [1.0]) self.need_reset = True in [x < 1 for x in ratio_list] + def get_image_info_list(self, file_list, ratio_list): + if isinstance(file_list, str): + file_list = [file_list] + data_lines = [] + for idx, file in enumerate(file_list): + with open(file, "rb") as f: + lines = f.readlines() + if self.mode == "train" or ratio_list[idx] < 1.0: + random.seed(self.seed) + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) + data_lines.extend(lines) + return data_lines + + def check(self, max_text_length): + data_lines = [] + for line in self.data_lines: + data_line = line.decode('utf-8').strip("\n") + info = json.loads(data_line) + file_name = info['filename'] + cells = info['html']['cells'].copy() + structure = info['html']['structure']['tokens'].copy() + + img_path = os.path.join(self.data_dir, file_name) + if not os.path.exists(img_path): + self.logger.warning("{} does not exist!".format(img_path)) + continue + if len(structure) == 0 or len(structure) > max_text_length: + continue + # data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name} + data_lines.append(line) + self.data_lines = data_lines + def shuffle_data_random(self): if self.do_shuffle: random.seed(self.seed) @@ -68,47 +99,34 @@ class PubTabDataSet(Dataset): data_line = data_line.decode('utf-8').strip("\n") info = json.loads(data_line) file_name = info['filename'] - select_flag = True - if self.do_hard_select: - prob = self.img_select_prob[file_name] - if prob < random.uniform(0, 1): - select_flag = False - - if self.table_select_type: - structure = info['html']['structure']['tokens'].copy() - structure_str = ''.join(structure) - table_type = "simple" - if 'colspan' in structure_str or 'rowspan' in structure_str: - table_type = "complex" - if table_type == "complex": - if self.table_select_prob < random.uniform(0, 1): - select_flag = False - - if select_flag: - cells = info['html']['cells'].copy() - structure = info['html']['structure'].copy() - img_path = os.path.join(self.data_dir, file_name) - data = { - 'img_path': img_path, - 'cells': cells, - 'structure': structure - } - if not os.path.exists(img_path): - raise Exception("{} does not exist!".format(img_path)) - with open(data['img_path'], 'rb') as f: - img = f.read() - data['image'] = img - outs = transform(data, self.ops) - else: - outs = None - except Exception as e: + cells = info['html']['cells'].copy() + structure = info['html']['structure']['tokens'].copy() + + img_path = os.path.join(self.data_dir, file_name) + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + data = { + 'img_path': img_path, + 'cells': cells, + 'structure': structure, + 'file_name': file_name + } + + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + outs = transform(data, self.ops) + except: + import traceback + err = traceback.format_exc() self.logger.error( - "When parsing line {}, error happened with msg: {}".format( - data_line, e)) + "When parsing line {}, error happened with msg: {}".format(err)) outs = None if outs is None: - return self.__getitem__(np.random.randint(self.__len__())) + rnd_idx = np.random.randint(self.__len__( + )) if self.mode == "train" else (idx + 1) % self.__len__() + return self.__getitem__(rnd_idx) return outs def __len__(self): - return len(self.data_idx_order_list) + return len(self.data_lines) diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index de8419b7c1cf6a30ab7195a1cbcbb10a5e52642d..0f208007b193888b12919547a02c7ea074f01c90 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss # table loss from .table_att_loss import TableAttentionLoss - +from .table_master_loss import TableMasterLoss # vqa token loss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss @@ -61,7 +61,8 @@ def build_loss(config): 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', - 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' + 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', + 'TableMasterLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/table_master_loss.py b/ppocr/losses/table_master_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..50a773dd9a9c39d6e7c323011847f8362414a43a --- /dev/null +++ b/ppocr/losses/table_master_loss.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle import nn + + +class TableMasterLoss(nn.Layer): + def __init__(self, ignore_index=-1): + super(TableMasterLoss, self).__init__() + self.structure_loss = nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction='mean') + self.box_loss = nn.L1Loss(reduction='sum') + self.eps = 1e-12 + + def forward(self, predicts, batch): + # structure_loss + structure_probs = predicts['structure_probs'] + structure_targets = batch[1] + structure_targets = structure_targets[:, 1:] + structure_probs = structure_probs.reshape( + [-1, structure_probs.shape[-1]]) + structure_targets = structure_targets.reshape([-1]) + + structure_loss = self.structure_loss(structure_probs, structure_targets) + structure_loss = structure_loss.mean() + losses = dict(structure_loss=structure_loss) + + # box loss + bboxes_preds = predicts['loc_preds'] + bboxes_targets = batch[2][:, 1:, :] + bbox_masks = batch[3][:, 1:] + # mask empty-bbox or non-bbox structure token's bbox. + + masked_bboxes_preds = bboxes_preds * bbox_masks + masked_bboxes_targets = bboxes_targets * bbox_masks + + # horizon loss (x and width) + horizon_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 0::2], + masked_bboxes_targets[:, :, 0::2]) + horizon_loss = horizon_sum_loss / (bbox_masks.sum() + self.eps) + # vertical loss (y and height) + vertical_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 1::2], + masked_bboxes_targets[:, :, 1::2]) + vertical_loss = vertical_sum_loss / (bbox_masks.sum() + self.eps) + + horizon_loss = horizon_loss.mean() + vertical_loss = vertical_loss.mean() + all_loss = structure_loss + horizon_loss + vertical_loss + losses.update({ + 'loss': all_loss, + 'horizon_bbox_loss': horizon_loss, + 'vertical_bbox_loss': vertical_loss + }) + return losses diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py index ca4d6474202b4e85cadf86ccb2fe2726c7fa9aeb..17f3dc92b27cda3e9a19dea2a3bf72988c00b415 100644 --- a/ppocr/metrics/table_metric.py +++ b/ppocr/metrics/table_metric.py @@ -12,29 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np +from ppocr.metrics.det_metric import DetMetric -class TableMetric(object): - def __init__(self, main_indicator='acc', **kwargs): +class TableStructureMetric(object): + def __init__(self, main_indicator='acc', eps=1e-6, **kwargs): self.main_indicator = main_indicator - self.eps = 1e-5 + self.eps = eps self.reset() - def __call__(self, pred, batch, *args, **kwargs): - structure_probs = pred['structure_probs'].numpy() - structure_labels = batch[1] + def __call__(self, pred_label, batch=None, *args, **kwargs): + preds, labels = pred_label + pred_structure_batch_list = preds['structure_batch_list'] + gt_structure_batch_list = labels['structure_batch_list'] correct_num = 0 all_num = 0 - structure_probs = np.argmax(structure_probs, axis=2) - structure_labels = structure_labels[:, 1:] - batch_size = structure_probs.shape[0] - for bno in range(batch_size): - all_num += 1 - if (structure_probs[bno] == structure_labels[bno]).all(): + for (pred, pred_conf), target in zip(pred_structure_batch_list, + gt_structure_batch_list): + pred_str = ''.join(pred) + target_str = ''.join(target) + if pred_str == target_str: correct_num += 1 + all_num += 1 self.correct_num += correct_num self.all_num += all_num - return {'acc': correct_num * 1.0 / (all_num + self.eps), } def get_metric(self): """ @@ -49,3 +50,91 @@ class TableMetric(object): def reset(self): self.correct_num = 0 self.all_num = 0 + self.len_acc_num = 0 + self.token_nums = 0 + self.anys_dict = dict() + from collections import defaultdict + self.error_num_dict = defaultdict(int) + + +class TableMetric(object): + def __init__(self, + main_indicator='acc', + compute_bbox_metric=False, + point_num=4, + **kwargs): + """ + + @param sub_metrics: configs of sub_metric + @param main_matric: main_matric for save best_model + @param kwargs: + """ + self.structure_metric = TableStructureMetric() + self.bbox_metric = DetMetric() if compute_bbox_metric else None + self.main_indicator = main_indicator + self.point_num = point_num + self.reset() + + def __call__(self, pred_label, batch=None, *args, **kwargs): + self.structure_metric(pred_label) + if self.bbox_metric is not None: + self.bbox_metric(*self.prepare_bbox_metric_input(pred_label)) + + def prepare_bbox_metric_input(self, pred_label): + pred_bbox_batch_list = [] + gt_ignore_tags_batch_list = [] + gt_bbox_batch_list = [] + preds, labels = pred_label + + batch_num = len(preds['bbox_batch_list']) + for batch_idx in range(batch_num): + # pred + pred_bbox_list = [ + self.format_box(pred_box) + for pred_box in preds['bbox_batch_list'][batch_idx] + ] + pred_bbox_batch_list.append({'points': pred_bbox_list}) + + # gt + gt_bbox_list = [] + gt_ignore_tags_list = [] + for gt_box in labels['bbox_batch_list'][batch_idx]: + gt_bbox_list.append(self.format_box(gt_box)) + gt_ignore_tags_list.append(0) + gt_bbox_batch_list.append(gt_bbox_list) + gt_ignore_tags_batch_list.append(gt_ignore_tags_list) + + return [ + pred_bbox_batch_list, + [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list] + ] + + def get_metric(self): + structure_metric = self.structure_metric.get_metric() + if self.bbox_metric is None: + return structure_metric + bbox_metric = self.bbox_metric.get_metric() + if self.main_indicator == self.bbox_metric.main_indicator: + output = bbox_metric + for sub_key in structure_metric: + output["structure_metric_{}".format( + sub_key)] = structure_metric[sub_key] + else: + output = structure_metric + for sub_key in bbox_metric: + output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key] + return output + + def reset(self): + self.structure_metric.reset() + if self.bbox_metric is not None: + self.bbox_metric.reset() + + def format_box(self, box): + if self.point_num == 4: + x1, y1, x2, y2 = box + box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] + elif self.point_num == 8: + x1, y1, x2, y2, x3, y3, x4, y4 = box + box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] + return box diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 072d6e0f84d4126d256c26aa5baf17c9dc4e63df..2b5fd9142eaa06947439a9c0b9a64ebf28c420f6 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -20,7 +20,10 @@ def build_backbone(config, model_type): from .det_mobilenet_v3 import MobileNetV3 from .det_resnet_vd import ResNet from .det_resnet_vd_sast import ResNet_SAST - support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"] + from .table_master_resnet import TableResNetExtra + support_dict = [ + "MobileNetV3", "ResNet", "ResNet_SAST", "TableResNetExtra" + ] elif model_type == "rec" or model_type == "cls": from .rec_mobilenet_v3 import MobileNetV3 from .rec_resnet_vd import ResNet diff --git a/ppocr/modeling/backbones/table_master_resnet.py b/ppocr/modeling/backbones/table_master_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..82b4f37a7420982415f21fe50c6200aa16e58314 --- /dev/null +++ b/ppocr/modeling/backbones/table_master_resnet.py @@ -0,0 +1,369 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + gcb_config=None): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2D( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm2D(planes, momentum=0.9) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2D( + planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False) + self.bn2 = nn.BatchNorm2D(planes, momentum=0.9) + self.downsample = downsample + self.stride = stride + self.gcb_config = gcb_config + + if self.gcb_config is not None: + gcb_ratio = gcb_config['ratio'] + gcb_headers = gcb_config['headers'] + att_scale = gcb_config['att_scale'] + fusion_type = gcb_config['fusion_type'] + self.context_block = MultiAspectGCAttention( + inplanes=planes, + ratio=gcb_ratio, + headers=gcb_headers, + att_scale=att_scale, + fusion_type=fusion_type) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.gcb_config is not None: + out = self.context_block(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +def get_gcb_config(gcb_config, layer): + if gcb_config is None or not gcb_config['layers'][layer]: + return None + else: + return gcb_config + + +class TableResNetExtra(nn.Layer): + def __init__(self, layers, in_channels=3, gcb_config=None): + assert len(layers) >= 4 + + super(TableResNetExtra, self).__init__() + self.inplanes = 128 + self.conv1 = nn.Conv2D( + in_channels, + 64, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm2D(64) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2D( + 64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False) + self.bn2 = nn.BatchNorm2D(128) + self.relu2 = nn.ReLU() + + self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2) + + self.layer1 = self._make_layer( + BasicBlock, + 256, + layers[0], + stride=1, + gcb_config=get_gcb_config(gcb_config, 0)) + + self.conv3 = nn.Conv2D( + 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False) + self.bn3 = nn.BatchNorm2D(256) + self.relu3 = nn.ReLU() + + self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2) + + self.layer2 = self._make_layer( + BasicBlock, + 256, + layers[1], + stride=1, + gcb_config=get_gcb_config(gcb_config, 1)) + + self.conv4 = nn.Conv2D( + 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False) + self.bn4 = nn.BatchNorm2D(256) + self.relu4 = nn.ReLU() + + self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2) + + self.layer3 = self._make_layer( + BasicBlock, + 512, + layers[2], + stride=1, + gcb_config=get_gcb_config(gcb_config, 2)) + + self.conv5 = nn.Conv2D( + 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False) + self.bn5 = nn.BatchNorm2D(512) + self.relu5 = nn.ReLU() + + self.layer4 = self._make_layer( + BasicBlock, + 512, + layers[3], + stride=1, + gcb_config=get_gcb_config(gcb_config, 3)) + + self.conv6 = nn.Conv2D( + 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False) + self.bn6 = nn.BatchNorm2D(512) + self.relu6 = nn.ReLU() + + self.out_channels = [256, 256, 512] + + def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias_attr=False), + nn.BatchNorm2D(planes * block.expansion), ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + gcb_config=gcb_config)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + f = [] + x = self.conv1(x) # 1,64,480,480 + + x = self.bn1(x) + x = self.relu1(x) + + x = self.conv2(x) # 1,128,480,480 + x = self.bn2(x) + x = self.relu2(x) + # (48, 160) + + x = self.maxpool1(x) # 1,64,240,240 + x = self.layer1(x) + + x = self.conv3(x) # 1,256,240,240 + x = self.bn3(x) + x = self.relu3(x) + f.append(x) + # (24, 80) + + x = self.maxpool2(x) # 1,256,120,120 + x = self.layer2(x) + + x = self.conv4(x) # 1,256,120,120 + x = self.bn4(x) + x = self.relu4(x) + f.append(x) + # (12, 40) + + x = self.maxpool3(x) # 1,256,60,60 + + x = self.layer3(x) # 1,512,60,60 + x = self.conv5(x) # 1,512,60,60 + x = self.bn5(x) + x = self.relu5(x) + + x = self.layer4(x) # 1,512,60,60 + x = self.conv6(x) # 1,512,60,60 + x = self.bn6(x) + x = self.relu6(x) + f.append(x) + # (6, 40) + return f + + +class MultiAspectGCAttention(nn.Layer): + def __init__(self, + inplanes, + ratio, + headers, + pooling_type='att', + att_scale=False, + fusion_type='channel_add'): + super(MultiAspectGCAttention, self).__init__() + assert pooling_type in ['avg', 'att'] + + assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat'] + assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly + + self.headers = headers + self.inplanes = inplanes + self.ratio = ratio + self.planes = int(inplanes * ratio) + self.pooling_type = pooling_type + self.fusion_type = fusion_type + self.att_scale = False + + self.single_header_inplanes = int(inplanes / headers) + + if pooling_type == 'att': + self.conv_mask = nn.Conv2D( + self.single_header_inplanes, 1, kernel_size=1) + self.softmax = nn.Softmax(axis=2) + else: + self.avg_pool = nn.AdaptiveAvgPool2D(1) + + if fusion_type == 'channel_add': + self.channel_add_conv = nn.Sequential( + nn.Conv2D( + self.inplanes, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(), + nn.Conv2D( + self.planes, self.inplanes, kernel_size=1)) + elif fusion_type == 'channel_concat': + self.channel_concat_conv = nn.Sequential( + nn.Conv2D( + self.inplanes, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(), + nn.Conv2D( + self.planes, self.inplanes, kernel_size=1)) + # for concat + self.cat_conv = nn.Conv2D( + 2 * self.inplanes, self.inplanes, kernel_size=1) + elif fusion_type == 'channel_mul': + self.channel_mul_conv = nn.Sequential( + nn.Conv2D( + self.inplanes, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(), + nn.Conv2D( + self.planes, self.inplanes, kernel_size=1)) + + def spatial_pool(self, x): + batch, channel, height, width = x.shape + if self.pooling_type == 'att': + # [N*headers, C', H , W] C = headers * C' + x = x.reshape([ + batch * self.headers, self.single_header_inplanes, height, width + ]) + input_x = x + + # [N*headers, C', H * W] C = headers * C' + # input_x = input_x.view(batch, channel, height * width) + input_x = input_x.reshape([ + batch * self.headers, self.single_header_inplanes, + height * width + ]) + + # [N*headers, 1, C', H * W] + input_x = input_x.unsqueeze(1) + # [N*headers, 1, H, W] + context_mask = self.conv_mask(x) + # [N*headers, 1, H * W] + context_mask = context_mask.reshape( + [batch * self.headers, 1, height * width]) + + # scale variance + if self.att_scale and self.headers > 1: + context_mask = context_mask / paddle.sqrt( + self.single_header_inplanes) + + # [N*headers, 1, H * W] + context_mask = self.softmax(context_mask) + + # [N*headers, 1, H * W, 1] + context_mask = context_mask.unsqueeze(-1) + # [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1] + context = paddle.matmul(input_x, context_mask) + + # [N, headers * C', 1, 1] + context = context.reshape( + [batch, self.headers * self.single_header_inplanes, 1, 1]) + else: + # [N, C, 1, 1] + context = self.avg_pool(x) + + return context + + def forward(self, x): + # [N, C, 1, 1] + context = self.spatial_pool(x) + + out = x + + if self.fusion_type == 'channel_mul': + # [N, C, 1, 1] + channel_mul_term = F.sigmoid(self.channel_mul_conv(context)) + out = out * channel_mul_term + elif self.fusion_type == 'channel_add': + # [N, C, 1, 1] + channel_add_term = self.channel_add_conv(context) + out = out + channel_add_term + else: + # [N, C, 1, 1] + channel_concat_term = self.channel_concat_conv(context) + + # use concat + _, C1, _, _ = channel_concat_term.shape + N, C2, H, W = out.shape + + out = paddle.concat( + [out, channel_concat_term.expand([-1, -1, H, W])], axis=1) + out = self.cat_conv(out) + out = F.layer_norm(out, [self.inplanes, H, W]) + out = F.relu(out) + + return out diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 1670ea38e66baa683e6faab0ec4b12bc517f3c41..da09e25c0e14ffaeb240e394ed9bb0c137afa5fd 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -41,12 +41,13 @@ def build_head(config): from .kie_sdmgr_head import SDMGRHead from .table_att_head import TableAttentionHead + from .table_master_head import TableMasterHead support_dict = [ 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', - 'MultiHead' + 'MultiHead', 'TableMasterHead' ] #table head diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index e354f40d6518c1f7ca22e93694b1c6668fc003d2..b64713898d40d48f19b3fafc7c175153bcba09a4 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -21,6 +21,8 @@ import paddle.nn as nn import paddle.nn.functional as F import numpy as np +from .rec_att_head import AttentionGRUCell + class TableAttentionHead(nn.Layer): def __init__(self, @@ -28,17 +30,13 @@ class TableAttentionHead(nn.Layer): hidden_size, loc_type, in_max_len=488, - max_text_length=100, - max_elem_length=800, - max_cell_num=500, + max_text_length=800, **kwargs): super(TableAttentionHead, self).__init__() self.input_size = in_channels[-1] self.hidden_size = hidden_size self.elem_num = 30 self.max_text_length = max_text_length - self.max_elem_length = max_elem_length - self.max_cell_num = max_cell_num self.structure_attention_cell = AttentionGRUCell( self.input_size, hidden_size, self.elem_num, use_gru=False) @@ -50,11 +48,11 @@ class TableAttentionHead(nn.Layer): self.loc_generator = nn.Linear(hidden_size, 4) else: if self.in_max_len == 640: - self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1) + self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) elif self.in_max_len == 800: - self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1) + self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) else: - self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1) + self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) def _char_to_onehot(self, input_char, onehot_dim): @@ -77,7 +75,7 @@ class TableAttentionHead(nn.Layer): output_hiddens = [] if self.training and targets is not None: structure = targets[0] - for i in range(self.max_elem_length + 1): + for i in range(self.max_text_length + 1): elem_onehots = self._char_to_onehot( structure[:, i], onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -102,9 +100,9 @@ class TableAttentionHead(nn.Layer): elem_onehots = None outputs = None alpha = None - max_elem_length = paddle.to_tensor(self.max_elem_length) + max_text_length = paddle.to_tensor(self.max_text_length) i = 0 - while i < max_elem_length + 1: + while i < max_text_length + 1: elem_onehots = self._char_to_onehot( temp_elem, onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -128,119 +126,3 @@ class TableAttentionHead(nn.Layer): loc_preds = self.loc_generator(loc_concat) loc_preds = F.sigmoid(loc_preds) return {'structure_probs': structure_probs, 'loc_preds': loc_preds} - - -class AttentionGRUCell(nn.Layer): - def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): - super(AttentionGRUCell, self).__init__() - self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) - self.h2h = nn.Linear(hidden_size, hidden_size) - self.score = nn.Linear(hidden_size, 1, bias_attr=False) - self.rnn = nn.GRUCell( - input_size=input_size + num_embeddings, hidden_size=hidden_size) - self.hidden_size = hidden_size - - def forward(self, prev_hidden, batch_H, char_onehots): - batch_H_proj = self.i2h(batch_H) - prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1) - res = paddle.add(batch_H_proj, prev_hidden_proj) - res = paddle.tanh(res) - e = self.score(res) - alpha = F.softmax(e, axis=1) - alpha = paddle.transpose(alpha, [0, 2, 1]) - context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) - concat_context = paddle.concat([context, char_onehots], 1) - cur_hidden = self.rnn(concat_context, prev_hidden) - return cur_hidden, alpha - - -class AttentionLSTM(nn.Layer): - def __init__(self, in_channels, out_channels, hidden_size, **kwargs): - super(AttentionLSTM, self).__init__() - self.input_size = in_channels - self.hidden_size = hidden_size - self.num_classes = out_channels - - self.attention_cell = AttentionLSTMCell( - in_channels, hidden_size, out_channels, use_gru=False) - self.generator = nn.Linear(hidden_size, out_channels) - - def _char_to_onehot(self, input_char, onehot_dim): - input_ont_hot = F.one_hot(input_char, onehot_dim) - return input_ont_hot - - def forward(self, inputs, targets=None, batch_max_length=25): - batch_size = inputs.shape[0] - num_steps = batch_max_length - - hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros( - (batch_size, self.hidden_size))) - output_hiddens = [] - - if targets is not None: - for i in range(num_steps): - # one-hot vectors for a i-th char - char_onehots = self._char_to_onehot( - targets[:, i], onehot_dim=self.num_classes) - hidden, alpha = self.attention_cell(hidden, inputs, - char_onehots) - - hidden = (hidden[1][0], hidden[1][1]) - output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1)) - output = paddle.concat(output_hiddens, axis=1) - probs = self.generator(output) - - else: - targets = paddle.zeros(shape=[batch_size], dtype="int32") - probs = None - - for i in range(num_steps): - char_onehots = self._char_to_onehot( - targets, onehot_dim=self.num_classes) - hidden, alpha = self.attention_cell(hidden, inputs, - char_onehots) - probs_step = self.generator(hidden[0]) - hidden = (hidden[1][0], hidden[1][1]) - if probs is None: - probs = paddle.unsqueeze(probs_step, axis=1) - else: - probs = paddle.concat( - [probs, paddle.unsqueeze( - probs_step, axis=1)], axis=1) - - next_input = probs_step.argmax(axis=1) - - targets = next_input - - return probs - - -class AttentionLSTMCell(nn.Layer): - def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): - super(AttentionLSTMCell, self).__init__() - self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) - self.h2h = nn.Linear(hidden_size, hidden_size) - self.score = nn.Linear(hidden_size, 1, bias_attr=False) - if not use_gru: - self.rnn = nn.LSTMCell( - input_size=input_size + num_embeddings, hidden_size=hidden_size) - else: - self.rnn = nn.GRUCell( - input_size=input_size + num_embeddings, hidden_size=hidden_size) - - self.hidden_size = hidden_size - - def forward(self, prev_hidden, batch_H, char_onehots): - batch_H_proj = self.i2h(batch_H) - prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1) - res = paddle.add(batch_H_proj, prev_hidden_proj) - res = paddle.tanh(res) - e = self.score(res) - - alpha = F.softmax(e, axis=1) - alpha = paddle.transpose(alpha, [0, 2, 1]) - context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) - concat_context = paddle.concat([context, char_onehots], 1) - cur_hidden = self.rnn(concat_context, prev_hidden) - - return cur_hidden, alpha diff --git a/ppocr/modeling/heads/table_master_head.py b/ppocr/modeling/heads/table_master_head.py new file mode 100644 index 0000000000000000000000000000000000000000..acd1a9145fc0aeb6d374a8555cd09347624f0172 --- /dev/null +++ b/ppocr/modeling/heads/table_master_head.py @@ -0,0 +1,276 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import paddle +from paddle import nn +from paddle.nn import functional as F + + +class TableMasterHead(nn.Layer): + """ + Split to two transformer header at the last layer. + Cls_layer is used to structure token classification. + Bbox_layer is used to regress bbox coord. + """ + + def __init__(self, + in_channels, + out_channels=30, + headers=8, + d_ff=2048, + dropout=0, + max_text_length=500, + point_num=4, + **kwargs): + super(TableMasterHead, self).__init__() + hidden_size = in_channels[-1] + self.layers = clones( + DecoderLayer(headers, hidden_size, dropout, d_ff), 2) + self.cls_layer = clones( + DecoderLayer(headers, hidden_size, dropout, d_ff), 1) + self.bbox_layer = clones( + DecoderLayer(headers, hidden_size, dropout, d_ff), 1) + self.cls_fc = nn.Linear(hidden_size, out_channels) + self.bbox_fc = nn.Sequential( + # nn.Linear(hidden_size, hidden_size), + nn.Linear(hidden_size, point_num), + nn.Sigmoid()) + self.norm = nn.LayerNorm(hidden_size) + self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels) + self.positional_encoding = PositionalEncoding(d_model=hidden_size) + + self.SOS = out_channels - 3 + self.PAD = out_channels - 1 + self.out_channels = out_channels + self.point_num = point_num + self.max_text_length = max_text_length + + def make_mask(self, tgt): + """ + Make mask for self attention. + :param src: [b, c, h, l_src] + :param tgt: [b, l_tgt] + :return: + """ + trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3) + + tgt_len = paddle.shape(tgt)[1] + trg_sub_mask = paddle.tril( + paddle.ones( + ([tgt_len, tgt_len]), dtype=paddle.float32)) + + tgt_mask = paddle.logical_and( + trg_pad_mask.astype(paddle.float32), trg_sub_mask) + return tgt_mask.astype(paddle.float32) + + def decode(self, input, feature, src_mask, tgt_mask): + # main process of transformer decoder. + x = self.embedding(input) # x: 1*x*512, feature: 1*3600,512 + x = self.positional_encoding(x) + + # origin transformer layers + for i, layer in enumerate(self.layers): + x = layer(x, feature, src_mask, tgt_mask) + + # cls head + for layer in self.cls_layer: + cls_x = layer(x, feature, src_mask, tgt_mask) + cls_x = self.norm(cls_x) + + # bbox head + for layer in self.bbox_layer: + bbox_x = layer(x, feature, src_mask, tgt_mask) + bbox_x = self.norm(bbox_x) + return self.cls_fc(cls_x), self.bbox_fc(bbox_x) + + def greedy_forward(self, SOS, feature): + input = SOS + output = paddle.zeros( + [input.shape[0], self.max_text_length + 1, self.out_channels]) + bbox_output = paddle.zeros( + [input.shape[0], self.max_text_length + 1, self.point_num]) + max_text_length = paddle.to_tensor(self.max_text_length) + for i in range(max_text_length + 1): + target_mask = self.make_mask(input) + out_step, bbox_output_step = self.decode(input, feature, None, + target_mask) + prob = F.softmax(out_step, axis=-1) + next_word = prob.argmax(axis=2, dtype="int64") + input = paddle.concat( + [input, next_word[:, -1].unsqueeze(-1)], axis=1) + if i == self.max_text_length: + output = out_step + bbox_output = bbox_output_step + return output, bbox_output + + def forward_train(self, out_enc, targets): + # x is token of label + # feat is feature after backbone before pe. + # out_enc is feature after pe. + padded_targets = targets[0] + src_mask = None + tgt_mask = self.make_mask(padded_targets[:, :-1]) + output, bbox_output = self.decode(padded_targets[:, :-1], out_enc, + src_mask, tgt_mask) + return {'structure_probs': output, 'loc_preds': bbox_output} + + def forward_test(self, out_enc): + batch_size = out_enc.shape[0] + SOS = paddle.zeros([batch_size, 1], dtype='int64') + self.SOS + output, bbox_output = self.greedy_forward(SOS, out_enc) + # output = F.softmax(output) + return {'structure_probs': output, 'loc_preds': bbox_output} + + def forward(self, feat, targets=None): + feat = feat[-1] + b, c, h, w = feat.shape + feat = feat.reshape([b, c, h * w]) # flatten 2D feature map + feat = feat.transpose((0, 2, 1)) + out_enc = self.positional_encoding(feat) + if self.training: + return self.forward_train(out_enc, targets) + + return self.forward_test(out_enc) + + +class DecoderLayer(nn.Layer): + """ + Decoder is made of self attention, srouce attention and feed forward. + """ + + def __init__(self, headers, d_model, dropout, d_ff): + super(DecoderLayer, self).__init__() + self.self_attn = MultiHeadAttention(headers, d_model, dropout) + self.src_attn = MultiHeadAttention(headers, d_model, dropout) + self.feed_forward = FeedForward(d_model, d_ff, dropout) + self.sublayer = clones(SubLayerConnection(d_model, dropout), 3) + + def forward(self, x, feature, src_mask, tgt_mask): + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + x = self.sublayer[1]( + x, lambda x: self.src_attn(x, feature, feature, src_mask)) + return self.sublayer[2](x, self.feed_forward) + + +class MultiHeadAttention(nn.Layer): + def __init__(self, headers, d_model, dropout): + super(MultiHeadAttention, self).__init__() + + assert d_model % headers == 0 + self.d_k = int(d_model / headers) + self.headers = headers + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(dropout) + + def forward(self, query, key, value, mask=None): + B = query.shape[0] + + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = \ + [l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3]) + for l, x in zip(self.linears, (query, key, value))] + # 2) Apply attention on all the projected vectors in batch + x, self.attn = self_attention( + query, key, value, mask=mask, dropout=self.dropout) + x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k]) + return self.linears[-1](x) + + +class FeedForward(nn.Layer): + def __init__(self, d_model, d_ff, dropout): + super(FeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class SubLayerConnection(nn.Layer): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + + def __init__(self, size, dropout): + super(SubLayerConnection, self).__init__() + self.norm = nn.LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + return x + self.dropout(sublayer(self.norm(x))) + + +def masked_fill(x, mask, value): + mask = mask.astype(x.dtype) + return x * paddle.logical_not(mask).astype(x.dtype) + mask * value + + +def self_attention(query, key, value, mask=None, dropout=None): + """ + Compute 'Scale Dot Product Attention' + """ + d_k = value.shape[-1] + + score = paddle.matmul(query, key.transpose([0, 1, 3, 2]) / math.sqrt(d_k)) + if mask is not None: + # score = score.masked_fill(mask == 0, -1e9) # b, h, L, L + score = masked_fill(score, mask == 0, -6.55e4) # for fp16 + + p_attn = F.softmax(score, axis=-1) + + if dropout is not None: + p_attn = dropout(p_attn) + return paddle.matmul(p_attn, value), p_attn + + +def clones(module, N): + """ Produce N identical layers """ + return nn.LayerList([copy.deepcopy(module) for _ in range(N)]) + + +class Embeddings(nn.Layer): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, *input): + x = input[0] + return self.lut(x) * math.sqrt(self.d_model) + + +class PositionalEncoding(nn.Layer): + """ Implement the PE function. """ + + def __init__(self, d_model, dropout=0., max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = paddle.zeros([max_len, d_model]) + position = paddle.arange(0, max_len).unsqueeze(1).astype('float32') + div_term = paddle.exp( + paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, feat, **kwargs): + feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512 + return self.dropout(feat) diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py index fe251f36e736bb1eac8a71a8115c941cbd7443e6..d96ab51896884428d88a70c5a6a1e4ab59252c55 100644 --- a/ppocr/optimizer/learning_rate.py +++ b/ppocr/optimizer/learning_rate.py @@ -308,3 +308,46 @@ class Const(object): end_lr=self.learning_rate, last_epoch=self.last_epoch) return learning_rate + + +class MultiStepDecay(object): + """ + Piecewise learning rate decay + Args: + step_each_epoch(int): steps each epoch + learning_rate (float): The initial learning rate. It is a python float number. + step_size (int): the interval to update. + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + It should be less than 1.0. Default: 0.1. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + """ + + def __init__(self, + learning_rate, + milestones, + step_each_epoch, + gamma, + warmup_epoch=0, + last_epoch=-1, + **kwargs): + super(MultiStepDecay, self).__init__() + self.milestones = [step_each_epoch * e for e in milestones] + self.learning_rate = learning_rate + self.gamma = gamma + self.last_epoch = last_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) + + def __call__(self): + learning_rate = lr.MultiStepDecay( + learning_rate=self.learning_rate, + milestones=self.milestones, + gamma=self.gamma, + last_epoch=self.last_epoch) + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=0.0, + end_lr=self.learning_rate, + last_epoch=self.last_epoch) + return learning_rate \ No newline at end of file diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index f50b5f1c5f8e617066bb47636c8f4d2b171b6ecb..4a08f1531f7aa4f521c360e59d74ab62ff2911ba 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -26,8 +26,9 @@ from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ - DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ + DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \ SEEDLabelDecode, PRENLabelDecode +from .table_postprocess import TableMasterLabelDecode, TableLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess @@ -42,7 +43,7 @@ def build_post_process(config, global_config=None): 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', - 'DistillationSARLabelDecode' + 'DistillationSARLabelDecode', 'TableMasterLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index bf0fd890bf25949361665d212bf8e1a657054e5b..0d01b342106dc04fa44bc8f9fb74f56b1b67ff8a 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -444,146 +444,6 @@ class SRNLabelDecode(BaseRecLabelDecode): return idx -class TableLabelDecode(object): - """ """ - - def __init__(self, character_dict_path, **kwargs): - list_character, list_elem = self.load_char_elem_dict( - character_dict_path) - list_character = self.add_special_char(list_character) - list_elem = self.add_special_char(list_elem) - self.dict_character = {} - self.dict_idx_character = {} - for i, char in enumerate(list_character): - self.dict_idx_character[i] = char - self.dict_character[char] = i - self.dict_elem = {} - self.dict_idx_elem = {} - for i, elem in enumerate(list_elem): - self.dict_idx_elem[i] = elem - self.dict_elem[elem] = i - - def load_char_elem_dict(self, character_dict_path): - list_character = [] - list_elem = [] - with open(character_dict_path, "rb") as fin: - lines = fin.readlines() - substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split( - "\t") - character_num = int(substr[0]) - elem_num = int(substr[1]) - for cno in range(1, 1 + character_num): - character = lines[cno].decode('utf-8').strip("\n").strip("\r\n") - list_character.append(character) - for eno in range(1 + character_num, 1 + character_num + elem_num): - elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n") - list_elem.append(elem) - return list_character, list_elem - - def add_special_char(self, list_character): - self.beg_str = "sos" - self.end_str = "eos" - list_character = [self.beg_str] + list_character + [self.end_str] - return list_character - - def __call__(self, preds): - structure_probs = preds['structure_probs'] - loc_preds = preds['loc_preds'] - if isinstance(structure_probs, paddle.Tensor): - structure_probs = structure_probs.numpy() - if isinstance(loc_preds, paddle.Tensor): - loc_preds = loc_preds.numpy() - structure_idx = structure_probs.argmax(axis=2) - structure_probs = structure_probs.max(axis=2) - structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode( - structure_idx, structure_probs, 'elem') - res_html_code_list = [] - res_loc_list = [] - batch_num = len(structure_str) - for bno in range(batch_num): - res_loc = [] - for sno in range(len(structure_str[bno])): - text = structure_str[bno][sno] - if text in ['', ' 0 and tmp_elem_idx == end_idx: - break - if tmp_elem_idx in ignored_tokens: - continue - - char_list.append(current_dict[tmp_elem_idx]) - elem_pos_list.append(idx) - score_list.append(structure_probs[batch_idx, idx]) - elem_idx_list.append(tmp_elem_idx) - result_list.append(char_list) - result_pos_list.append(elem_pos_list) - result_score_list.append(score_list) - result_elem_idx_list.append(elem_idx_list) - return result_list, result_pos_list, result_score_list, result_elem_idx_list - - def get_ignored_tokens(self, char_or_elem): - beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) - end_idx = self.get_beg_end_flag_idx("end", char_or_elem) - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): - if char_or_elem == "char": - if beg_or_end == "beg": - idx = self.dict_character[self.beg_str] - elif beg_or_end == "end": - idx = self.dict_character[self.end_str] - else: - assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ - % beg_or_end - elif char_or_elem == "elem": - if beg_or_end == "beg": - idx = self.dict_elem[self.beg_str] - elif beg_or_end == "end": - idx = self.dict_elem[self.end_str] - else: - assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ - % beg_or_end - else: - assert False, "Unsupport type %s in char_or_elem" \ - % char_or_elem - return idx - - class SARLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..4396ec4f701478e7bdcdd8c7752738c5c8ef148d --- /dev/null +++ b/ppocr/postprocess/table_postprocess.py @@ -0,0 +1,160 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from .rec_postprocess import AttnLabelDecode + + +class TableLabelDecode(AttnLabelDecode): + """ """ + + def __init__(self, character_dict_path, **kwargs): + super(TableLabelDecode, self).__init__(character_dict_path) + self.td_token = ['', '', ''] + + def __call__(self, preds, batch=None): + structure_probs = preds['structure_probs'] + bbox_preds = preds['loc_preds'] + if isinstance(structure_probs, paddle.Tensor): + structure_probs = structure_probs.numpy() + if isinstance(bbox_preds, paddle.Tensor): + bbox_preds = bbox_preds.numpy() + shape_list = batch[-1] + result = self.decode(structure_probs, bbox_preds, shape_list) + if len(batch) == 1: # only contains shape + return result + + label_decode_result = self.decode_label(batch) + return result, label_decode_result + + def decode(self, structure_probs, bbox_preds, shape_list): + """convert text-label into text-index. + """ + ignored_tokens = self.get_ignored_tokens() + end_idx = self.dict[self.end_str] + + structure_idx = structure_probs.argmax(axis=2) + structure_probs = structure_probs.max(axis=2) + + structure_batch_list = [] + bbox_batch_list = [] + batch_size = len(structure_idx) + for batch_idx in range(batch_size): + structure_list = [] + bbox_list = [] + score_list = [] + for idx in range(len(structure_idx[batch_idx])): + char_idx = int(structure_idx[batch_idx][idx]) + if idx > 0 and char_idx == end_idx: + break + if char_idx in ignored_tokens: + continue + text = self.character[char_idx] + if text in self.td_token: + bbox = bbox_preds[batch_idx, idx] + bbox = self._bbox_decode(bbox, shape_list[batch_idx]) + bbox_list.append(bbox) + structure_list.append(text) + score_list.append(structure_probs[batch_idx, idx]) + structure_batch_list.append([structure_list, np.mean(score_list)]) + bbox_batch_list.append(np.array(bbox_list)) + result = { + 'bbox_batch_list': bbox_batch_list, + 'structure_batch_list': structure_batch_list, + } + return result + + def decode_label(self, batch): + """convert text-label into text-index. + """ + structure_idx = batch[1] + gt_bbox_list = batch[2] + shape_list = batch[-1] + ignored_tokens = self.get_ignored_tokens() + end_idx = self.dict[self.end_str] + + structure_batch_list = [] + bbox_batch_list = [] + batch_size = len(structure_idx) + for batch_idx in range(batch_size): + structure_list = [] + bbox_list = [] + for idx in range(len(structure_idx[batch_idx])): + char_idx = int(structure_idx[batch_idx][idx]) + if idx > 0 and char_idx == end_idx: + break + if char_idx in ignored_tokens: + continue + structure_list.append(self.character[char_idx]) + + bbox = gt_bbox_list[batch_idx][idx] + if bbox.sum() != 0: + bbox = self._bbox_decode(bbox, shape_list[batch_idx]) + bbox_list.append(bbox) + structure_batch_list.append(structure_list) + bbox_batch_list.append(bbox_list) + result = { + 'bbox_batch_list': bbox_batch_list, + 'structure_batch_list': structure_batch_list, + } + return result + + def _bbox_decode(self, bbox, shape): + h, w, ratio_h, ratio_w, pad_h, pad_w = shape + src_h = h / ratio_h + src_w = w / ratio_w + bbox[0::2] *= src_w + bbox[1::2] *= src_h + return bbox + + +class TableMasterLabelDecode(TableLabelDecode): + """ """ + + def __init__(self, character_dict_path, box_shape='ori', **kwargs): + super(TableMasterLabelDecode, self).__init__(character_dict_path) + self.box_shape = box_shape + assert box_shape in [ + 'ori', 'pad' + ], 'The shape used for box normalization must be ori or pad' + + def add_special_char(self, dict_character): + self.beg_str = '' + self.end_str = '' + self.unknown_str = '' + self.pad_str = '' + dict_character = dict_character + dict_character = dict_character + [ + self.unknown_str, self.beg_str, self.end_str, self.pad_str + ] + return dict_character + + def get_ignored_tokens(self): + pad_idx = self.dict[self.pad_str] + start_idx = self.dict[self.beg_str] + end_idx = self.dict[self.end_str] + unknown_idx = self.dict[self.unknown_str] + return [start_idx, end_idx, pad_idx, unknown_idx] + + def _bbox_decode(self, bbox, shape): + h, w, ratio_h, ratio_w, pad_h, pad_w = shape + if self.box_shape == 'pad': + h, w = pad_h, pad_w + bbox[0::2] *= w + bbox[1::2] *= h + bbox[0::2] /= ratio_w + bbox[1::2] /= ratio_h + return bbox diff --git a/ppocr/utils/dict/table_master_structure_dict.txt b/ppocr/utils/dict/table_master_structure_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..95ab2539a70aca4f695c53a38cdc1c3e164fcfb3 --- /dev/null +++ b/ppocr/utils/dict/table_master_structure_dict.txt @@ -0,0 +1,39 @@ + + + + + + + + + + + colspan="2" + colspan="3" + + + rowspan="2" + colspan="4" + colspan="6" + rowspan="3" + colspan="9" + colspan="10" + colspan="7" + rowspan="4" + rowspan="5" + rowspan="9" + colspan="8" + rowspan="8" + rowspan="6" + rowspan="7" + rowspan="10" + + + + + + + + diff --git a/ppocr/utils/dict/table_structure_dict.txt b/ppocr/utils/dict/table_structure_dict.txt index 9c4531e5f3b8c498e70d3c2ea0471e5e746a2c30..8edb10b8817ad596af6c63b6b8fc5eb2349b7464 100644 --- a/ppocr/utils/dict/table_structure_dict.txt +++ b/ppocr/utils/dict/table_structure_dict.txt @@ -1,281 +1,3 @@ -277 28 1267 1186 - -V -a -r -i -b -l -e - -H -z -d - -t -o -9 -5 -% -C -I - -p - -v -u -* -A -g -( -m -n -) -0 -. -7 -1 -6 -≤ -> -8 -3 -– -2 -G -4 -M -F -T -y -f -s -L -w -c -U -h -D -S -Q -R -x -P -- -E -O -/ -k -, -+ -N -K -q -′ -[ -] -< -≥ - -− - -μ -± -J -j -W -_ -Δ -B -“ -: -Y -α -λ -; - - -? -∼ -= -° -# -̊ -̈ -̂ -’ -Z -X -∗ -— -β -' -† -~ -@ -" -γ -↓ -↑ -& -‡ -χ -” -σ -§ -| -¶ -‐ -× -$ -→ -√ -✓ -‘ -\ -∞ -π -• -® -^ -∆ -≧ - - -́ -♀ -♂ -‒ -⁎ -▲ -· -£ -φ -Ψ -ß -△ -☆ -▪ -η -€ -∧ -̃ -Φ -ρ -̄ -δ -‰ -̧ -Ω -♦ -{ -} -̀ -∑ -∫ -ø -κ -ε -¥ -※ -` -ω -Σ -➔ -‖ -Β -̸ -
 -─ -● -⩾ -Χ -Α -⋅ -◆ -★ -■ -ψ -ǂ -□ -ζ -! -Γ -↔ -θ -⁄ -〈 -〉 -― -υ -τ -⋆ -Ø -© -∥ -С -˂ -➢ -ɛ -⁡ -✗ -← -○ -¢ -⩽ -∖ -˃ -­ -≈ -Π -̌ -≦ -∅ -ᅟ - - -∣ -¤ -♯ -̆ -ξ -÷ -▼ - -ι -ν -║ - - -◦ -​ -◊ -∙ -« -» -ł -ı -Θ -∈ -„ -∘ -✔ -̇ -æ -ʹ -ˆ -♣ -⇓ -∩ -⊕ -⇒ -⇑ -̨ -Ι -Λ -⋯ -А -⋮ @@ -303,2457 +25,4 @@ $ rowspan="8" rowspan="6" rowspan="7" - rowspan="10" -0 2924682 -1 3405345 -2 2363468 -3 2709165 -4 4078680 -5 3250792 -6 1923159 -7 1617890 -8 1450532 -9 1717624 -10 1477550 -11 1489223 -12 915528 -13 819193 -14 593660 -15 518924 -16 682065 -17 494584 -18 400591 -19 396421 -20 340994 -21 280688 -22 250328 -23 226786 -24 199927 -25 182707 -26 164629 -27 141613 -28 127554 -29 116286 -30 107682 -31 96367 -32 88002 -33 79234 -34 72186 -35 65921 -36 60374 -37 55976 -38 52166 -39 47414 -40 44932 -41 41279 -42 38232 -43 35463 -44 33703 -45 30557 -46 29639 -47 27000 -48 25447 -49 23186 -50 22093 -51 20412 -52 19844 -53 18261 -54 17561 -55 16499 -56 15597 -57 14558 -58 14372 -59 13445 -60 13514 -61 12058 -62 11145 -63 10767 -64 10370 -65 9630 -66 9337 -67 8881 -68 8727 -69 8060 -70 7994 -71 7740 -72 7189 -73 6729 -74 6749 -75 6548 -76 6321 -77 5957 -78 5740 -79 5407 -80 5370 -81 5035 -82 4921 -83 4656 -84 4600 -85 4519 -86 4277 -87 4023 -88 3939 -89 3910 -90 3861 -91 3560 -92 3483 -93 3406 -94 3346 -95 3229 -96 3122 -97 3086 -98 3001 -99 2884 -100 2822 -101 2677 -102 2670 -103 2610 -104 2452 -105 2446 -106 2400 -107 2300 -108 2316 -109 2196 -110 2089 -111 2083 -112 2041 -113 1881 -114 1838 -115 1896 -116 1795 -117 1786 -118 1743 -119 1765 -120 1750 -121 1683 -122 1563 -123 1499 -124 1513 -125 1462 -126 1388 -127 1441 -128 1417 -129 1392 -130 1306 -131 1321 -132 1274 -133 1294 -134 1240 -135 1126 -136 1157 -137 1130 -138 1084 -139 1130 -140 1083 -141 1040 -142 980 -143 1031 -144 974 -145 980 -146 932 -147 898 -148 960 -149 907 -150 852 -151 912 -152 859 -153 847 -154 876 -155 792 -156 791 -157 765 -158 788 -159 787 -160 744 -161 673 -162 683 -163 697 -164 666 -165 680 -166 632 -167 677 -168 657 -169 618 -170 587 -171 585 -172 567 -173 549 -174 562 -175 548 -176 542 -177 539 -178 542 -179 549 -180 547 -181 526 -182 525 -183 514 -184 512 -185 505 -186 515 -187 467 -188 475 -189 458 -190 435 -191 443 -192 427 -193 424 -194 404 -195 389 -196 429 -197 404 -198 386 -199 351 -200 388 -201 408 -202 361 -203 346 -204 324 -205 361 -206 363 -207 364 -208 323 -209 336 -210 342 -211 315 -212 325 -213 328 -214 314 -215 327 -216 320 -217 300 -218 295 -219 315 -220 310 -221 295 -222 275 -223 248 -224 274 -225 232 -226 293 -227 259 -228 286 -229 263 -230 242 -231 214 -232 261 -233 231 -234 211 -235 250 -236 233 -237 206 -238 224 -239 210 -240 233 -241 223 -242 216 -243 222 -244 207 -245 212 -246 196 -247 205 -248 201 -249 202 -250 211 -251 201 -252 215 -253 179 -254 163 -255 179 -256 191 -257 188 -258 196 -259 150 -260 154 -261 176 -262 211 -263 166 -264 171 -265 165 -266 149 -267 182 -268 159 -269 161 -270 164 -271 161 -272 141 -273 151 -274 127 -275 129 -276 142 -277 158 -278 148 -279 135 -280 127 -281 134 -282 138 -283 131 -284 126 -285 125 -286 130 -287 126 -288 135 -289 125 -290 135 -291 131 -292 95 -293 135 -294 106 -295 117 -296 136 -297 128 -298 128 -299 118 -300 109 -301 112 -302 117 -303 108 -304 120 -305 100 -306 95 -307 108 -308 112 -309 77 -310 120 -311 104 -312 109 -313 89 -314 98 -315 82 -316 98 -317 93 -318 77 -319 93 -320 77 -321 98 -322 93 -323 86 -324 89 -325 73 -326 70 -327 71 -328 77 -329 87 -330 77 -331 93 -332 100 -333 83 -334 72 -335 74 -336 69 -337 77 -338 68 -339 78 -340 90 -341 98 -342 75 -343 80 -344 63 -345 71 -346 83 -347 66 -348 71 -349 70 -350 62 -351 62 -352 59 -353 63 -354 62 -355 52 -356 64 -357 64 -358 56 -359 49 -360 57 -361 63 -362 60 -363 68 -364 62 -365 55 -366 54 -367 40 -368 75 -369 70 -370 53 -371 58 -372 57 -373 55 -374 69 -375 57 -376 53 -377 43 -378 45 -379 47 -380 56 -381 51 -382 59 -383 51 -384 43 -385 34 -386 57 -387 49 -388 39 -389 46 -390 48 -391 43 -392 40 -393 54 -394 50 -395 41 -396 43 -397 33 -398 27 -399 49 -400 44 -401 44 -402 38 -403 30 -404 32 -405 37 -406 39 -407 42 -408 53 -409 39 -410 34 -411 31 -412 32 -413 52 -414 27 -415 41 -416 34 -417 36 -418 50 -419 35 -420 32 -421 33 -422 45 -423 35 -424 40 -425 29 -426 41 -427 40 -428 39 -429 32 -430 31 -431 34 -432 29 -433 27 -434 26 -435 22 -436 34 -437 28 -438 30 -439 38 -440 35 -441 36 -442 36 -443 27 -444 24 -445 33 -446 31 -447 25 -448 33 -449 27 -450 32 -451 46 -452 31 -453 35 -454 35 -455 34 -456 26 -457 21 -458 25 -459 26 -460 24 -461 27 -462 33 -463 30 -464 35 -465 21 -466 32 -467 19 -468 27 -469 16 -470 28 -471 26 -472 27 -473 26 -474 25 -475 25 -476 27 -477 20 -478 28 -479 22 -480 23 -481 16 -482 25 -483 27 -484 19 -485 23 -486 19 -487 15 -488 15 -489 23 -490 24 -491 19 -492 20 -493 18 -494 17 -495 30 -496 28 -497 20 -498 29 -499 17 -500 19 -501 21 -502 15 -503 24 -504 15 -505 19 -506 25 -507 16 -508 23 -509 26 -510 21 -511 15 -512 12 -513 16 -514 18 -515 24 -516 26 -517 18 -518 8 -519 25 -520 14 -521 8 -522 24 -523 20 -524 18 -525 15 -526 13 -527 17 -528 18 -529 22 -530 21 -531 9 -532 16 -533 17 -534 13 -535 17 -536 15 -537 13 -538 20 -539 13 -540 19 -541 29 -542 10 -543 8 -544 18 -545 13 -546 9 -547 18 -548 10 -549 18 -550 18 -551 9 -552 9 -553 15 -554 13 -555 15 -556 14 -557 14 -558 18 -559 8 -560 13 -561 9 -562 7 -563 12 -564 6 -565 9 -566 9 -567 18 -568 9 -569 10 -570 13 -571 14 -572 13 -573 21 -574 8 -575 16 -576 12 -577 9 -578 16 -579 17 -580 22 -581 6 -582 14 -583 13 -584 15 -585 11 -586 13 -587 5 -588 12 -589 13 -590 15 -591 13 -592 15 -593 12 -594 7 -595 18 -596 12 -597 13 -598 13 -599 13 -600 12 -601 12 -602 10 -603 11 -604 6 -605 6 -606 2 -607 9 -608 8 -609 12 -610 9 -611 12 -612 13 -613 12 -614 14 -615 9 -616 8 -617 9 -618 14 -619 13 -620 12 -621 6 -622 8 -623 8 -624 8 -625 12 -626 8 -627 7 -628 5 -629 8 -630 12 -631 6 -632 10 -633 10 -634 7 -635 8 -636 9 -637 6 -638 9 -639 4 -640 12 -641 4 -642 3 -643 11 -644 10 -645 6 -646 12 -647 12 -648 4 -649 4 -650 9 -651 8 -652 6 -653 5 -654 14 -655 10 -656 11 -657 8 -658 5 -659 5 -660 9 -661 13 -662 4 -663 5 -664 9 -665 11 -666 12 -667 7 -668 13 -669 2 -670 1 -671 7 -672 7 -673 7 -674 10 -675 9 -676 6 -677 5 -678 7 -679 6 -680 3 -681 3 -682 4 -683 9 -684 8 -685 5 -686 3 -687 11 -688 9 -689 2 -690 6 -691 5 -692 9 -693 5 -694 6 -695 5 -696 9 -697 8 -698 3 -699 7 -700 5 -701 9 -702 8 -703 7 -704 2 -705 3 -706 7 -707 6 -708 6 -709 10 -710 2 -711 10 -712 6 -713 7 -714 5 -715 6 -716 4 -717 6 -718 8 -719 4 -720 6 -721 7 -722 5 -723 7 -724 3 -725 10 -726 10 -727 3 -728 7 -729 7 -730 5 -731 2 -732 1 -733 5 -734 1 -735 5 -736 6 -737 2 -738 2 -739 3 -740 7 -741 2 -742 7 -743 4 -744 5 -745 4 -746 5 -747 3 -748 1 -749 4 -750 4 -751 2 -752 4 -753 6 -754 6 -755 6 -756 3 -757 2 -758 5 -759 5 -760 3 -761 4 -762 2 -763 1 -764 8 -765 3 -766 4 -767 3 -768 1 -769 5 -770 3 -771 3 -772 4 -773 4 -774 1 -775 3 -776 2 -777 2 -778 3 -779 3 -780 1 -781 4 -782 3 -783 4 -784 6 -785 3 -786 5 -787 4 -788 2 -789 4 -790 5 -791 4 -792 6 -794 4 -795 1 -796 1 -797 4 -798 2 -799 3 -800 3 -801 1 -802 5 -803 5 -804 3 -805 3 -806 3 -807 4 -808 4 -809 2 -811 5 -812 4 -813 6 -814 3 -815 2 -816 2 -817 3 -818 5 -819 3 -820 1 -821 1 -822 4 -823 3 -824 4 -825 8 -826 3 -827 5 -828 5 -829 3 -830 6 -831 3 -832 4 -833 8 -834 5 -835 3 -836 3 -837 2 -838 4 -839 2 -840 1 -841 3 -842 2 -843 1 -844 3 -846 4 -847 4 -848 3 -849 3 -850 2 -851 3 -853 1 -854 4 -855 4 -856 2 -857 4 -858 1 -859 2 -860 5 -861 1 -862 1 -863 4 -864 2 -865 2 -867 5 -868 1 -869 4 -870 1 -871 1 -872 1 -873 2 -875 5 -876 3 -877 1 -878 3 -879 3 -880 3 -881 2 -882 1 -883 6 -884 2 -885 2 -886 1 -887 1 -888 3 -889 2 -890 2 -891 3 -892 1 -893 3 -894 1 -895 5 -896 1 -897 3 -899 2 -900 2 -902 1 -903 2 -904 4 -905 4 -906 3 -907 1 -908 1 -909 2 -910 5 -911 2 -912 3 -914 1 -915 1 -916 2 -918 2 -919 2 -920 4 -921 4 -922 1 -923 1 -924 4 -925 5 -926 1 -928 2 -929 1 -930 1 -931 1 -932 1 -933 1 -934 2 -935 1 -936 1 -937 1 -938 2 -939 1 -941 1 -942 4 -944 2 -945 2 -946 2 -947 1 -948 1 -950 1 -951 2 -953 1 -954 2 -955 1 -956 1 -957 2 -958 1 -960 3 -962 4 -963 1 -964 1 -965 3 -966 2 -967 2 -968 1 -969 3 -970 3 -972 1 -974 4 -975 3 -976 3 -977 2 -979 2 -980 1 -981 1 -983 5 -984 1 -985 3 -986 1 -987 2 -988 4 -989 2 -991 2 -992 2 -993 1 -994 1 -996 2 -997 2 -998 1 -999 3 -1000 2 -1001 1 -1002 3 -1003 3 -1004 2 -1005 3 -1006 1 -1007 2 -1009 1 -1011 1 -1013 3 -1014 1 -1016 2 -1017 1 -1018 1 -1019 1 -1020 4 -1021 1 -1022 2 -1025 1 -1026 1 -1027 2 -1028 1 -1030 1 -1031 2 -1032 4 -1034 3 -1035 2 -1036 1 -1038 1 -1039 1 -1040 1 -1041 1 -1042 2 -1043 1 -1044 2 -1045 4 -1048 1 -1050 1 -1051 1 -1052 2 -1054 1 -1055 3 -1056 2 -1057 1 -1059 1 -1061 2 -1063 1 -1064 1 -1065 1 -1066 1 -1067 1 -1068 1 -1069 2 -1074 1 -1075 1 -1077 1 -1078 1 -1079 1 -1082 1 -1085 1 -1088 1 -1090 1 -1091 1 -1092 2 -1094 2 -1097 2 -1098 1 -1099 2 -1101 2 -1102 1 -1104 1 -1105 1 -1107 1 -1109 1 -1111 2 -1112 1 -1114 2 -1115 2 -1116 2 -1117 1 -1118 1 -1119 1 -1120 1 -1122 1 -1123 1 -1127 1 -1128 3 -1132 2 -1138 3 -1142 1 -1145 4 -1150 1 -1153 2 -1154 1 -1158 1 -1159 1 -1163 1 -1165 1 -1169 2 -1174 1 -1176 1 -1177 1 -1178 2 -1179 1 -1180 2 -1181 1 -1182 1 -1183 2 -1185 1 -1187 1 -1191 2 -1193 1 -1195 3 -1196 1 -1201 3 -1203 1 -1206 1 -1210 1 -1213 1 -1214 1 -1215 2 -1218 1 -1220 1 -1221 1 -1225 1 -1226 1 -1233 2 -1241 1 -1243 1 -1249 1 -1250 2 -1251 1 -1254 1 -1255 2 -1260 1 -1268 1 -1270 1 -1273 1 -1274 1 -1277 1 -1284 1 -1287 1 -1291 1 -1292 2 -1294 1 -1295 2 -1297 1 -1298 1 -1301 1 -1307 1 -1308 3 -1311 2 -1313 1 -1316 1 -1321 1 -1324 1 -1325 1 -1330 1 -1333 1 -1334 1 -1338 2 -1340 1 -1341 1 -1342 1 -1343 1 -1345 1 -1355 1 -1357 1 -1360 2 -1375 1 -1376 1 -1380 1 -1383 1 -1387 1 -1389 1 -1393 1 -1394 1 -1396 1 -1398 1 -1410 1 -1414 1 -1419 1 -1425 1 -1434 1 -1435 1 -1438 1 -1439 1 -1447 1 -1455 2 -1460 1 -1461 1 -1463 1 -1466 1 -1470 1 -1473 1 -1478 1 -1480 1 -1483 1 -1484 1 -1485 2 -1492 2 -1499 1 -1509 1 -1512 1 -1513 1 -1523 1 -1524 1 -1525 2 -1529 1 -1539 1 -1544 1 -1568 1 -1584 1 -1591 1 -1598 1 -1600 1 -1604 1 -1614 1 -1617 1 -1621 1 -1622 1 -1626 1 -1638 1 -1648 1 -1658 1 -1661 1 -1679 1 -1682 1 -1693 1 -1700 1 -1705 1 -1707 1 -1722 1 -1728 1 -1758 1 -1762 1 -1763 1 -1775 1 -1776 1 -1801 1 -1810 1 -1812 1 -1827 1 -1834 1 -1846 1 -1847 1 -1848 1 -1851 1 -1862 1 -1866 1 -1877 2 -1884 1 -1888 1 -1903 1 -1912 1 -1925 1 -1938 1 -1955 1 -1998 1 -2054 1 -2058 1 -2065 1 -2069 1 -2076 1 -2089 1 -2104 1 -2111 1 -2133 1 -2138 1 -2156 1 -2204 1 -2212 1 -2237 1 -2246 2 -2298 1 -2304 1 -2360 1 -2400 1 -2481 1 -2544 1 -2586 1 -2622 1 -2666 1 -2682 1 -2725 1 -2920 1 -3997 1 -4019 1 -5211 1 -12 19 -14 1 -16 401 -18 2 -20 421 -22 557 -24 625 -26 50 -28 4481 -30 52 -32 550 -34 5840 -36 4644 -38 87 -40 5794 -41 33 -42 571 -44 11805 -46 4711 -47 7 -48 597 -49 12 -50 678 -51 2 -52 14715 -53 3 -54 7322 -55 3 -56 508 -57 39 -58 3486 -59 11 -60 8974 -61 45 -62 1276 -63 4 -64 15693 -65 15 -66 657 -67 13 -68 6409 -69 10 -70 3188 -71 25 -72 1889 -73 27 -74 10370 -75 9 -76 12432 -77 23 -78 520 -79 15 -80 1534 -81 29 -82 2944 -83 23 -84 12071 -85 36 -86 1502 -87 10 -88 10978 -89 11 -90 889 -91 16 -92 4571 -93 17 -94 7855 -95 21 -96 2271 -97 33 -98 1423 -99 15 -100 11096 -101 21 -102 4082 -103 13 -104 5442 -105 25 -106 2113 -107 26 -108 3779 -109 43 -110 1294 -111 29 -112 7860 -113 29 -114 4965 -115 22 -116 7898 -117 25 -118 1772 -119 28 -120 1149 -121 38 -122 1483 -123 32 -124 10572 -125 25 -126 1147 -127 31 -128 1699 -129 22 -130 5533 -131 22 -132 4669 -133 34 -134 3777 -135 10 -136 5412 -137 21 -138 855 -139 26 -140 2485 -141 46 -142 1970 -143 27 -144 6565 -145 40 -146 933 -147 15 -148 7923 -149 16 -150 735 -151 23 -152 1111 -153 33 -154 3714 -155 27 -156 2445 -157 30 -158 3367 -159 10 -160 4646 -161 27 -162 990 -163 23 -164 5679 -165 25 -166 2186 -167 17 -168 899 -169 32 -170 1034 -171 22 -172 6185 -173 32 -174 2685 -175 17 -176 1354 -177 38 -178 1460 -179 15 -180 3478 -181 20 -182 958 -183 20 -184 6055 -185 23 -186 2180 -187 15 -188 1416 -189 30 -190 1284 -191 22 -192 1341 -193 21 -194 2413 -195 18 -196 4984 -197 13 -198 830 -199 22 -200 1834 -201 19 -202 2238 -203 9 -204 3050 -205 22 -206 616 -207 17 -208 2892 -209 22 -210 711 -211 30 -212 2631 -213 19 -214 3341 -215 21 -216 987 -217 26 -218 823 -219 9 -220 3588 -221 20 -222 692 -223 7 -224 2925 -225 31 -226 1075 -227 16 -228 2909 -229 18 -230 673 -231 20 -232 2215 -233 14 -234 1584 -235 21 -236 1292 -237 29 -238 1647 -239 25 -240 1014 -241 30 -242 1648 -243 19 -244 4465 -245 10 -246 787 -247 11 -248 480 -249 25 -250 842 -251 15 -252 1219 -253 23 -254 1508 -255 8 -256 3525 -257 16 -258 490 -259 12 -260 1678 -261 14 -262 822 -263 16 -264 1729 -265 28 -266 604 -267 11 -268 2572 -269 7 -270 1242 -271 15 -272 725 -273 18 -274 1983 -275 13 -276 1662 -277 19 -278 491 -279 12 -280 1586 -281 14 -282 563 -283 10 -284 2363 -285 10 -286 656 -287 14 -288 725 -289 28 -290 871 -291 9 -292 2606 -293 12 -294 961 -295 9 -296 478 -297 13 -298 1252 -299 10 -300 736 -301 19 -302 466 -303 13 -304 2254 -305 12 -306 486 -307 14 -308 1145 -309 13 -310 955 -311 13 -312 1235 -313 13 -314 931 -315 14 -316 1768 -317 11 -318 330 -319 10 -320 539 -321 23 -322 570 -323 12 -324 1789 -325 13 -326 884 -327 5 -328 1422 -329 14 -330 317 -331 11 -332 509 -333 13 -334 1062 -335 12 -336 577 -337 27 -338 378 -339 10 -340 2313 -341 9 -342 391 -343 13 -344 894 -345 17 -346 664 -347 9 -348 453 -349 6 -350 363 -351 15 -352 1115 -353 13 -354 1054 -355 8 -356 1108 -357 12 -358 354 -359 7 -360 363 -361 16 -362 344 -363 11 -364 1734 -365 12 -366 265 -367 10 -368 969 -369 16 -370 316 -371 12 -372 757 -373 7 -374 563 -375 15 -376 857 -377 9 -378 469 -379 9 -380 385 -381 12 -382 921 -383 15 -384 764 -385 14 -386 246 -387 6 -388 1108 -389 14 -390 230 -391 8 -392 266 -393 11 -394 641 -395 8 -396 719 -397 9 -398 243 -399 4 -400 1108 -401 7 -402 229 -403 7 -404 903 -405 7 -406 257 -407 12 -408 244 -409 3 -410 541 -411 6 -412 744 -413 8 -414 419 -415 8 -416 388 -417 19 -418 470 -419 14 -420 612 -421 6 -422 342 -423 3 -424 1179 -425 3 -426 116 -427 14 -428 207 -429 6 -430 255 -431 4 -432 288 -433 12 -434 343 -435 6 -436 1015 -437 3 -438 538 -439 10 -440 194 -441 6 -442 188 -443 15 -444 524 -445 7 -446 214 -447 7 -448 574 -449 6 -450 214 -451 5 -452 635 -453 9 -454 464 -455 5 -456 205 -457 9 -458 163 -459 2 -460 558 -461 4 -462 171 -463 14 -464 444 -465 11 -466 543 -467 5 -468 388 -469 6 -470 141 -471 4 -472 647 -473 3 -474 210 -475 4 -476 193 -477 7 -478 195 -479 7 -480 443 -481 10 -482 198 -483 3 -484 816 -485 6 -486 128 -487 9 -488 215 -489 9 -490 328 -491 7 -492 158 -493 11 -494 335 -495 8 -496 435 -497 6 -498 174 -499 1 -500 373 -501 5 -502 140 -503 7 -504 330 -505 9 -506 149 -507 5 -508 642 -509 3 -510 179 -511 3 -512 159 -513 8 -514 204 -515 7 -516 306 -517 4 -518 110 -519 5 -520 326 -521 6 -522 305 -523 6 -524 294 -525 7 -526 268 -527 5 -528 149 -529 4 -530 133 -531 2 -532 513 -533 10 -534 116 -535 5 -536 258 -537 4 -538 113 -539 4 -540 138 -541 6 -542 116 -544 485 -545 4 -546 93 -547 9 -548 299 -549 3 -550 256 -551 6 -552 92 -553 3 -554 175 -555 6 -556 253 -557 7 -558 95 -559 2 -560 128 -561 4 -562 206 -563 2 -564 465 -565 3 -566 69 -567 3 -568 157 -569 7 -570 97 -571 8 -572 118 -573 5 -574 130 -575 4 -576 301 -577 6 -578 177 -579 2 -580 397 -581 3 -582 80 -583 1 -584 128 -585 5 -586 52 -587 2 -588 72 -589 1 -590 84 -591 6 -592 323 -593 11 -594 77 -595 5 -596 205 -597 1 -598 244 -599 4 -600 69 -601 3 -602 89 -603 5 -604 254 -605 6 -606 147 -607 3 -608 83 -609 3 -610 77 -611 3 -612 194 -613 1 -614 98 -615 3 -616 243 -617 3 -618 50 -619 8 -620 188 -621 4 -622 67 -623 4 -624 123 -625 2 -626 50 -627 1 -628 239 -629 2 -630 51 -631 4 -632 65 -633 5 -634 188 -636 81 -637 3 -638 46 -639 3 -640 103 -641 1 -642 136 -643 3 -644 188 -645 3 -646 58 -648 122 -649 4 -650 47 -651 2 -652 155 -653 4 -654 71 -655 1 -656 71 -657 3 -658 50 -659 2 -660 177 -661 5 -662 66 -663 2 -664 183 -665 3 -666 50 -667 2 -668 53 -669 2 -670 115 -672 66 -673 2 -674 47 -675 1 -676 197 -677 2 -678 46 -679 3 -680 95 -681 3 -682 46 -683 3 -684 107 -685 1 -686 86 -687 2 -688 158 -689 4 -690 51 -691 1 -692 80 -694 56 -695 4 -696 40 -698 43 -699 3 -700 95 -701 2 -702 51 -703 2 -704 133 -705 1 -706 100 -707 2 -708 121 -709 2 -710 15 -711 3 -712 35 -713 2 -714 20 -715 3 -716 37 -717 2 -718 78 -720 55 -721 1 -722 42 -723 2 -724 218 -725 3 -726 23 -727 2 -728 26 -729 1 -730 64 -731 2 -732 65 -734 24 -735 2 -736 53 -737 1 -738 32 -739 1 -740 60 -742 81 -743 1 -744 77 -745 1 -746 47 -747 1 -748 62 -749 1 -750 19 -751 1 -752 86 -753 3 -754 40 -756 55 -757 2 -758 38 -759 1 -760 101 -761 1 -762 22 -764 67 -765 2 -766 35 -767 1 -768 38 -769 1 -770 22 -771 1 -772 82 -773 1 -774 73 -776 29 -777 1 -778 55 -780 23 -781 1 -782 16 -784 84 -785 3 -786 28 -788 59 -789 1 -790 33 -791 3 -792 24 -794 13 -795 1 -796 110 -797 2 -798 15 -800 22 -801 3 -802 29 -803 1 -804 87 -806 21 -808 29 -810 48 -812 28 -813 1 -814 58 -815 1 -816 48 -817 1 -818 31 -819 1 -820 66 -822 17 -823 2 -824 58 -826 10 -827 2 -828 25 -829 1 -830 29 -831 1 -832 63 -833 1 -834 26 -835 3 -836 52 -837 1 -838 18 -840 27 -841 2 -842 12 -843 1 -844 83 -845 1 -846 7 -847 1 -848 10 -850 26 -852 25 -853 1 -854 15 -856 27 -858 32 -859 1 -860 15 -862 43 -864 32 -865 1 -866 6 -868 39 -870 11 -872 25 -873 1 -874 10 -875 1 -876 20 -877 2 -878 19 -879 1 -880 30 -882 11 -884 53 -886 25 -887 1 -888 28 -890 6 -892 36 -894 10 -896 13 -898 14 -900 31 -902 14 -903 2 -904 43 -906 25 -908 9 -910 11 -911 1 -912 16 -913 1 -914 24 -916 27 -918 6 -920 15 -922 27 -923 1 -924 23 -926 13 -928 42 -929 1 -930 3 -932 27 -934 17 -936 8 -937 1 -938 11 -940 33 -942 4 -943 1 -944 18 -946 15 -948 13 -950 18 -952 12 -954 11 -956 21 -958 10 -960 13 -962 5 -964 32 -966 13 -968 8 -970 8 -971 1 -972 23 -973 2 -974 12 -975 1 -976 22 -978 7 -979 1 -980 14 -982 8 -984 22 -985 1 -986 6 -988 17 -989 1 -990 6 -992 13 -994 19 -996 11 -998 4 -1000 9 -1002 2 -1004 14 -1006 5 -1008 3 -1010 9 -1012 29 -1014 6 -1016 22 -1017 1 -1018 8 -1019 1 -1020 7 -1022 6 -1023 1 -1024 10 -1026 2 -1028 8 -1030 11 -1031 2 -1032 8 -1034 9 -1036 13 -1038 12 -1040 12 -1042 3 -1044 12 -1046 3 -1048 11 -1050 2 -1051 1 -1052 2 -1054 11 -1056 6 -1058 8 -1059 1 -1060 23 -1062 6 -1063 1 -1064 8 -1066 3 -1068 6 -1070 8 -1071 1 -1072 5 -1074 3 -1076 5 -1078 3 -1080 11 -1081 1 -1082 7 -1084 18 -1086 4 -1087 1 -1088 3 -1090 3 -1092 7 -1094 3 -1096 12 -1098 6 -1099 1 -1100 2 -1102 6 -1104 14 -1106 3 -1108 6 -1110 5 -1112 2 -1114 8 -1116 3 -1118 3 -1120 7 -1122 10 -1124 6 -1126 8 -1128 1 -1130 4 -1132 3 -1134 2 -1136 5 -1138 5 -1140 8 -1142 3 -1144 7 -1146 3 -1148 11 -1150 1 -1152 5 -1154 1 -1156 5 -1158 1 -1160 5 -1162 3 -1164 6 -1165 1 -1166 1 -1168 4 -1169 1 -1170 3 -1171 1 -1172 2 -1174 5 -1176 3 -1177 1 -1180 8 -1182 2 -1184 4 -1186 2 -1188 3 -1190 2 -1192 5 -1194 6 -1196 1 -1198 2 -1200 2 -1204 10 -1206 2 -1208 9 -1210 1 -1214 6 -1216 3 -1218 4 -1220 9 -1221 2 -1222 1 -1224 5 -1226 4 -1228 8 -1230 1 -1232 1 -1234 3 -1236 5 -1240 3 -1242 1 -1244 3 -1245 1 -1246 4 -1248 6 -1250 2 -1252 7 -1256 3 -1258 2 -1260 2 -1262 3 -1264 4 -1265 1 -1266 1 -1270 1 -1271 1 -1272 2 -1274 3 -1276 3 -1278 1 -1280 3 -1284 1 -1286 1 -1290 1 -1292 3 -1294 1 -1296 7 -1300 2 -1302 4 -1304 3 -1306 2 -1308 2 -1312 1 -1314 1 -1316 3 -1318 2 -1320 1 -1324 8 -1326 1 -1330 1 -1331 1 -1336 2 -1338 1 -1340 3 -1341 1 -1344 1 -1346 2 -1347 1 -1348 3 -1352 1 -1354 2 -1356 1 -1358 1 -1360 3 -1362 1 -1364 4 -1366 1 -1370 1 -1372 3 -1380 2 -1384 2 -1388 2 -1390 2 -1392 2 -1394 1 -1396 1 -1398 1 -1400 2 -1402 1 -1404 1 -1406 1 -1410 1 -1412 5 -1418 1 -1420 1 -1424 1 -1432 2 -1434 2 -1442 3 -1444 5 -1448 1 -1454 1 -1456 1 -1460 3 -1462 4 -1468 1 -1474 1 -1476 1 -1478 2 -1480 1 -1486 2 -1488 1 -1492 1 -1496 1 -1500 3 -1503 1 -1506 1 -1512 2 -1516 1 -1522 1 -1524 2 -1534 4 -1536 1 -1538 1 -1540 2 -1544 2 -1548 1 -1556 1 -1560 1 -1562 1 -1564 2 -1566 1 -1568 1 -1570 1 -1572 1 -1576 1 -1590 1 -1594 1 -1604 1 -1608 1 -1614 1 -1622 1 -1624 2 -1628 1 -1629 1 -1636 1 -1642 1 -1654 2 -1660 1 -1664 1 -1670 1 -1684 4 -1698 1 -1732 3 -1742 1 -1752 1 -1760 1 -1764 1 -1772 2 -1798 1 -1808 1 -1820 1 -1852 1 -1856 1 -1874 1 -1902 1 -1908 1 -1952 1 -2004 1 -2018 1 -2020 1 -2028 1 -2174 1 -2233 1 -2244 1 -2280 1 -2290 1 -2352 1 -2604 1 -4190 1 + rowspan="10" \ No newline at end of file diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py index 0179c614ae4864677576f6073f291282fb772988..17ec909582a0d8ae70829730da23c7580104eb68 100755 --- a/ppstructure/table/predict_structure.py +++ b/ppstructure/table/predict_structure.py @@ -23,6 +23,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 import numpy as np import time +import json import tools.infer.utility as utility from ppocr.data import create_operators, transform @@ -34,32 +35,50 @@ from ppstructure.utility import parse_args logger = get_logger() +def build_pre_process_list(args): + resize_op = {'ResizeTableImage': {'max_len': args.table_max_len, }} + pad_op = { + 'PaddingTableImage': { + 'size': [args.table_max_len, args.table_max_len] + } + } + normalize_op = { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225] if + args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5], + 'mean': [0.485, 0.456, 0.406] if + args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5], + 'scale': '1./255.', + 'order': 'hwc' + } + } + to_chw_op = {'ToCHWImage': None} + keep_keys_op = {'KeepKeys': {'keep_keys': ['image', 'shape']}} + if args.table_algorithm not in ['TableMaster']: + pre_process_list = [ + resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op + ] + else: + pre_process_list = [ + resize_op, pad_op, normalize_op, to_chw_op, keep_keys_op + ] + return pre_process_list + + class TableStructurer(object): def __init__(self, args): - pre_process_list = [{ - 'ResizeTableImage': { - 'max_len': args.table_max_len + pre_process_list = build_pre_process_list(args) + if args.table_algorithm not in ['TableMaster']: + postprocess_params = { + 'name': 'TableLabelDecode', + "character_dict_path": args.table_char_dict_path, } - }, { - 'NormalizeImage': { - 'std': [0.229, 0.224, 0.225], - 'mean': [0.485, 0.456, 0.406], - 'scale': '1./255.', - 'order': 'hwc' + else: + postprocess_params = { + 'name': 'TableMasterLabelDecode', + "character_dict_path": args.table_char_dict_path, + 'box_shape': 'pad' } - }, { - 'PaddingTableImage': None - }, { - 'ToCHWImage': None - }, { - 'KeepKeys': { - 'keep_keys': ['image'] - } - }] - postprocess_params = { - 'name': 'TableLabelDecode', - "character_dict_path": args.table_char_dict_path, - } self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) @@ -88,27 +107,30 @@ class TableStructurer(object): preds['structure_probs'] = outputs[1] preds['loc_preds'] = outputs[0] - post_result = self.postprocess_op(preds) - - structure_str_list = post_result['structure_str_list'] - res_loc = post_result['res_loc'] - imgh, imgw = ori_im.shape[0:2] - res_loc_final = [] - for rno in range(len(res_loc[0])): - x0, y0, x1, y1 = res_loc[0][rno] - left = max(int(imgw * x0), 0) - top = max(int(imgh * y0), 0) - right = min(int(imgw * x1), imgw - 1) - bottom = min(int(imgh * y1), imgh - 1) - res_loc_final.append([left, top, right, bottom]) - - structure_str_list = structure_str_list[0][:-1] + shape_list = np.expand_dims(data[-1], axis=0) + post_result = self.postprocess_op(preds, [shape_list]) + + structure_str_list = post_result['structure_batch_list'][0] + bbox_list = post_result['bbox_batch_list'][0] + structure_str_list = structure_str_list[0] structure_str_list = [ '', '', '' ] + structure_str_list + ['
', '', ''] - elapse = time.time() - starttime - return (structure_str_list, res_loc_final), elapse + return structure_str_list, bbox_list, elapse + + +def draw_rectangle(img_path, boxes, use_xywh=False): + img = cv2.imread(img_path) + img_show = img.copy() + for box in boxes.astype(int): + if use_xywh: + x, y, w, h = box + x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 + else: + x1, y1, x2, y2 = box + cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2) + return img_show def main(args): @@ -116,21 +138,35 @@ def main(args): table_structurer = TableStructurer(args) count = 0 total_time = 0 - for image_file in image_file_list: - img, flag = check_and_read_gif(image_file) - if not flag: - img = cv2.imread(image_file) - if img is None: - logger.info("error in loading image:{}".format(image_file)) - continue - structure_res, elapse = table_structurer(img) - - logger.info("result: {}".format(structure_res)) - - if count > 0: - total_time += elapse - count += 1 - logger.info("Predict time of {}: {}".format(image_file, elapse)) + use_xywh = args.table_algorithm in ['TableMaster'] + os.makedirs(args.output, exist_ok=True) + with open( + os.path.join(args.output, 'infer.txt'), mode='w', + encoding='utf-8') as f_w: + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + structure_str_list, bbox_list, elapse = table_structurer(img) + + bbox_list_str = json.dumps(bbox_list.tolist()) + logger.info("result: {}, {}".format(structure_str_list, + bbox_list_str)) + f_w.write("result: {}, {}\n".format(structure_str_list, + bbox_list_str)) + + img = draw_rectangle(image_file, bbox_list, use_xywh) + img_save_path = os.path.join(args.output, + os.path.basename(image_file)) + cv2.imwrite(img_save_path, img) + logger.info("save vis result to {}".format(img_save_path)) + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) if __name__ == "__main__": diff --git a/ppstructure/utility.py b/ppstructure/utility.py index 1ad902e7e6be95a6901e3774420fad337f594861..05452c23b53356991d1684f0ed0f63649447e915 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -25,6 +25,7 @@ def init_args(): parser.add_argument("--output", type=str, default='./output') # params for table structure parser.add_argument("--table_max_len", type=int, default=488) + parser.add_argument("--table_algorithm", type=str, default='TableAttn') parser.add_argument("--table_model_dir", type=str) parser.add_argument( "--table_char_dict_path", @@ -65,7 +66,7 @@ def init_args(): "--recovery", type=bool, default=False, - help='Whether to enable layout of recovery') + help='Whether to enable layout of recovery') return parser diff --git a/tools/export_model.py b/tools/export_model.py index e971f6cb20025d529d0387d287ec87a76abbdbe7..fbb2201e39906660ac200350751f684091117f30 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -88,6 +88,8 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): infer_shape = [1, 32, 100] elif arch_config["model_type"] == "table": infer_shape = [3, 488, 488] + if arch_config["algorithm"] == "TableMaster": + infer_shape = [3, 480, 480] model = to_static( model, input_spec=[ diff --git a/tools/infer_table.py b/tools/infer_table.py index 66c2da4421a313c634d27eb7a1013638a7c005ed..58e7455cbb7feb0d87d72238aba52c72abc6f87b 100644 --- a/tools/infer_table.py +++ b/tools/infer_table.py @@ -40,6 +40,7 @@ import tools.program as program import cv2 +@paddle.no_grad() def main(config, device, logger, vdl_writer): global_config = config['Global'] @@ -53,53 +54,74 @@ def main(config, device, logger, vdl_writer): getattr(post_process_class, 'character')) model = build_model(config['Architecture']) + algorithm = config['Architecture']['algorithm'] + use_xywh = algorithm in ['TableMaster'] load_model(config, model) # create data ops transforms = [] - use_padding = False for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] - if 'Label' in op_name: + if 'Encode' in op_name: continue if op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['image'] - if op_name == "ResizeTableImage": - use_padding = True - padding_max_len = op['ResizeTableImage']['max_len'] + op[op_name]['keep_keys'] = ['image', 'shape'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) + save_res_path = config['Global']['save_res_path'] + os.makedirs(save_res_path, exist_ok=True) + model.eval() - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - images = np.expand_dims(batch[0], axis=0) - images = paddle.to_tensor(images) - preds = model(images) - post_result = post_process_class(preds) - res_html_code = post_result['res_html_code'] - res_loc = post_result['res_loc'] - img = cv2.imread(file) - imgh, imgw = img.shape[0:2] - res_loc_final = [] - for rno in range(len(res_loc[0])): - x0, y0, x1, y1 = res_loc[0][rno] - left = max(int(imgw * x0), 0) - top = max(int(imgh * y0), 0) - right = min(int(imgw * x1), imgw - 1) - bottom = min(int(imgh * y1), imgh - 1) - cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2) - res_loc_final.append([left, top, right, bottom]) - res_loc_str = json.dumps(res_loc_final) - logger.info("result: {}, {}".format(res_html_code, res_loc_final)) - logger.info("success!") + with open( + os.path.join(save_res_path, 'infer.txt'), mode='w', + encoding='utf-8') as f_w: + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + images = np.expand_dims(batch[0], axis=0) + shape_list = np.expand_dims(batch[1], axis=0) + + images = paddle.to_tensor(images) + preds = model(images) + post_result = post_process_class(preds, [shape_list]) + + structure_str_list = post_result['structure_batch_list'][0] + bbox_list = post_result['bbox_batch_list'][0] + structure_str_list = structure_str_list[0] + structure_str_list = [ + '', '', '' + ] + structure_str_list + ['
', '', ''] + bbox_list_str = json.dumps(bbox_list.tolist()) + + logger.info("result: {}, {}".format(structure_str_list, + bbox_list_str)) + f_w.write("result: {}, {}\n".format(structure_str_list, + bbox_list_str)) + + img = draw_rectangle(file, bbox_list, use_xywh) + cv2.imwrite( + os.path.join(save_res_path, os.path.basename(file)), img) + logger.info("success!") + + +def draw_rectangle(img_path, boxes, use_xywh=False): + img = cv2.imread(img_path) + img_show = img.copy() + for box in boxes.astype(int): + if use_xywh: + x, y, w, h = box + x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 + else: + x1, y1, x2, y2 = box + cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2) + return img_show if __name__ == '__main__': diff --git a/tools/program.py b/tools/program.py index 7c02dc0149f36085ef05ca378b79d27e92d6dd57..17079cb86e7762663a76951379fb8d7804b19f9e 100755 --- a/tools/program.py +++ b/tools/program.py @@ -274,8 +274,11 @@ def train(config, if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need batch = [item.numpy() for item in batch] - if model_type in ['table', 'kie']: + if model_type in ['kie']: eval_class(preds, batch) + elif model_type in ['table']: + post_result = post_process_class(preds, batch) + eval_class(post_result, batch) else: if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2' ]: # for multi head loss @@ -302,7 +305,8 @@ def train(config, train_stats.update(stats) if log_writer is not None and dist.get_rank() == 0: - log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) + log_writer.log_metrics( + metrics=train_stats.get(), prefix="TRAIN", step=global_step) if dist.get_rank() == 0 and ( (global_step > 0 and global_step % print_batch_step == 0) or @@ -349,7 +353,8 @@ def train(config, # logger metric if log_writer is not None: - log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) + log_writer.log_metrics( + metrics=cur_metric, prefix="EVAL", step=global_step) if cur_metric[main_indicator] >= best_model_dict[ main_indicator]: @@ -372,11 +377,18 @@ def train(config, logger.info(best_str) # logger best metric if log_writer is not None: - log_writer.log_metrics(metrics={ - "best_{}".format(main_indicator): best_model_dict[main_indicator] - }, prefix="EVAL", step=global_step) - - log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) + log_writer.log_metrics( + metrics={ + "best_{}".format(main_indicator): + best_model_dict[main_indicator] + }, + prefix="EVAL", + step=global_step) + + log_writer.log_model( + is_best=True, + prefix="best_accuracy", + metadata=best_model_dict) reader_start = time.time() if dist.get_rank() == 0: @@ -408,7 +420,8 @@ def train(config, epoch=epoch, global_step=global_step) if log_writer is not None: - log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) + log_writer.log_model( + is_best=False, prefix='iter_epoch_{}'.format(epoch)) best_str = 'best metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) @@ -446,7 +459,6 @@ def eval(model, preds = model(batch) else: preds = model(images) - batch_numpy = [] for item in batch: if isinstance(item, paddle.Tensor): @@ -456,9 +468,9 @@ def eval(model, # Obtain usable results from post-processing methods total_time += time.time() - start # Evaluate the results of the current batch - if model_type in ['table', 'kie']: + if model_type in ['kie']: eval_class(preds, batch_numpy) - elif model_type in ['vqa']: + elif model_type in ['table', 'vqa']: post_result = post_process_class(preds, batch_numpy) eval_class(post_result, batch_numpy) else: @@ -559,7 +571,8 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', + 'TableMaster' ] device = 'cpu' @@ -578,7 +591,8 @@ def preprocess(is_train=False): vdl_writer_path = '{}/vdl/'.format(save_model_dir) log_writer = VDLLogger(save_model_dir) loggers.append(log_writer) - if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: + if ('use_wandb' in config['Global'] and + config['Global']['use_wandb']) or 'wandb' in config: save_dir = config['Global']['save_model_dir'] wandb_writer_path = "{}/wandb".format(save_dir) if "wandb" in config: