提交 16c247ac 编写于 作者: M MissPenguin

refine

上级 7c8b2c8d
Global:
use_gpu: true
epoch_num: 40
epoch_num: 50
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_mv3/
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
save_epoch_step: 5
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train: True
pretrained_model:
checkpoints:
......@@ -18,19 +17,20 @@ Global:
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 800
max_elem_length: 500
max_cell_num: 500
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 5.0
lr:
learning_rate: 0.0001
learning_rate: 0.001
regularizer:
name: 'L2'
factor: 0.00000
......@@ -41,12 +41,12 @@ Architecture:
Backbone:
name: MobileNetV3
scale: 1.0
model_name: large
model_name: small
disable_se: True
Head:
name: TableAttentionHead # AttentionHead
hidden_size: 256 #
name: TableAttentionHead
hidden_size: 256
l2_decay: 0.00001
# loc_type: 1
loc_type: 2
Loss:
......@@ -86,7 +86,7 @@ Train:
shuffle: True
batch_size_per_card: 32
drop_last: True
num_workers: 4
num_workers: 1
Eval:
dataset:
......@@ -113,4 +113,4 @@ Eval:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 4
num_workers: 1
......@@ -412,7 +412,6 @@ class TableLabelEncode(object):
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
# structure = [0] + structure + [0]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
structure = np.array(structure)
data['structure'] = structure
......@@ -443,8 +442,6 @@ class TableLabelEncode(object):
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
# structure_mask[td_idx] = self.span_weight
# structure_mask[cand_span_idx] = self.span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
......@@ -458,23 +455,6 @@ class TableLabelEncode(object):
self.max_elem_length, self.max_cell_num, elem_num])
return data
########
# for char decode
# cell_list = []
# for cell in cells:
# char_list = cell['tokens']
# cell = self.encode(char_list, 'char')
# if cell is None:
# return None
# cell = [0] + cell + [len(self.dict_character) - 1]
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
# cell_list.append(cell)
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
# cell_list = np.array(cell_list)
# cell_list_padding[0:cell_list.shape[0]] = cell_list
# data['cells'] = cell_list_padding
# return data
def encode(self, text, char_or_elem):
"""convert text-label into text-index.
"""
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,6 +19,7 @@ import json
from .imaug import transform, create_operators
class PubTabDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__()
......@@ -57,23 +58,6 @@ class PubTabDataSet(Dataset):
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def load_hard_select_prob(self):
label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
img_select_prob = {}
with open(label_path, "rb") as fin:
lines = fin.readlines()
for lno in range(len(lines)):
substr = lines[lno].decode('utf-8').strip("\n").split(" ")
img_name = substr[0].strip(":")
score = float(substr[1])
if score <= 0.8:
img_select_prob[img_name] = self.hard_prob[0]
elif score <= 0.98:
img_select_prob[img_name] = self.hard_prob[1]
else:
img_select_prob[img_name] = self.hard_prob[2]
return img_select_prob
def __getitem__(self, idx):
try:
......@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
# if self.table_select_type != table_type:
# select_flag = False
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -21,13 +21,16 @@ import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.char_num = 280
self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = 500
self.max_cell_num = 500
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
......@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, 801)
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, 801)
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
else:
self.loc_fea_trans = nn.Linear(256, 801)
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim):
......@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]
#sp_tokens = targets[2].numpy()
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
max_text_length, max_elem_length, max_cell_num = 100, 800, 500
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
if mode == 'Train' and targets is not None:
structure = targets[0]
for i in range(max_elem_length+1):
for i in range(self.max_elem_length+1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
......@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
elem_onehots = None
outputs = None
alpha = None
max_elem_length = paddle.to_tensor(max_elem_length)
max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0
while i < max_elem_length+1:
elem_onehots = self._char_to_onehot(
......@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_51.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
stride = 1,
weight_attr=ParamAttr(
name='conv2d_50.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_49.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_48.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_52.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_53.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_54.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_55.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.fuse_conv = nn.Conv2D(
in_channels=self.out_channels * 4,
out_channels=512,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
def forward(self, x):
c2, c3, c4, c5 = x
......
......@@ -368,18 +368,6 @@ class TableLabelDecode(object):
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_sp_tokens(self):
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num])
return sp_tokens
def __call__(self, preds):
structure_probs = preds['structure_probs']
......
......@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
model = to_static(
model,
input_spec=[
......
......@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
img = f.read()
data = {'image': img}
batch = transform(data, ops)
sp_tokens = post_process_class.get_sp_tokens()
targets = [[], [], paddle.to_tensor([sp_tokens])]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
preds = model(images, data=targets, mode='Test')
preds = model(images, data=None, mode='Test')
post_result = post_process_class(preds)
res_html_code = post_result['res_html_code']
res_loc = post_result['res_loc']
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -276,6 +276,7 @@ def train(config,
valid_dataloader,
post_process_class,
eval_class,
"table",
use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册