From 16c247ac465af18ac8616f5b5d04646f9b99414f Mon Sep 17 00:00:00 2001 From: MissPenguin Date: Mon, 21 Jun 2021 12:20:25 +0000 Subject: [PATCH] refine --- configs/table/table_mv3.yml | 24 +++++++++--------- ppocr/data/imaug/label_ops.py | 20 --------------- ppocr/data/pubtab_dataset.py | 22 ++-------------- ppocr/modeling/architectures/base_model.py | 2 +- ppocr/modeling/heads/table_att_head.py | 22 ++++++++-------- ppocr/modeling/necks/table_fpn.py | 29 ++++++++-------------- ppocr/postprocess/rec_postprocess.py | 12 --------- tools/export_model.py | 3 ++- tools/infer_table.py | 4 +-- tools/program.py | 3 ++- 10 files changed, 40 insertions(+), 101 deletions(-) diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index 32164fe3..a74e18d3 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -1,13 +1,12 @@ Global: use_gpu: true - epoch_num: 40 + epoch_num: 50 log_smooth_window: 20 print_batch_step: 5 save_model_dir: ./output/table_mv3/ - save_epoch_step: 3 - # evaluation is run every 5000 iterations after the 4000th iteration + save_epoch_step: 5 + # evaluation is run every 400 iterations after the 0th iteration eval_batch_step: [0, 400] - # if pretrained_model is saved in static mode, load_static_weights must set to True cal_metric_during_train: True pretrained_model: checkpoints: @@ -18,19 +17,20 @@ Global: character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en max_text_length: 100 - max_elem_length: 800 + max_elem_length: 500 max_cell_num: 500 infer_mode: False process_total_num: 0 process_cut_num: 0 + Optimizer: name: Adam beta1: 0.9 beta2: 0.999 clip_norm: 5.0 lr: - learning_rate: 0.0001 + learning_rate: 0.001 regularizer: name: 'L2' factor: 0.00000 @@ -41,12 +41,12 @@ Architecture: Backbone: name: MobileNetV3 scale: 1.0 - model_name: large + model_name: small + disable_se: True Head: - name: TableAttentionHead # AttentionHead - hidden_size: 256 # + name: TableAttentionHead + hidden_size: 256 l2_decay: 0.00001 -# loc_type: 1 loc_type: 2 Loss: @@ -86,7 +86,7 @@ Train: shuffle: True batch_size_per_card: 32 drop_last: True - num_workers: 4 + num_workers: 1 Eval: dataset: @@ -113,4 +113,4 @@ Eval: shuffle: False drop_last: False batch_size_per_card: 16 - num_workers: 4 + num_workers: 1 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index cd883d1b..e25cce79 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -412,7 +412,6 @@ class TableLabelEncode(object): return None elem_num = len(structure) structure = [0] + structure + [len(self.dict_elem) - 1] -# structure = [0] + structure + [0] structure = structure + [0] * (self.max_elem_length + 2 - len(structure)) structure = np.array(structure) data['structure'] = structure @@ -443,8 +442,6 @@ class TableLabelEncode(object): if cand_span_idx < (self.max_elem_length + 2): if structure[cand_span_idx] in span_idx_list: structure_mask[cand_span_idx] = span_weight -# structure_mask[td_idx] = self.span_weight -# structure_mask[cand_span_idx] = self.span_weight data['bbox_list'] = bbox_list data['bbox_list_mask'] = bbox_list_mask @@ -458,23 +455,6 @@ class TableLabelEncode(object): self.max_elem_length, self.max_cell_num, elem_num]) return data - ######## - # for char decode -# cell_list = [] -# for cell in cells: -# char_list = cell['tokens'] -# cell = self.encode(char_list, 'char') -# if cell is None: -# return None -# cell = [0] + cell + [len(self.dict_character) - 1] -# cell = cell + [0] * (self.max_text_length + 2 - len(cell)) -# cell_list.append(cell) -# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2)) -# cell_list = np.array(cell_list) -# cell_list_padding[0:cell_list.shape[0]] = cell_list -# data['cells'] = cell_list_padding -# return data - def encode(self, text, char_or_elem): """convert text-label into text-index. """ diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py index a2c3eebf..78b76c5a 100644 --- a/ppocr/data/pubtab_dataset.py +++ b/ppocr/data/pubtab_dataset.py @@ -1,4 +1,4 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import json from .imaug import transform, create_operators + class PubTabDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): super(PubTabDataSet, self).__init__() @@ -57,23 +58,6 @@ class PubTabDataSet(Dataset): random.seed(self.seed) random.shuffle(self.data_lines) return - - def load_hard_select_prob(self): - label_path = "./pretrained_model/teds_score_exp5_st2_train.txt" - img_select_prob = {} - with open(label_path, "rb") as fin: - lines = fin.readlines() - for lno in range(len(lines)): - substr = lines[lno].decode('utf-8').strip("\n").split(" ") - img_name = substr[0].strip(":") - score = float(substr[1]) - if score <= 0.8: - img_select_prob[img_name] = self.hard_prob[0] - elif score <= 0.98: - img_select_prob[img_name] = self.hard_prob[1] - else: - img_select_prob[img_name] = self.hard_prob[2] - return img_select_prob def __getitem__(self, idx): try: @@ -93,8 +77,6 @@ class PubTabDataSet(Dataset): table_type = "simple" if 'colspan' in structure_str or 'rowspan' in structure_str: table_type = "complex" -# if self.table_select_type != table_type: -# select_flag = False if table_type == "complex": if self.table_select_prob < random.uniform(0, 1): select_flag = False diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 49160b52..c1bdaaaf 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index 9e5c438a..61dacd35 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -21,13 +21,16 @@ import paddle.nn as nn import paddle.nn.functional as F import numpy as np + class TableAttentionHead(nn.Layer): def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): super(TableAttentionHead, self).__init__() self.input_size = in_channels[-1] self.hidden_size = hidden_size - self.char_num = 280 self.elem_num = 30 + self.max_text_length = 100 + self.max_elem_length = 500 + self.max_cell_num = 500 self.structure_attention_cell = AttentionGRUCell( self.input_size, hidden_size, self.elem_num, use_gru=False) @@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer): self.loc_generator = nn.Linear(hidden_size, 4) else: if self.in_max_len == 640: - self.loc_fea_trans = nn.Linear(400, 801) + self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1) elif self.in_max_len == 800: - self.loc_fea_trans = nn.Linear(625, 801) + self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1) else: - self.loc_fea_trans = nn.Linear(256, 801) + self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) def _char_to_onehot(self, input_char, onehot_dim): @@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer): fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) batch_size = fea.shape[0] - #sp_tokens = targets[2].numpy() - #char_beg_idx, char_end_idx = sp_tokens[0, 0:2] - #elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4] - #elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6] - #max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9] - max_text_length, max_elem_length, max_cell_num = 100, 800, 500 hidden = paddle.zeros((batch_size, self.hidden_size)) output_hiddens = [] if mode == 'Train' and targets is not None: structure = targets[0] - for i in range(max_elem_length+1): + for i in range(self.max_elem_length+1): elem_onehots = self._char_to_onehot( structure[:, i], onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer): elem_onehots = None outputs = None alpha = None - max_elem_length = paddle.to_tensor(max_elem_length) + max_elem_length = paddle.to_tensor(self.max_elem_length) i = 0 while i < max_elem_length+1: elem_onehots = self._char_to_onehot( @@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer): loc_preds = F.sigmoid(loc_preds) return {'structure_probs':structure_probs, 'loc_preds':loc_preds} + class AttentionGRUCell(nn.Layer): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): super(AttentionGRUCell, self).__init__() diff --git a/ppocr/modeling/necks/table_fpn.py b/ppocr/modeling/necks/table_fpn.py index d72bff4f..734f15af 100644 --- a/ppocr/modeling/necks/table_fpn.py +++ b/ppocr/modeling/necks/table_fpn.py @@ -1,4 +1,4 @@ -# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,70 +31,61 @@ class TableFPN(nn.Layer): in_channels=in_channels[0], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_51.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in3_conv = nn.Conv2D( in_channels=in_channels[1], out_channels=self.out_channels, kernel_size=1, stride = 1, - weight_attr=ParamAttr( - name='conv2d_50.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in4_conv = nn.Conv2D( in_channels=in_channels[2], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_49.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in5_conv = nn.Conv2D( in_channels=in_channels[3], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_48.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p5_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_52.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p4_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_53.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p3_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_54.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p2_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_55.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.fuse_conv = nn.Conv2D( in_channels=self.out_channels * 4, out_channels=512, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False) + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) def forward(self, x): c2, c3, c4, c5 = x diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 9429d6b4..912d9bba 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -368,18 +368,6 @@ class TableLabelDecode(object): self.end_str = "eos" list_character = [self.beg_str] + list_character + [self.end_str] return list_character - - def get_sp_tokens(self): - char_beg_idx = self.get_beg_end_flag_idx('beg', 'char') - char_end_idx = self.get_beg_end_flag_idx('end', 'char') - elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') - elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') - elem_char_idx1 = self.dict_elem[''] - elem_char_idx2 = self.dict_elem['