提交 7c8b2c8d 编写于 作者: M MissPenguin

add train code for table

上级 e93735a2
Global:
use_gpu: true
epoch_num: 40
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_mv3/
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 400]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 5.0
lr:
learning_rate: 0.0001
regularizer:
name: 'L2'
factor: 0.00000
Architecture:
model_type: table
algorithm: TableAttn
Backbone:
name: MobileNetV3
scale: 1.0
model_name: large
Head:
name: TableAttentionHead # AttentionHead
hidden_size: 256 #
l2_decay: 0.00001
# loc_type: 1
loc_type: 2
Loss:
name: TableAttentionLoss
structure_weight: 100.0
loc_weight: 10000.0
PostProcess:
name: TableLabelDecode
Metric:
name: TableMetric
main_indicator: acc
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
loader:
shuffle: True
batch_size_per_card: 32
drop_last: True
num_workers: 4
Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 4
......@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
......@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
......
......@@ -351,3 +351,182 @@ class SRNLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class TableLabelEncode(object):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
span_weight = 1.0,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1+character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
for eno in range(1+character_num, 1+character_num+elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_span_idx_list(self):
span_idx_list = []
for elem in self.dict_elem:
if 'span' in elem:
span_idx_list.append(self.dict_elem[elem])
return span_idx_list
def __call__(self, data):
cells = data['cells']
structure = data['structure']['tokens']
structure = self.encode(structure, 'elem')
if structure is None:
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
# structure = [0] + structure + [0]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
structure = np.array(structure)
data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
span_weight = min(max(span_weight, 1.0), self.span_weight)
for cno in range(len(cells)):
if 'bbox' in cells[cno]:
bbox = cells[cno]['bbox'].copy()
bbox[0] = bbox[0] * 1.0 / img_width
bbox[1] = bbox[1] * 1.0 / img_height
bbox[2] = bbox[2] * 1.0 / img_width
bbox[3] = bbox[3] * 1.0 / img_height
td_idx = td_idx_list[cno]
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
# structure_mask[td_idx] = self.span_weight
# structure_mask[cand_span_idx] = self.span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num])
return data
########
# for char decode
# cell_list = []
# for cell in cells:
# char_list = cell['tokens']
# cell = self.encode(char_list, 'char')
# if cell is None:
# return None
# cell = [0] + cell + [len(self.dict_character) - 1]
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
# cell_list.append(cell)
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
# cell_list = np.array(cell_list)
# cell_list_padding[0:cell_list.shape[0]] = cell_list
# data['cells'] = cell_list_padding
# return data
def encode(self, text, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_character
else:
max_len = self.max_elem_length
current_dict = self.dict_elem
if len(text) > max_len:
return None
if len(text) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
text_list = []
for char in text:
if char not in current_dict:
return None
text_list.append(current_dict[char])
if len(text_list) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
return text_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = np.array(self.dict_character[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_character[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = np.array(self.dict_elem[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_elem[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import random
from paddle.io import Dataset
import json
from .imaug import transform, create_operators
class PubTabDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__()
self.logger = logger
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path')
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path)
with open(label_file_path, "rb") as f:
self.data_lines = f.readlines()
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def load_hard_select_prob(self):
label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
img_select_prob = {}
with open(label_path, "rb") as fin:
lines = fin.readlines()
for lno in range(len(lines)):
substr = lines[lno].decode('utf-8').strip("\n").split(" ")
img_name = substr[0].strip(":")
score = float(substr[1])
if score <= 0.8:
img_select_prob[img_name] = self.hard_prob[0]
elif score <= 0.98:
img_select_prob[img_name] = self.hard_prob[1]
else:
img_select_prob[img_name] = self.hard_prob[2]
return img_select_prob
def __getitem__(self, idx):
try:
data_line = self.data_lines[idx]
data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
# if self.table_select_type != table_type:
# select_flag = False
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'cells': cells, 'structure':structure}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
......@@ -38,11 +38,13 @@ from .basic_loss import DistanceLoss
# combined loss function
from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss'
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle import fluid
class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.use_giou = use_giou
self.giou_weight = giou_weight
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:return: loss
'''
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
# overlap
inters = iw * ih
# union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
# ious
ious = inters / uni
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
# enclose erea
enclose = ew * eh + eps
giou = ious - (enclose - uni) / enclose
loss = 1 - giou
if reduction == 'mean':
loss = paddle.mean(loss)
elif reduction == 'sum':
loss = paddle.sum(loss)
else:
raise NotImplementedError
return loss
def forward(self, predicts, batch):
structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
if len(batch) == 6:
structure_mask = batch[5].astype("int64")
structure_mask = structure_mask[:, 1:]
structure_mask = paddle.reshape(structure_mask, [-1])
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets)
if len(batch) == 6:
structure_loss = structure_loss * structure_mask
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[4].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
else:
total_loss = structure_loss + loc_loss
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
\ No newline at end of file
......@@ -26,11 +26,11 @@ from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
from .table_metric import TableMetric
def build_metric(config):
support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
]
config = copy.deepcopy(config)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, pred, batch, *args, **kwargs):
structure_probs = pred['structure_probs'].numpy()
structure_labels = batch[1]
correct_num = 0
all_num = 0
structure_probs = np.argmax(structure_probs, axis=2)
structure_labels = structure_labels[:, 1:]
batch_size = structure_probs.shape[0]
for bno in range(batch_size):
all_num += 1
if (structure_probs[bno] == structure_labels[bno]).all():
correct_num += 1
self.correct_num += correct_num
self.all_num += all_num
return {
'acc': correct_num * 1.0 / all_num,
}
def get_metric(self):
"""
return metrics {
'acc': 0,
}
"""
acc = 1.0 * self.correct_num / self.all_num
self.reset()
return {'acc': acc}
def reset(self):
self.correct_num = 0
self.all_num = 0
......@@ -69,7 +69,7 @@ class BaseModel(nn.Layer):
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None):
def forward(self, x, data=None, mode='Train'):
y = dict()
if self.use_transform:
x = self.transform(x)
......@@ -81,7 +81,10 @@ class BaseModel(nn.Layer):
if data is None:
x = self.head(x)
else:
x = self.head(x, data)
if mode == 'Eval' or mode == 'Test':
x = self.head(x, targets=data, mode=mode)
else:
x = self.head(x, targets=data)
y["head_out"] = x
if self.return_all_feats:
return y
......
......@@ -29,6 +29,10 @@ def build_backbone(config, model_type):
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
elif model_type == "table":
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
support_dict = ['ResNet', 'MobileNetV3']
else:
raise NotImplementedError
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
__all__ = ['MobileNetV3']
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class MobileNetV3(nn.Layer):
def __init__(self,
in_channels=3,
model_name='large',
scale=0.5,
disable_se=False,
**kwargs):
"""
the MobilenetV3 backbone network for detection module.
Args:
params(dict): the super parameters for build network
"""
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hardswish', 2],
[3, 200, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 480, 112, True, 'hardswish', 1],
[3, 672, 112, True, 'hardswish', 1],
[5, 672, 160, True, 'hardswish', 2],
[5, 960, 160, True, 'hardswish', 1],
[5, 960, 160, True, 'hardswish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hardswish', 2],
[5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hardswish', 2],
[5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError("mode[" + model_name +
"_model] is not implemented!")
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert scale in supported_scale, \
"supported scale are {} but input scale is {}".format(supported_scale, scale)
inplanes = 16
# conv1
self.conv = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act='hardswish',
name='conv1')
self.stages = []
self.out_channels = []
block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
start_idx = 2 if model_name == 'large' else 0
if s == 2 and i > start_idx:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
block_list = []
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2)))
inplanes = make_divisible(scale * c)
i += 1
block_list.append(
ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act='hardswish',
name='conv_last'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages):
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
def forward(self, x):
x = self.conv(x)
out_list = []
for stage in self.stages:
x = stage(x)
out_list.append(x)
return out_list
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=None,
param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hardswish":
x = F.hardswish(x)
else:
print("The activation function({}) is selected incorrectly.".
format(self.act))
exit()
return x
class ResidualUnit(nn.Layer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
kernel_size,
stride,
use_se,
act=None,
name=''):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
act=act,
name=name + "_depthwise")
if self.if_se:
self.mid_se = SEModule(mid_channels, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear")
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = paddle.add(inputs, x)
return x
class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
return inputs * outputs
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ["ResNet"]
class ConvBNLayer(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = []
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
out = []
for block in self.stages:
y = block(y)
out.append(y)
return out
......@@ -31,8 +31,10 @@ def build_head(config):
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead']
'SRNHead', 'PGHead', 'TableAttentionHead']
#table head
from .table_att_head import TableAttentionHead
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.char_num = 280
self.elem_num = 30
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.loc_type = loc_type
self.in_max_len = in_max_len
if self.loc_type == 1:
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, 801)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, 801)
else:
self.loc_fea_trans = nn.Linear(256, 801)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, mode='Train'):
# if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1]
if len(fea.shape) == 3:
pass
else:
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]
#sp_tokens = targets[2].numpy()
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
max_text_length, max_elem_length, max_cell_num = 100, 800, 500
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
if mode == 'Train' and targets is not None:
structure = targets[0]
for i in range(max_elem_length+1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.transpose([0, 2, 1])
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
else:
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
structure_probs = None
loc_preds = None
elem_onehots = None
outputs = None
alpha = None
max_elem_length = paddle.to_tensor(max_elem_length)
i = 0
while i < max_elem_length+1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.transpose([0, 2, 1])
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
......@@ -21,7 +21,8 @@ def build_neck(config):
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
from .pg_fpn import PGFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
from .table_fpn import TableFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class TableFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs):
super(TableFPN, self).__init__()
self.out_channels = 512
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D(
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_51.w_0', initializer=weight_attr),
bias_attr=False)
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
stride = 1,
weight_attr=ParamAttr(
name='conv2d_50.w_0', initializer=weight_attr),
bias_attr=False)
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_49.w_0', initializer=weight_attr),
bias_attr=False)
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_48.w_0', initializer=weight_attr),
bias_attr=False)
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_52.w_0', initializer=weight_attr),
bias_attr=False)
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_53.w_0', initializer=weight_attr),
bias_attr=False)
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_54.w_0', initializer=weight_attr),
bias_attr=False)
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_55.w_0', initializer=weight_attr),
bias_attr=False)
self.fuse_conv = nn.Conv2D(
in_channels=self.out_channels * 4,
out_channels=512,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.in5_conv(c5)
in4 = self.in4_conv(c4)
in3 = self.in3_conv(c3)
in2 = self.in2_conv(c2)
out4 = in4 + F.upsample(
in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
out3 = in3 + F.upsample(
out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
out2 = in2 + F.upsample(
out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
fuse = paddle.concat([in5, p4, p3, p2], axis=1)
fuse_conv = self.fuse_conv(fuse) * 0.005
return [c5 + fuse_conv]
......@@ -325,8 +325,14 @@ class TableLabelDecode(object):
""" """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
......@@ -363,6 +369,18 @@ class TableLabelDecode(object):
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_sp_tokens(self):
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num])
return sp_tokens
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
......
......@@ -48,6 +48,7 @@ def main():
getattr(post_process_class, 'character'))
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type']
best_model_dict = init_model(config, model)
if len(best_model_dict):
......@@ -60,7 +61,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, use_srn)
eval_class, model_type, use_srn)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
import json
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import paddle
from paddle.jit import to_static
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
import cv2
def main(config, device, logger, vdl_writer):
global_config = config['Global']
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character'))
model = build_model(config['Architecture'])
init_model(config, model, logger)
# create data ops
transforms = []
use_padding = False
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
if op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image']
if op_name == "ResizeTableImage":
use_padding = True
padding_max_len = op['ResizeTableImage']['max_len']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
sp_tokens = post_process_class.get_sp_tokens()
targets = [[], [], paddle.to_tensor([sp_tokens])]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
preds = model(images, data=targets, mode='Test')
post_result = post_process_class(preds)
res_html_code = post_result['res_html_code']
res_loc = post_result['res_loc']
img = cv2.imread(file)
imgh, imgw = img.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
res_loc_final.append([left, top, right, bottom])
res_loc_str = json.dumps(res_loc_final)
logger.info("result: {}, {}".format(res_html_code, res_loc_final))
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer)
......@@ -186,7 +186,8 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type']
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
else:
......@@ -211,6 +212,9 @@ def train(config,
others = batch[-4:]
preds = model(images, others)
model_average = True
elif model_type == "table":
others = batch[1:]
preds = model(images, others)
else:
preds = model(images)
loss = loss_class(preds, batch)
......@@ -232,8 +236,11 @@ def train(config,
if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
if model_type == 'table':
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
metric = eval_class.get_metric()
train_stats.update(metric)
......@@ -337,7 +344,7 @@ def train(config,
def eval(model, valid_dataloader, post_process_class, eval_class,
use_srn=False):
model_type, use_srn=False):
model.eval()
with paddle.no_grad():
total_frame = 0.0
......@@ -359,10 +366,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
post_result = post_process_class(preds, batch[1])
total_time += time.time() - start
# Evaluate the results of the current batch
eval_class(post_result, batch)
if model_type == 'table':
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
pbar.update(1)
total_frame += len(images)
# Get final metric,eg. acc or hmean
......@@ -386,7 +396,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation'
'CLS', 'PGNet', 'Distillation', 'TableAttn'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册