提交 16c247ac 编写于 作者: M MissPenguin

refine

上级 7c8b2c8d
Global: Global:
use_gpu: true use_gpu: true
epoch_num: 40 epoch_num: 50
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 3 save_epoch_step: 5
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400] eval_batch_step: [0, 400]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
...@@ -18,19 +17,20 @@ Global: ...@@ -18,19 +17,20 @@ Global:
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 100 max_text_length: 100
max_elem_length: 800 max_elem_length: 500
max_cell_num: 500 max_cell_num: 500
infer_mode: False infer_mode: False
process_total_num: 0 process_total_num: 0
process_cut_num: 0 process_cut_num: 0
Optimizer: Optimizer:
name: Adam name: Adam
beta1: 0.9 beta1: 0.9
beta2: 0.999 beta2: 0.999
clip_norm: 5.0 clip_norm: 5.0
lr: lr:
learning_rate: 0.0001 learning_rate: 0.001
regularizer: regularizer:
name: 'L2' name: 'L2'
factor: 0.00000 factor: 0.00000
...@@ -41,12 +41,12 @@ Architecture: ...@@ -41,12 +41,12 @@ Architecture:
Backbone: Backbone:
name: MobileNetV3 name: MobileNetV3
scale: 1.0 scale: 1.0
model_name: large model_name: small
disable_se: True
Head: Head:
name: TableAttentionHead # AttentionHead name: TableAttentionHead
hidden_size: 256 # hidden_size: 256
l2_decay: 0.00001 l2_decay: 0.00001
# loc_type: 1
loc_type: 2 loc_type: 2
Loss: Loss:
...@@ -86,7 +86,7 @@ Train: ...@@ -86,7 +86,7 @@ Train:
shuffle: True shuffle: True
batch_size_per_card: 32 batch_size_per_card: 32
drop_last: True drop_last: True
num_workers: 4 num_workers: 1
Eval: Eval:
dataset: dataset:
...@@ -113,4 +113,4 @@ Eval: ...@@ -113,4 +113,4 @@ Eval:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 16 batch_size_per_card: 16
num_workers: 4 num_workers: 1
...@@ -412,7 +412,6 @@ class TableLabelEncode(object): ...@@ -412,7 +412,6 @@ class TableLabelEncode(object):
return None return None
elem_num = len(structure) elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1] structure = [0] + structure + [len(self.dict_elem) - 1]
# structure = [0] + structure + [0]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)) structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
structure = np.array(structure) structure = np.array(structure)
data['structure'] = structure data['structure'] = structure
...@@ -443,8 +442,6 @@ class TableLabelEncode(object): ...@@ -443,8 +442,6 @@ class TableLabelEncode(object):
if cand_span_idx < (self.max_elem_length + 2): if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list: if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight structure_mask[cand_span_idx] = span_weight
# structure_mask[td_idx] = self.span_weight
# structure_mask[cand_span_idx] = self.span_weight
data['bbox_list'] = bbox_list data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask data['bbox_list_mask'] = bbox_list_mask
...@@ -458,23 +455,6 @@ class TableLabelEncode(object): ...@@ -458,23 +455,6 @@ class TableLabelEncode(object):
self.max_elem_length, self.max_cell_num, elem_num]) self.max_elem_length, self.max_cell_num, elem_num])
return data return data
########
# for char decode
# cell_list = []
# for cell in cells:
# char_list = cell['tokens']
# cell = self.encode(char_list, 'char')
# if cell is None:
# return None
# cell = [0] + cell + [len(self.dict_character) - 1]
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
# cell_list.append(cell)
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
# cell_list = np.array(cell_list)
# cell_list_padding[0:cell_list.shape[0]] = cell_list
# data['cells'] = cell_list_padding
# return data
def encode(self, text, char_or_elem): def encode(self, text, char_or_elem):
"""convert text-label into text-index. """convert text-label into text-index.
""" """
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,6 +19,7 @@ import json ...@@ -19,6 +19,7 @@ import json
from .imaug import transform, create_operators from .imaug import transform, create_operators
class PubTabDataSet(Dataset): class PubTabDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None): def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__() super(PubTabDataSet, self).__init__()
...@@ -58,23 +59,6 @@ class PubTabDataSet(Dataset): ...@@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def load_hard_select_prob(self):
label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
img_select_prob = {}
with open(label_path, "rb") as fin:
lines = fin.readlines()
for lno in range(len(lines)):
substr = lines[lno].decode('utf-8').strip("\n").split(" ")
img_name = substr[0].strip(":")
score = float(substr[1])
if score <= 0.8:
img_select_prob[img_name] = self.hard_prob[0]
elif score <= 0.98:
img_select_prob[img_name] = self.hard_prob[1]
else:
img_select_prob[img_name] = self.hard_prob[2]
return img_select_prob
def __getitem__(self, idx): def __getitem__(self, idx):
try: try:
data_line = self.data_lines[idx] data_line = self.data_lines[idx]
...@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset): ...@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
table_type = "simple" table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str: if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex" table_type = "complex"
# if self.table_select_type != table_type:
# select_flag = False
if table_type == "complex": if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1): if self.table_select_prob < random.uniform(0, 1):
select_flag = False select_flag = False
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -21,13 +21,16 @@ import paddle.nn as nn ...@@ -21,13 +21,16 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import numpy as np import numpy as np
class TableAttentionHead(nn.Layer): class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
super(TableAttentionHead, self).__init__() super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1] self.input_size = in_channels[-1]
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.char_num = 280
self.elem_num = 30 self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = 500
self.max_cell_num = 500
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False) self.input_size, hidden_size, self.elem_num, use_gru=False)
...@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer): ...@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4) self.loc_generator = nn.Linear(hidden_size, 4)
else: else:
if self.in_max_len == 640: if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, 801) self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
elif self.in_max_len == 800: elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, 801) self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
else: else:
self.loc_fea_trans = nn.Linear(256, 801) self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
...@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer): ...@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0] batch_size = fea.shape[0]
#sp_tokens = targets[2].numpy()
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
max_text_length, max_elem_length, max_cell_num = 100, 800, 500
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = [] output_hiddens = []
if mode == 'Train' and targets is not None: if mode == 'Train' and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(max_elem_length+1): for i in range(self.max_elem_length+1):
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num) structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
...@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer): ...@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
elem_onehots = None elem_onehots = None
outputs = None outputs = None
alpha = None alpha = None
max_elem_length = paddle.to_tensor(max_elem_length) max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0 i = 0
while i < max_elem_length+1: while i < max_elem_length+1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
...@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer): ...@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
loc_preds = F.sigmoid(loc_preds) loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds} return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
class AttentionGRUCell(nn.Layer): class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__() super(AttentionGRUCell, self).__init__()
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,70 +31,61 @@ class TableFPN(nn.Layer): ...@@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
in_channels=in_channels[0], in_channels=in_channels[0],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_51.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in3_conv = nn.Conv2D( self.in3_conv = nn.Conv2D(
in_channels=in_channels[1], in_channels=in_channels[1],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
stride = 1, stride = 1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_50.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in4_conv = nn.Conv2D( self.in4_conv = nn.Conv2D(
in_channels=in_channels[2], in_channels=in_channels[2],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_49.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in5_conv = nn.Conv2D( self.in5_conv = nn.Conv2D(
in_channels=in_channels[3], in_channels=in_channels[3],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_48.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p5_conv = nn.Conv2D( self.p5_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_52.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p4_conv = nn.Conv2D( self.p4_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_53.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p3_conv = nn.Conv2D( self.p3_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_54.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p2_conv = nn.Conv2D( self.p2_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_55.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.fuse_conv = nn.Conv2D( self.fuse_conv = nn.Conv2D(
in_channels=self.out_channels * 4, in_channels=self.out_channels * 4,
out_channels=512, out_channels=512,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
def forward(self, x): def forward(self, x):
c2, c3, c4, c5 = x c2, c3, c4, c5 = x
......
...@@ -369,18 +369,6 @@ class TableLabelDecode(object): ...@@ -369,18 +369,6 @@ class TableLabelDecode(object):
list_character = [self.beg_str] + list_character + [self.end_str] list_character = [self.beg_str] + list_character + [self.end_str]
return list_character return list_character
def get_sp_tokens(self):
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num])
return sp_tokens
def __call__(self, preds): def __call__(self, preds):
structure_probs = preds['structure_probs'] structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds'] loc_preds = preds['loc_preds']
......
...@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger): ...@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
) )
infer_shape[-1] = 100 infer_shape[-1] = 100
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
model = to_static( model = to_static(
model, model,
input_spec=[ input_spec=[
......
...@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer): ...@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
img = f.read() img = f.read()
data = {'image': img} data = {'image': img}
batch = transform(data, ops) batch = transform(data, ops)
sp_tokens = post_process_class.get_sp_tokens()
targets = [[], [], paddle.to_tensor([sp_tokens])]
images = np.expand_dims(batch[0], axis=0) images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images, data=targets, mode='Test') preds = model(images, data=None, mode='Test')
post_result = post_process_class(preds) post_result = post_process_class(preds)
res_html_code = post_result['res_html_code'] res_html_code = post_result['res_html_code']
res_loc = post_result['res_loc'] res_loc = post_result['res_loc']
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -276,6 +276,7 @@ def train(config, ...@@ -276,6 +276,7 @@ def train(config,
valid_dataloader, valid_dataloader,
post_process_class, post_process_class,
eval_class, eval_class,
"table",
use_srn=use_srn) use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册