diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml new file mode 100755 index 0000000000000000000000000000000000000000..32164fe30619e3fa3838f6b021d95925e86708c2 --- /dev/null +++ b/configs/table/table_mv3.yml @@ -0,0 +1,116 @@ +Global: + use_gpu: true + epoch_num: 40 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/table_mv3/ + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 400] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ppocr/utils/dict/table_structure_dict.txt + character_type: en + max_text_length: 100 + max_elem_length: 800 + max_cell_num: 500 + infer_mode: False + process_total_num: 0 + process_cut_num: 0 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + clip_norm: 5.0 + lr: + learning_rate: 0.0001 + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: table + algorithm: TableAttn + Backbone: + name: MobileNetV3 + scale: 1.0 + model_name: large + Head: + name: TableAttentionHead # AttentionHead + hidden_size: 256 # + l2_decay: 0.00001 +# loc_type: 1 + loc_type: 2 + +Loss: + name: TableAttentionLoss + structure_weight: 100.0 + loc_weight: 10000.0 + +PostProcess: + name: TableLabelDecode + +Metric: + name: TableMetric + main_indicator: acc + +Train: + dataset: + name: PubTabDataSet + data_dir: train_data/table/pubtabnet/train/ + label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - ResizeTableImage: + max_len: 488 + - TableLabelEncode: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] + loader: + shuffle: True + batch_size_per_card: 32 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: PubTabDataSet + data_dir: train_data/table/pubtabnet/val/ + label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - ResizeTableImage: + max_len: 488 + - TableLabelEncode: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 16 + num_workers: 4 diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 728b8317f54687ee76b519cba18f4d7807493821..e860c5a6986f495e6384d9df93c24795c04a0d5f 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.pgnet_dataset import PGDataSet +from ppocr.data.pubtab_dataset import PubTabDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp) def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet'] + support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index bba3209f7560f19b74a54c102caf697319814803..cd883d1b433701f27044eb76675b07d9ea234d00 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -351,3 +351,182 @@ class SRNLabelEncode(BaseRecLabelEncode): assert False, "Unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + +class TableLabelEncode(object): + """ Convert between text-label and text-index """ + def __init__(self, + max_text_length, + max_elem_length, + max_cell_num, + character_dict_path, + span_weight = 1.0, + **kwargs): + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num + list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + for i, char in enumerate(list_character): + self.dict_character[char] = i + self.dict_elem = {} + for i, elem in enumerate(list_elem): + self.dict_elem[elem] = i + self.span_weight = span_weight + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode('utf-8').strip("\n").split("\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1+character_num): + character = lines[cno].decode('utf-8').strip("\n") + list_character.append(character) + for eno in range(1+character_num, 1+character_num+elem_num): + elem = lines[eno].decode('utf-8').strip("\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def get_span_idx_list(self): + span_idx_list = [] + for elem in self.dict_elem: + if 'span' in elem: + span_idx_list.append(self.dict_elem[elem]) + return span_idx_list + + def __call__(self, data): + cells = data['cells'] + structure = data['structure']['tokens'] + structure = self.encode(structure, 'elem') + if structure is None: + return None + elem_num = len(structure) + structure = [0] + structure + [len(self.dict_elem) - 1] +# structure = [0] + structure + [0] + structure = structure + [0] * (self.max_elem_length + 2 - len(structure)) + structure = np.array(structure) + data['structure'] = structure + elem_char_idx1 = self.dict_elem[''] + elem_char_idx2 = self.dict_elem[' 0: + span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) + span_weight = min(max(span_weight, 1.0), self.span_weight) + for cno in range(len(cells)): + if 'bbox' in cells[cno]: + bbox = cells[cno]['bbox'].copy() + bbox[0] = bbox[0] * 1.0 / img_width + bbox[1] = bbox[1] * 1.0 / img_height + bbox[2] = bbox[2] * 1.0 / img_width + bbox[3] = bbox[3] * 1.0 / img_height + td_idx = td_idx_list[cno] + bbox_list[td_idx] = bbox + bbox_list_mask[td_idx] = 1.0 + cand_span_idx = td_idx + 1 + if cand_span_idx < (self.max_elem_length + 2): + if structure[cand_span_idx] in span_idx_list: + structure_mask[cand_span_idx] = span_weight +# structure_mask[td_idx] = self.span_weight +# structure_mask[cand_span_idx] = self.span_weight + + data['bbox_list'] = bbox_list + data['bbox_list_mask'] = bbox_list_mask + data['structure_mask'] = structure_mask + char_beg_idx = self.get_beg_end_flag_idx('beg', 'char') + char_end_idx = self.get_beg_end_flag_idx('end', 'char') + elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') + elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') + data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx, + elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length, + self.max_elem_length, self.max_cell_num, elem_num]) + return data + + ######## + # for char decode +# cell_list = [] +# for cell in cells: +# char_list = cell['tokens'] +# cell = self.encode(char_list, 'char') +# if cell is None: +# return None +# cell = [0] + cell + [len(self.dict_character) - 1] +# cell = cell + [0] * (self.max_text_length + 2 - len(cell)) +# cell_list.append(cell) +# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2)) +# cell_list = np.array(cell_list) +# cell_list_padding[0:cell_list.shape[0]] = cell_list +# data['cells'] = cell_list_padding +# return data + + def encode(self, text, char_or_elem): + """convert text-label into text-index. + """ + if char_or_elem == "char": + max_len = self.max_text_length + current_dict = self.dict_character + else: + max_len = self.max_elem_length + current_dict = self.dict_elem + if len(text) > max_len: + return None + if len(text) == 0: + if char_or_elem == "char": + return [self.dict_character['space']] + else: + return None + text_list = [] + for char in text: + if char not in current_dict: + return None + text_list.append(current_dict[char]) + if len(text_list) == 0: + if char_or_elem == "char": + return [self.dict_character['space']] + else: + return None + return text_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = np.array(self.dict_character[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict_character[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ + % beg_or_end + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = np.array(self.dict_elem[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict_elem[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ + % beg_or_end + else: + assert False, "Unsupport type %s in char_or_elem" \ + % char_or_elem + return idx + \ No newline at end of file diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a2c3eebf7b480219f71bf7ec04375029aa7db613 --- /dev/null +++ b/ppocr/data/pubtab_dataset.py @@ -0,0 +1,125 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import os +import random +from paddle.io import Dataset +import json + +from .imaug import transform, create_operators + +class PubTabDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(PubTabDataSet, self).__init__() + self.logger = logger + + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + + label_file_path = dataset_config.pop('label_file_path') + + self.data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + self.do_hard_select = False + if 'hard_select' in loader_config: + self.do_hard_select = loader_config['hard_select'] + self.hard_prob = loader_config['hard_prob'] + if self.do_hard_select: + self.img_select_prob = self.load_hard_select_prob() + self.table_select_type = None + if 'table_select_type' in loader_config: + self.table_select_type = loader_config['table_select_type'] + self.table_select_prob = loader_config['table_select_prob'] + + self.seed = seed + logger.info("Initialize indexs of datasets:%s" % label_file_path) + with open(label_file_path, "rb") as f: + self.data_lines = f.readlines() + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() + self.ops = create_operators(dataset_config['transforms'], global_config) + + def shuffle_data_random(self): + if self.do_shuffle: + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def load_hard_select_prob(self): + label_path = "./pretrained_model/teds_score_exp5_st2_train.txt" + img_select_prob = {} + with open(label_path, "rb") as fin: + lines = fin.readlines() + for lno in range(len(lines)): + substr = lines[lno].decode('utf-8').strip("\n").split(" ") + img_name = substr[0].strip(":") + score = float(substr[1]) + if score <= 0.8: + img_select_prob[img_name] = self.hard_prob[0] + elif score <= 0.98: + img_select_prob[img_name] = self.hard_prob[1] + else: + img_select_prob[img_name] = self.hard_prob[2] + return img_select_prob + + def __getitem__(self, idx): + try: + data_line = self.data_lines[idx] + data_line = data_line.decode('utf-8').strip("\n") + info = json.loads(data_line) + file_name = info['filename'] + select_flag = True + if self.do_hard_select: + prob = self.img_select_prob[file_name] + if prob < random.uniform(0, 1): + select_flag = False + + if self.table_select_type: + structure = info['html']['structure']['tokens'].copy() + structure_str = ''.join(structure) + table_type = "simple" + if 'colspan' in structure_str or 'rowspan' in structure_str: + table_type = "complex" +# if self.table_select_type != table_type: +# select_flag = False + if table_type == "complex": + if self.table_select_prob < random.uniform(0, 1): + select_flag = False + + if select_flag: + cells = info['html']['cells'].copy() + structure = info['html']['structure'].copy() + img_path = os.path.join(self.data_dir, file_name) + data = {'img_path': img_path, 'cells': cells, 'structure':structure} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + outs = transform(data, self.ops) + else: + outs = None + except Exception as e: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + data_line, e)) + outs = None + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index bf10d2982dcdd36021a7385ab8828398b51af3d3..025ae7ca5cc604eea59423ca7f523c37c1492e35 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -38,11 +38,13 @@ from .basic_loss import DistanceLoss # combined loss function from .combined_loss import CombinedLoss +# table loss +from .table_att_loss import TableAttentionLoss def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss' + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fd99e6952aacc0182a482ca5ae5ddaf959a026 --- /dev/null +++ b/ppocr/losses/table_att_loss.py @@ -0,0 +1,109 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle import fluid + +class TableAttentionLoss(nn.Layer): + def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): + super(TableAttentionLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') + self.structure_weight = structure_weight + self.loc_weight = loc_weight + self.use_giou = use_giou + self.giou_weight = giou_weight + + def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): + ''' + :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] + :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] + :return: loss + ''' + ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0]) + iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1]) + ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2]) + iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3]) + + iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10) + ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10) + + # overlap + inters = iw * ih + + # union + uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3 + ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * ( + bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps + + # ious + ious = inters / uni + + ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0]) + ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1]) + ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2]) + ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3]) + ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10) + eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10) + + # enclose erea + enclose = ew * eh + eps + giou = ious - (enclose - uni) / enclose + + loss = 1 - giou + + if reduction == 'mean': + loss = paddle.mean(loss) + elif reduction == 'sum': + loss = paddle.sum(loss) + else: + raise NotImplementedError + return loss + + def forward(self, predicts, batch): + structure_probs = predicts['structure_probs'] + structure_targets = batch[1].astype("int64") + structure_targets = structure_targets[:, 1:] + if len(batch) == 6: + structure_mask = batch[5].astype("int64") + structure_mask = structure_mask[:, 1:] + structure_mask = paddle.reshape(structure_mask, [-1]) + structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]]) + structure_targets = paddle.reshape(structure_targets, [-1]) + structure_loss = self.loss_func(structure_probs, structure_targets) + + if len(batch) == 6: + structure_loss = structure_loss * structure_mask + +# structure_loss = paddle.sum(structure_loss) * self.structure_weight + structure_loss = paddle.mean(structure_loss) * self.structure_weight + + loc_preds = predicts['loc_preds'] + loc_targets = batch[2].astype("float32") + loc_targets_mask = batch[4].astype("float32") + loc_targets = loc_targets[:, 1:, :] + loc_targets_mask = loc_targets_mask[:, 1:, :] + loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight + if self.use_giou: + loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight + total_loss = structure_loss + loc_loss + loc_loss_giou + return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou} + else: + total_loss = structure_loss + loc_loss + return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss} \ No newline at end of file diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 9e9060fa999bd3175c31dfc0797cd293d4e7afec..64f62e51cdf922773c03bb784a4edffdc17f506f 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -26,11 +26,11 @@ from .rec_metric import RecMetric from .cls_metric import ClsMetric from .e2e_metric import E2EMetric from .distillation_metric import DistillationMetric - +from .table_metric import TableMetric def build_metric(config): support_dict = [ - "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric" + "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric" ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..80d1c789ecc3979bd4c33620af91ccd28012f7a8 --- /dev/null +++ b/ppocr/metrics/table_metric.py @@ -0,0 +1,50 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +class TableMetric(object): + def __init__(self, main_indicator='acc', **kwargs): + self.main_indicator = main_indicator + self.reset() + + def __call__(self, pred, batch, *args, **kwargs): + structure_probs = pred['structure_probs'].numpy() + structure_labels = batch[1] + correct_num = 0 + all_num = 0 + structure_probs = np.argmax(structure_probs, axis=2) + structure_labels = structure_labels[:, 1:] + batch_size = structure_probs.shape[0] + for bno in range(batch_size): + all_num += 1 + if (structure_probs[bno] == structure_labels[bno]).all(): + correct_num += 1 + self.correct_num += correct_num + self.all_num += all_num + return { + 'acc': correct_num * 1.0 / all_num, + } + + def get_metric(self): + """ + return metrics { + 'acc': 0, + } + """ + acc = 1.0 * self.correct_num / self.all_num + self.reset() + return {'acc': acc} + + def reset(self): + self.correct_num = 0 + self.all_num = 0 diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 4c941fcf65573d9314c0badda49895d0b6b5c4f9..49160b52898a50984c3036d4dea48513ca53bb0d 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -69,7 +69,7 @@ class BaseModel(nn.Layer): self.return_all_feats = config.get("return_all_feats", False) - def forward(self, x, data=None): + def forward(self, x, data=None, mode='Train'): y = dict() if self.use_transform: x = self.transform(x) @@ -81,7 +81,10 @@ class BaseModel(nn.Layer): if data is None: x = self.head(x) else: - x = self.head(x, data) + if mode == 'Eval' or mode == 'Test': + x = self.head(x, targets=data, mode=mode) + else: + x = self.head(x, targets=data) y["head_out"] = x if self.return_all_feats: return y diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index fe2c9bc30a4f2abd1ba7d3d6989b9ef9b20c1f4f..13b70b203371b3be58ee82c6808d744bf6098333 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -29,6 +29,10 @@ def build_backbone(config, model_type): elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet support_dict = ['ResNet'] + elif model_type == "table": + from .table_resnet_vd import ResNet + from .table_mobilenet_v3 import MobileNetV3 + support_dict = ['ResNet', 'MobileNetV3'] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/table_mobilenet_v3.py b/ppocr/modeling/backbones/table_mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..daa87f976038d8d5eeafadceb869b9232ba22cd9 --- /dev/null +++ b/ppocr/modeling/backbones/table_mobilenet_v3.py @@ -0,0 +1,287 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + +__all__ = ['MobileNetV3'] + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class MobileNetV3(nn.Layer): + def __init__(self, + in_channels=3, + model_name='large', + scale=0.5, + disable_se=False, + **kwargs): + """ + the MobilenetV3 backbone network for detection module. + Args: + params(dict): the super parameters for build network + """ + super(MobileNetV3, self).__init__() + + self.disable_se = disable_se + + if model_name == "large": + cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, 'relu', 1], + [3, 64, 24, False, 'relu', 2], + [3, 72, 24, False, 'relu', 1], + [5, 72, 40, True, 'relu', 2], + [5, 120, 40, True, 'relu', 1], + [5, 120, 40, True, 'relu', 1], + [3, 240, 80, False, 'hardswish', 2], + [3, 200, 80, False, 'hardswish', 1], + [3, 184, 80, False, 'hardswish', 1], + [3, 184, 80, False, 'hardswish', 1], + [3, 480, 112, True, 'hardswish', 1], + [3, 672, 112, True, 'hardswish', 1], + [5, 672, 160, True, 'hardswish', 2], + [5, 960, 160, True, 'hardswish', 1], + [5, 960, 160, True, 'hardswish', 1], + ] + cls_ch_squeeze = 960 + elif model_name == "small": + cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, 'relu', 2], + [3, 72, 24, False, 'relu', 2], + [3, 88, 24, False, 'relu', 1], + [5, 96, 40, True, 'hardswish', 2], + [5, 240, 40, True, 'hardswish', 1], + [5, 240, 40, True, 'hardswish', 1], + [5, 120, 48, True, 'hardswish', 1], + [5, 144, 48, True, 'hardswish', 1], + [5, 288, 96, True, 'hardswish', 2], + [5, 576, 96, True, 'hardswish', 1], + [5, 576, 96, True, 'hardswish', 1], + ] + cls_ch_squeeze = 576 + else: + raise NotImplementedError("mode[" + model_name + + "_model] is not implemented!") + + supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25] + assert scale in supported_scale, \ + "supported scale are {} but input scale is {}".format(supported_scale, scale) + inplanes = 16 + # conv1 + self.conv = ConvBNLayer( + in_channels=in_channels, + out_channels=make_divisible(inplanes * scale), + kernel_size=3, + stride=2, + padding=1, + groups=1, + if_act=True, + act='hardswish', + name='conv1') + + self.stages = [] + self.out_channels = [] + block_list = [] + i = 0 + inplanes = make_divisible(inplanes * scale) + for (k, exp, c, se, nl, s) in cfg: + se = se and not self.disable_se + start_idx = 2 if model_name == 'large' else 0 + if s == 2 and i > start_idx: + self.out_channels.append(inplanes) + self.stages.append(nn.Sequential(*block_list)) + block_list = [] + block_list.append( + ResidualUnit( + in_channels=inplanes, + mid_channels=make_divisible(scale * exp), + out_channels=make_divisible(scale * c), + kernel_size=k, + stride=s, + use_se=se, + act=nl, + name="conv" + str(i + 2))) + inplanes = make_divisible(scale * c) + i += 1 + block_list.append( + ConvBNLayer( + in_channels=inplanes, + out_channels=make_divisible(scale * cls_ch_squeeze), + kernel_size=1, + stride=1, + padding=0, + groups=1, + if_act=True, + act='hardswish', + name='conv_last')) + self.stages.append(nn.Sequential(*block_list)) + self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) + for i, stage in enumerate(self.stages): + self.add_sublayer(sublayer=stage, name="stage{}".format(i)) + + def forward(self, x): + x = self.conv(x) + out_list = [] + for stage in self.stages: + x = stage(x) + out_list.append(x) + return out_list + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=None, + param_attr=ParamAttr(name=name + "_bn_scale"), + bias_attr=ParamAttr(name=name + "_bn_offset"), + moving_mean_name=name + "_bn_mean", + moving_variance_name=name + "_bn_variance") + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.if_act: + if self.act == "relu": + x = F.relu(x) + elif self.act == "hardswish": + x = F.hardswish(x) + else: + print("The activation function({}) is selected incorrectly.". + format(self.act)) + exit() + return x + + +class ResidualUnit(nn.Layer): + def __init__(self, + in_channels, + mid_channels, + out_channels, + kernel_size, + stride, + use_se, + act=None, + name=''): + super(ResidualUnit, self).__init__() + self.if_shortcut = stride == 1 and in_channels == out_channels + self.if_se = use_se + + self.expand_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + if_act=True, + act=act, + name=name + "_expand") + self.bottleneck_conv = ConvBNLayer( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=int((kernel_size - 1) // 2), + groups=mid_channels, + if_act=True, + act=act, + name=name + "_depthwise") + if self.if_se: + self.mid_se = SEModule(mid_channels, name=name + "_se") + self.linear_conv = ConvBNLayer( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + if_act=False, + act=None, + name=name + "_linear") + + def forward(self, inputs): + x = self.expand_conv(inputs) + x = self.bottleneck_conv(x) + if self.if_se: + x = self.mid_se(x) + x = self.linear_conv(x) + if self.if_shortcut: + x = paddle.add(inputs, x) + return x + + +class SEModule(nn.Layer): + def __init__(self, in_channels, reduction=4, name=""): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.conv1 = nn.Conv2D( + in_channels=in_channels, + out_channels=in_channels // reduction, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(name=name + "_1_weights"), + bias_attr=ParamAttr(name=name + "_1_offset")) + self.conv2 = nn.Conv2D( + in_channels=in_channels // reduction, + out_channels=in_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(name + "_2_weights"), + bias_attr=ParamAttr(name=name + "_2_offset")) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5) + return inputs * outputs \ No newline at end of file diff --git a/ppocr/modeling/backbones/table_resnet_vd.py b/ppocr/modeling/backbones/table_resnet_vd.py new file mode 100644 index 0000000000000000000000000000000000000000..1c07c2684eec8d0c4a445cc88c543bfe1da9c864 --- /dev/null +++ b/ppocr/modeling/backbones/table_resnet_vd.py @@ -0,0 +1,280 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ["ResNet"] + + +class ConvBNLayer(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None, ): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def forward(self, inputs): + if self.is_vd_mode: + inputs = self._pool2d_avg(inputs) + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class BottleneckBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv2) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv1) + y = F.relu(y) + return y + + +class ResNet(nn.Layer): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, + 1024] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512] + + self.conv1_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=32, + kernel_size=3, + stride=2, + act='relu', + name="conv1_1") + self.conv1_2 = ConvBNLayer( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + act='relu', + name="conv1_2") + self.conv1_3 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=1, + act='relu', + name="conv1_3") + self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + self.stages = [] + self.out_channels = [] + if layers >= 50: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(bottleneck_block) + self.out_channels.append(num_filters[block] * 4) + self.stages.append(nn.Sequential(*block_list)) + else: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(basic_block) + self.out_channels.append(num_filters[block]) + self.stages.append(nn.Sequential(*block_list)) + + def forward(self, inputs): + y = self.conv1_1(inputs) + y = self.conv1_2(y) + y = self.conv1_3(y) + y = self.pool2d_max(y) + out = [] + for block in self.stages: + y = block(y) + out.append(y) + return out diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 4852c7f2d14d72b9e4d59f40532469f7226c966d..5096479415f504aa9f074d55bd9b2e4a31c730b4 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -31,8 +31,10 @@ def build_head(config): from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead'] + 'SRNHead', 'PGHead', 'TableAttentionHead'] + #table head + from .table_att_head import TableAttentionHead module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5c438a358757f597e31ae8ea84a7ab1c22776b --- /dev/null +++ b/ppocr/modeling/heads/table_att_head.py @@ -0,0 +1,240 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + +class TableAttentionHead(nn.Layer): + def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): + super(TableAttentionHead, self).__init__() + self.input_size = in_channels[-1] + self.hidden_size = hidden_size + self.char_num = 280 + self.elem_num = 30 + + self.structure_attention_cell = AttentionGRUCell( + self.input_size, hidden_size, self.elem_num, use_gru=False) + self.structure_generator = nn.Linear(hidden_size, self.elem_num) + self.loc_type = loc_type + self.in_max_len = in_max_len + + if self.loc_type == 1: + self.loc_generator = nn.Linear(hidden_size, 4) + else: + if self.in_max_len == 640: + self.loc_fea_trans = nn.Linear(400, 801) + elif self.in_max_len == 800: + self.loc_fea_trans = nn.Linear(625, 801) + else: + self.loc_fea_trans = nn.Linear(256, 801) + self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, mode='Train'): + # if and else branch are both needed when you want to assign a variable + # if you modify the var in just one branch, then the modification will not work. + fea = inputs[-1] + if len(fea.shape) == 3: + pass + else: + last_shape = int(np.prod(fea.shape[2:])) # gry added + fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) + fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + batch_size = fea.shape[0] + #sp_tokens = targets[2].numpy() + #char_beg_idx, char_end_idx = sp_tokens[0, 0:2] + #elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4] + #elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6] + #max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9] + max_text_length, max_elem_length, max_cell_num = 100, 800, 500 + + hidden = paddle.zeros((batch_size, self.hidden_size)) + output_hiddens = [] + if mode == 'Train' and targets is not None: + structure = targets[0] + for i in range(max_elem_length+1): + elem_onehots = self._char_to_onehot( + structure[:, i], onehot_dim=self.elem_num) + (outputs, hidden), alpha = self.structure_attention_cell( + hidden, fea, elem_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output = paddle.concat(output_hiddens, axis=1) + structure_probs = self.structure_generator(output) + if self.loc_type == 1: + loc_preds = self.loc_generator(output) + loc_preds = F.sigmoid(loc_preds) + else: + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) + else: + temp_elem = paddle.zeros(shape=[batch_size], dtype="int32") + structure_probs = None + loc_preds = None + elem_onehots = None + outputs = None + alpha = None + max_elem_length = paddle.to_tensor(max_elem_length) + i = 0 + while i < max_elem_length+1: + elem_onehots = self._char_to_onehot( + temp_elem, onehot_dim=self.elem_num) + (outputs, hidden), alpha = self.structure_attention_cell( + hidden, fea, elem_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + structure_probs_step = self.structure_generator(outputs) + temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") + i += 1 + + output = paddle.concat(output_hiddens, axis=1) + structure_probs = self.structure_generator(output) + structure_probs = F.softmax(structure_probs) + if self.loc_type == 1: + loc_preds = self.loc_generator(output) + loc_preds = F.sigmoid(loc_preds) + else: + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) + return {'structure_probs':structure_probs, 'loc_preds':loc_preds} + +class AttentionGRUCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionGRUCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha + + +class AttentionLSTM(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(AttentionLSTM, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionLSTMCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = inputs.shape[0] + num_steps = batch_max_length + + hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros( + (batch_size, self.hidden_size))) + output_hiddens = [] + + if targets is not None: + for i in range(num_steps): + # one-hot vectors for a i-th char + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + + hidden = (hidden[1][0], hidden[1][1]) + output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(hidden[0]) + hidden = (hidden[1][0], hidden[1][1]) + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + + next_input = probs_step.argmax(axis=1) + + targets = next_input + + return probs + + +class AttentionLSTMCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionLSTMCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + if not use_gru: + self.rnn = nn.LSTMCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + else: + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index 37a5cf7863cb386884d82ed88c756c9fc06a541d..e97c4f64bdc9acd6729d67a9c6ff7a7563f6c95e 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -21,7 +21,8 @@ def build_neck(config): from .sast_fpn import SASTFPN from .rnn import SequenceEncoder from .pg_fpn import PGFPN - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] + from .table_fpn import TableFPN + support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/modeling/necks/table_fpn.py b/ppocr/modeling/necks/table_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..d72bff4ff14951fd532f516e3a1a8405cedc8f23 --- /dev/null +++ b/ppocr/modeling/necks/table_fpn.py @@ -0,0 +1,119 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class TableFPN(nn.Layer): + def __init__(self, in_channels, out_channels, **kwargs): + super(TableFPN, self).__init__() + self.out_channels = 512 + weight_attr = paddle.nn.initializer.KaimingUniform() + self.in2_conv = nn.Conv2D( + in_channels=in_channels[0], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr( + name='conv2d_51.w_0', initializer=weight_attr), + bias_attr=False) + self.in3_conv = nn.Conv2D( + in_channels=in_channels[1], + out_channels=self.out_channels, + kernel_size=1, + stride = 1, + weight_attr=ParamAttr( + name='conv2d_50.w_0', initializer=weight_attr), + bias_attr=False) + self.in4_conv = nn.Conv2D( + in_channels=in_channels[2], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr( + name='conv2d_49.w_0', initializer=weight_attr), + bias_attr=False) + self.in5_conv = nn.Conv2D( + in_channels=in_channels[3], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr( + name='conv2d_48.w_0', initializer=weight_attr), + bias_attr=False) + self.p5_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr( + name='conv2d_52.w_0', initializer=weight_attr), + bias_attr=False) + self.p4_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr( + name='conv2d_53.w_0', initializer=weight_attr), + bias_attr=False) + self.p3_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr( + name='conv2d_54.w_0', initializer=weight_attr), + bias_attr=False) + self.p2_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr( + name='conv2d_55.w_0', initializer=weight_attr), + bias_attr=False) + self.fuse_conv = nn.Conv2D( + in_channels=self.out_channels * 4, + out_channels=512, + kernel_size=3, + padding=1, + weight_attr=ParamAttr( + name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False) + + def forward(self, x): + c2, c3, c4, c5 = x + + in5 = self.in5_conv(c5) + in4 = self.in4_conv(c4) + in3 = self.in3_conv(c3) + in2 = self.in2_conv(c2) + + out4 = in4 + F.upsample( + in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16 + out3 = in3 + F.upsample( + out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8 + out2 = in2 + F.upsample( + out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4 + + p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1) + p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1) + p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1) + fuse = paddle.concat([in5, p4, p3, p2], axis=1) + fuse_conv = self.fuse_conv(fuse) * 0.005 + return [c5 + fuse_conv] diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 8426bcf2b9a71e0293d912e25f1b617fd18c59fc..9429d6b473421cd526ac17823d3198f5ae0921e0 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -325,8 +325,14 @@ class TableLabelDecode(object): """ """ def __init__(self, + max_text_length, + max_elem_length, + max_cell_num, character_dict_path, **kwargs): + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num list_character, list_elem = self.load_char_elem_dict(character_dict_path) list_character = self.add_special_char(list_character) list_elem = self.add_special_char(list_elem) @@ -363,6 +369,18 @@ class TableLabelDecode(object): list_character = [self.beg_str] + list_character + [self.end_str] return list_character + def get_sp_tokens(self): + char_beg_idx = self.get_beg_end_flag_idx('beg', 'char') + char_end_idx = self.get_beg_end_flag_idx('end', 'char') + elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') + elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') + elem_char_idx1 = self.dict_elem[''] + elem_char_idx2 = self.dict_elem['