未验证 提交 b6a21419 编写于 作者: Z zhoujun 提交者: GitHub

fix gap between table structure train model and inference model (#4565)

* add indent in pipeline_rpc_client.py

* fix gap in table structure train model and inference model
上级 a8960021
Global: Global:
use_gpu: true use_gpu: true
epoch_num: 50 epoch_num: 400
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 5 save_epoch_step: 3
# evaluation is run every 400 iterations after the 0th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400] eval_batch_step: [0, 400]
cal_metric_during_train: True cal_metric_during_train: True
...@@ -12,18 +12,17 @@ Global: ...@@ -12,18 +12,17 @@ Global:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg infer_img: doc/table/table.jpg
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 100 max_text_length: 100
max_elem_length: 500 max_elem_length: 800
max_cell_num: 500 max_cell_num: 500
infer_mode: False infer_mode: False
process_total_num: 0 process_total_num: 0
process_cut_num: 0 process_cut_num: 0
Optimizer: Optimizer:
name: Adam name: Adam
beta1: 0.9 beta1: 0.9
...@@ -41,13 +40,15 @@ Architecture: ...@@ -41,13 +40,15 @@ Architecture:
Backbone: Backbone:
name: MobileNetV3 name: MobileNetV3
scale: 1.0 scale: 1.0
model_name: small model_name: large
disable_se: True
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
l2_decay: 0.00001 l2_decay: 0.00001
loc_type: 2 loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
Loss: Loss:
name: TableAttentionLoss name: TableAttentionLoss
......
...@@ -41,6 +41,6 @@ for img_file in os.listdir(test_img_dir): ...@@ -41,6 +41,6 @@ for img_file in os.listdir(test_img_dir):
image_data = file.read() image_data = file.read()
image = cv2_to_base64(image_data) image = cv2_to_base64(image_data)
for i in range(1): for i in range(1):
ret = client.predict(feed_dict={"image": image}, fetch=["res"]) ret = client.predict(feed_dict={"image": image}, fetch=["res"])
print(ret) print(ret)
...@@ -23,14 +23,22 @@ import numpy as np ...@@ -23,14 +23,22 @@ import numpy as np
class TableAttentionHead(nn.Layer): class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): def __init__(self,
in_channels,
hidden_size,
loc_type,
in_max_len=488,
max_text_length=100,
max_elem_length=800,
max_cell_num=500,
**kwargs):
super(TableAttentionHead, self).__init__() super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1] self.input_size = in_channels[-1]
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.elem_num = 30 self.elem_num = 30
self.max_text_length = 100 self.max_text_length = max_text_length
self.max_elem_length = 500 self.max_elem_length = max_elem_length
self.max_cell_num = 500 self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False) self.input_size, hidden_size, self.elem_num, use_gru=False)
...@@ -42,11 +50,11 @@ class TableAttentionHead(nn.Layer): ...@@ -42,11 +50,11 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4) self.loc_generator = nn.Linear(hidden_size, 4)
else: else:
if self.in_max_len == 640: if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1) self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
elif self.in_max_len == 800: elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1) self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
else: else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1) self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
...@@ -69,7 +77,7 @@ class TableAttentionHead(nn.Layer): ...@@ -69,7 +77,7 @@ class TableAttentionHead(nn.Layer):
output_hiddens = [] output_hiddens = []
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_elem_length+1): for i in range(self.max_elem_length + 1):
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num) structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
...@@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer): ...@@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer):
alpha = None alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length) max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0 i = 0
while i < max_elem_length+1: while i < max_elem_length + 1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num) temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
...@@ -119,7 +127,7 @@ class TableAttentionHead(nn.Layer): ...@@ -119,7 +127,7 @@ class TableAttentionHead(nn.Layer):
loc_concat = paddle.concat([output, loc_fea], axis=2) loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat) loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds) loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds} return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Layer): class AttentionGRUCell(nn.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册