From b4800ad2b59617e82a6e1e9d2875a7ce67b2b3a6 Mon Sep 17 00:00:00 2001 From: zhoujun Date: Wed, 10 Nov 2021 20:21:30 +0800 Subject: [PATCH] fix gap in table structure train model and inference model (#4566) --- configs/table/table_mv3.yml | 17 +++++------ ppocr/modeling/heads/table_att_head.py | 40 +++++++++++++++----------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index a74e18d3..1a91ea95 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -1,29 +1,28 @@ Global: use_gpu: true - epoch_num: 50 + epoch_num: 400 log_smooth_window: 20 print_batch_step: 5 save_model_dir: ./output/table_mv3/ - save_epoch_step: 5 + save_epoch_step: 3 # evaluation is run every 400 iterations after the 0th iteration eval_batch_step: [0, 400] cal_metric_during_train: True - pretrained_model: + pretrained_model: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/table/table.jpg # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en max_text_length: 100 - max_elem_length: 500 + max_elem_length: 800 max_cell_num: 500 infer_mode: False process_total_num: 0 process_cut_num: 0 - Optimizer: name: Adam beta1: 0.9 @@ -41,13 +40,15 @@ Architecture: Backbone: name: MobileNetV3 scale: 1.0 - model_name: small - disable_se: True + model_name: large Head: name: TableAttentionHead hidden_size: 256 l2_decay: 0.00001 loc_type: 2 + max_text_length: 100 + max_elem_length: 800 + max_cell_num: 500 Loss: name: TableAttentionLoss diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index 155f036d..e354f40d 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -23,32 +23,40 @@ import numpy as np class TableAttentionHead(nn.Layer): - def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): + def __init__(self, + in_channels, + hidden_size, + loc_type, + in_max_len=488, + max_text_length=100, + max_elem_length=800, + max_cell_num=500, + **kwargs): super(TableAttentionHead, self).__init__() self.input_size = in_channels[-1] self.hidden_size = hidden_size self.elem_num = 30 - self.max_text_length = 100 - self.max_elem_length = 500 - self.max_cell_num = 500 + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num self.structure_attention_cell = AttentionGRUCell( self.input_size, hidden_size, self.elem_num, use_gru=False) self.structure_generator = nn.Linear(hidden_size, self.elem_num) self.loc_type = loc_type self.in_max_len = in_max_len - + if self.loc_type == 1: self.loc_generator = nn.Linear(hidden_size, 4) else: if self.in_max_len == 640: - self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1) elif self.in_max_len == 800: - self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1) else: - self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) - + def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) return input_ont_hot @@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer): if len(fea.shape) == 3: pass else: - last_shape = int(np.prod(fea.shape[2:])) # gry added + last_shape = int(np.prod(fea.shape[2:])) # gry added fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) batch_size = fea.shape[0] - + hidden = paddle.zeros((batch_size, self.hidden_size)) output_hiddens = [] if self.training and targets is not None: structure = targets[0] - for i in range(self.max_elem_length+1): + for i in range(self.max_elem_length + 1): elem_onehots = self._char_to_onehot( structure[:, i], onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer): alpha = None max_elem_length = paddle.to_tensor(self.max_elem_length) i = 0 - while i < max_elem_length+1: + while i < max_elem_length + 1: elem_onehots = self._char_to_onehot( temp_elem, onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer): structure_probs_step = self.structure_generator(outputs) temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") i += 1 - + output = paddle.concat(output_hiddens, axis=1) structure_probs = self.structure_generator(output) structure_probs = F.softmax(structure_probs) @@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer): loc_concat = paddle.concat([output, loc_fea], axis=2) loc_preds = self.loc_generator(loc_concat) loc_preds = F.sigmoid(loc_preds) - return {'structure_probs':structure_probs, 'loc_preds':loc_preds} + return {'structure_probs': structure_probs, 'loc_preds': loc_preds} + - class AttentionGRUCell(nn.Layer): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): super(AttentionGRUCell, self).__init__() -- GitLab