diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml
new file mode 100755
index 0000000000000000000000000000000000000000..32164fe30619e3fa3838f6b021d95925e86708c2
--- /dev/null
+++ b/configs/table/table_mv3.yml
@@ -0,0 +1,116 @@
+Global:
+ use_gpu: true
+ epoch_num: 40
+ log_smooth_window: 20
+ print_batch_step: 5
+ save_model_dir: ./output/table_mv3/
+ save_epoch_step: 3
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 400]
+ # if pretrained_model is saved in static mode, load_static_weights must set to True
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path: ppocr/utils/dict/table_structure_dict.txt
+ character_type: en
+ max_text_length: 100
+ max_elem_length: 800
+ max_cell_num: 500
+ infer_mode: False
+ process_total_num: 0
+ process_cut_num: 0
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 5.0
+ lr:
+ learning_rate: 0.0001
+ regularizer:
+ name: 'L2'
+ factor: 0.00000
+
+Architecture:
+ model_type: table
+ algorithm: TableAttn
+ Backbone:
+ name: MobileNetV3
+ scale: 1.0
+ model_name: large
+ Head:
+ name: TableAttentionHead # AttentionHead
+ hidden_size: 256 #
+ l2_decay: 0.00001
+# loc_type: 1
+ loc_type: 2
+
+Loss:
+ name: TableAttentionLoss
+ structure_weight: 100.0
+ loc_weight: 10000.0
+
+PostProcess:
+ name: TableLabelDecode
+
+Metric:
+ name: TableMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: PubTabDataSet
+ data_dir: train_data/table/pubtabnet/train/
+ label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ResizeTableImage:
+ max_len: 488
+ - TableLabelEncode:
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
+ loader:
+ shuffle: True
+ batch_size_per_card: 32
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: PubTabDataSet
+ data_dir: train_data/table/pubtabnet/val/
+ label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ResizeTableImage:
+ max_len: 488
+ - TableLabelEncode:
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 16
+ num_workers: 4
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 728b8317f54687ee76b519cba18f4d7807493821..e860c5a6986f495e6384d9df93c24795c04a0d5f 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet
+from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
- support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
+ support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index bba3209f7560f19b74a54c102caf697319814803..cd883d1b433701f27044eb76675b07d9ea234d00 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -351,3 +351,182 @@ class SRNLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+class TableLabelEncode(object):
+ """ Convert between text-label and text-index """
+ def __init__(self,
+ max_text_length,
+ max_elem_length,
+ max_cell_num,
+ character_dict_path,
+ span_weight = 1.0,
+ **kwargs):
+ self.max_text_length = max_text_length
+ self.max_elem_length = max_elem_length
+ self.max_cell_num = max_cell_num
+ list_character, list_elem = self.load_char_elem_dict(character_dict_path)
+ list_character = self.add_special_char(list_character)
+ list_elem = self.add_special_char(list_elem)
+ self.dict_character = {}
+ for i, char in enumerate(list_character):
+ self.dict_character[char] = i
+ self.dict_elem = {}
+ for i, elem in enumerate(list_elem):
+ self.dict_elem[elem] = i
+ self.span_weight = span_weight
+
+ def load_char_elem_dict(self, character_dict_path):
+ list_character = []
+ list_elem = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ substr = lines[0].decode('utf-8').strip("\n").split("\t")
+ character_num = int(substr[0])
+ elem_num = int(substr[1])
+ for cno in range(1, 1+character_num):
+ character = lines[cno].decode('utf-8').strip("\n")
+ list_character.append(character)
+ for eno in range(1+character_num, 1+character_num+elem_num):
+ elem = lines[eno].decode('utf-8').strip("\n")
+ list_elem.append(elem)
+ return list_character, list_elem
+
+ def add_special_char(self, list_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ list_character = [self.beg_str] + list_character + [self.end_str]
+ return list_character
+
+ def get_span_idx_list(self):
+ span_idx_list = []
+ for elem in self.dict_elem:
+ if 'span' in elem:
+ span_idx_list.append(self.dict_elem[elem])
+ return span_idx_list
+
+ def __call__(self, data):
+ cells = data['cells']
+ structure = data['structure']['tokens']
+ structure = self.encode(structure, 'elem')
+ if structure is None:
+ return None
+ elem_num = len(structure)
+ structure = [0] + structure + [len(self.dict_elem) - 1]
+# structure = [0] + structure + [0]
+ structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
+ structure = np.array(structure)
+ data['structure'] = structure
+ elem_char_idx1 = self.dict_elem['
']
+ elem_char_idx2 = self.dict_elem[' | 0:
+ span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
+ span_weight = min(max(span_weight, 1.0), self.span_weight)
+ for cno in range(len(cells)):
+ if 'bbox' in cells[cno]:
+ bbox = cells[cno]['bbox'].copy()
+ bbox[0] = bbox[0] * 1.0 / img_width
+ bbox[1] = bbox[1] * 1.0 / img_height
+ bbox[2] = bbox[2] * 1.0 / img_width
+ bbox[3] = bbox[3] * 1.0 / img_height
+ td_idx = td_idx_list[cno]
+ bbox_list[td_idx] = bbox
+ bbox_list_mask[td_idx] = 1.0
+ cand_span_idx = td_idx + 1
+ if cand_span_idx < (self.max_elem_length + 2):
+ if structure[cand_span_idx] in span_idx_list:
+ structure_mask[cand_span_idx] = span_weight
+# structure_mask[td_idx] = self.span_weight
+# structure_mask[cand_span_idx] = self.span_weight
+
+ data['bbox_list'] = bbox_list
+ data['bbox_list_mask'] = bbox_list_mask
+ data['structure_mask'] = structure_mask
+ char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
+ char_end_idx = self.get_beg_end_flag_idx('end', 'char')
+ elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
+ elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
+ data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
+ elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
+ self.max_elem_length, self.max_cell_num, elem_num])
+ return data
+
+ ########
+ # for char decode
+# cell_list = []
+# for cell in cells:
+# char_list = cell['tokens']
+# cell = self.encode(char_list, 'char')
+# if cell is None:
+# return None
+# cell = [0] + cell + [len(self.dict_character) - 1]
+# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
+# cell_list.append(cell)
+# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
+# cell_list = np.array(cell_list)
+# cell_list_padding[0:cell_list.shape[0]] = cell_list
+# data['cells'] = cell_list_padding
+# return data
+
+ def encode(self, text, char_or_elem):
+ """convert text-label into text-index.
+ """
+ if char_or_elem == "char":
+ max_len = self.max_text_length
+ current_dict = self.dict_character
+ else:
+ max_len = self.max_elem_length
+ current_dict = self.dict_elem
+ if len(text) > max_len:
+ return None
+ if len(text) == 0:
+ if char_or_elem == "char":
+ return [self.dict_character['space']]
+ else:
+ return None
+ text_list = []
+ for char in text:
+ if char not in current_dict:
+ return None
+ text_list.append(current_dict[char])
+ if len(text_list) == 0:
+ if char_or_elem == "char":
+ return [self.dict_character['space']]
+ else:
+ return None
+ return text_list
+
+ def get_ignored_tokens(self, char_or_elem):
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
+ if char_or_elem == "char":
+ if beg_or_end == "beg":
+ idx = np.array(self.dict_character[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict_character[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
+ % beg_or_end
+ elif char_or_elem == "elem":
+ if beg_or_end == "beg":
+ idx = np.array(self.dict_elem[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict_elem[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
+ % beg_or_end
+ else:
+ assert False, "Unsupport type %s in char_or_elem" \
+ % char_or_elem
+ return idx
+
\ No newline at end of file
diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2c3eebf7b480219f71bf7ec04375029aa7db613
--- /dev/null
+++ b/ppocr/data/pubtab_dataset.py
@@ -0,0 +1,125 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+import random
+from paddle.io import Dataset
+import json
+
+from .imaug import transform, create_operators
+
+class PubTabDataSet(Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ super(PubTabDataSet, self).__init__()
+ self.logger = logger
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+
+ label_file_path = dataset_config.pop('label_file_path')
+
+ self.data_dir = dataset_config['data_dir']
+ self.do_shuffle = loader_config['shuffle']
+ self.do_hard_select = False
+ if 'hard_select' in loader_config:
+ self.do_hard_select = loader_config['hard_select']
+ self.hard_prob = loader_config['hard_prob']
+ if self.do_hard_select:
+ self.img_select_prob = self.load_hard_select_prob()
+ self.table_select_type = None
+ if 'table_select_type' in loader_config:
+ self.table_select_type = loader_config['table_select_type']
+ self.table_select_prob = loader_config['table_select_prob']
+
+ self.seed = seed
+ logger.info("Initialize indexs of datasets:%s" % label_file_path)
+ with open(label_file_path, "rb") as f:
+ self.data_lines = f.readlines()
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+ if mode.lower() == "train":
+ self.shuffle_data_random()
+ self.ops = create_operators(dataset_config['transforms'], global_config)
+
+ def shuffle_data_random(self):
+ if self.do_shuffle:
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
+ return
+
+ def load_hard_select_prob(self):
+ label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
+ img_select_prob = {}
+ with open(label_path, "rb") as fin:
+ lines = fin.readlines()
+ for lno in range(len(lines)):
+ substr = lines[lno].decode('utf-8').strip("\n").split(" ")
+ img_name = substr[0].strip(":")
+ score = float(substr[1])
+ if score <= 0.8:
+ img_select_prob[img_name] = self.hard_prob[0]
+ elif score <= 0.98:
+ img_select_prob[img_name] = self.hard_prob[1]
+ else:
+ img_select_prob[img_name] = self.hard_prob[2]
+ return img_select_prob
+
+ def __getitem__(self, idx):
+ try:
+ data_line = self.data_lines[idx]
+ data_line = data_line.decode('utf-8').strip("\n")
+ info = json.loads(data_line)
+ file_name = info['filename']
+ select_flag = True
+ if self.do_hard_select:
+ prob = self.img_select_prob[file_name]
+ if prob < random.uniform(0, 1):
+ select_flag = False
+
+ if self.table_select_type:
+ structure = info['html']['structure']['tokens'].copy()
+ structure_str = ''.join(structure)
+ table_type = "simple"
+ if 'colspan' in structure_str or 'rowspan' in structure_str:
+ table_type = "complex"
+# if self.table_select_type != table_type:
+# select_flag = False
+ if table_type == "complex":
+ if self.table_select_prob < random.uniform(0, 1):
+ select_flag = False
+
+ if select_flag:
+ cells = info['html']['cells'].copy()
+ structure = info['html']['structure'].copy()
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {'img_path': img_path, 'cells': cells, 'structure':structure}
+ if not os.path.exists(img_path):
+ raise Exception("{} does not exist!".format(img_path))
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ outs = transform(data, self.ops)
+ else:
+ outs = None
+ except Exception as e:
+ self.logger.error(
+ "When parsing line {}, error happened with msg: {}".format(
+ data_line, e))
+ outs = None
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return len(self.data_idx_order_list)
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index bf10d2982dcdd36021a7385ab8828398b51af3d3..025ae7ca5cc604eea59423ca7f523c37c1492e35 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -38,11 +38,13 @@ from .basic_loss import DistanceLoss
# combined loss function
from .combined_loss import CombinedLoss
+# table loss
+from .table_att_loss import TableAttentionLoss
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
- 'SRNLoss', 'PGLoss', 'CombinedLoss'
+ 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7fd99e6952aacc0182a482ca5ae5ddaf959a026
--- /dev/null
+++ b/ppocr/losses/table_att_loss.py
@@ -0,0 +1,109 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+from paddle import fluid
+
+class TableAttentionLoss(nn.Layer):
+ def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
+ super(TableAttentionLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
+ self.structure_weight = structure_weight
+ self.loc_weight = loc_weight
+ self.use_giou = use_giou
+ self.giou_weight = giou_weight
+
+ def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
+ '''
+ :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
+ :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
+ :return: loss
+ '''
+ ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
+ iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
+ ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
+ iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
+
+ iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
+ ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
+
+ # overlap
+ inters = iw * ih
+
+ # union
+ uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
+ ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
+ bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
+
+ # ious
+ ious = inters / uni
+
+ ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
+ ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
+ ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
+ ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
+ ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
+ eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
+
+ # enclose erea
+ enclose = ew * eh + eps
+ giou = ious - (enclose - uni) / enclose
+
+ loss = 1 - giou
+
+ if reduction == 'mean':
+ loss = paddle.mean(loss)
+ elif reduction == 'sum':
+ loss = paddle.sum(loss)
+ else:
+ raise NotImplementedError
+ return loss
+
+ def forward(self, predicts, batch):
+ structure_probs = predicts['structure_probs']
+ structure_targets = batch[1].astype("int64")
+ structure_targets = structure_targets[:, 1:]
+ if len(batch) == 6:
+ structure_mask = batch[5].astype("int64")
+ structure_mask = structure_mask[:, 1:]
+ structure_mask = paddle.reshape(structure_mask, [-1])
+ structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
+ structure_targets = paddle.reshape(structure_targets, [-1])
+ structure_loss = self.loss_func(structure_probs, structure_targets)
+
+ if len(batch) == 6:
+ structure_loss = structure_loss * structure_mask
+
+# structure_loss = paddle.sum(structure_loss) * self.structure_weight
+ structure_loss = paddle.mean(structure_loss) * self.structure_weight
+
+ loc_preds = predicts['loc_preds']
+ loc_targets = batch[2].astype("float32")
+ loc_targets_mask = batch[4].astype("float32")
+ loc_targets = loc_targets[:, 1:, :]
+ loc_targets_mask = loc_targets_mask[:, 1:, :]
+ loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
+ if self.use_giou:
+ loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
+ total_loss = structure_loss + loc_loss + loc_loss_giou
+ return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
+ else:
+ total_loss = structure_loss + loc_loss
+ return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
\ No newline at end of file
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 9e9060fa999bd3175c31dfc0797cd293d4e7afec..64f62e51cdf922773c03bb784a4edffdc17f506f 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -26,11 +26,11 @@ from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
-
+from .table_metric import TableMetric
def build_metric(config):
support_dict = [
- "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
+ "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d1c789ecc3979bd4c33620af91ccd28012f7a8
--- /dev/null
+++ b/ppocr/metrics/table_metric.py
@@ -0,0 +1,50 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+class TableMetric(object):
+ def __init__(self, main_indicator='acc', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, pred, batch, *args, **kwargs):
+ structure_probs = pred['structure_probs'].numpy()
+ structure_labels = batch[1]
+ correct_num = 0
+ all_num = 0
+ structure_probs = np.argmax(structure_probs, axis=2)
+ structure_labels = structure_labels[:, 1:]
+ batch_size = structure_probs.shape[0]
+ for bno in range(batch_size):
+ all_num += 1
+ if (structure_probs[bno] == structure_labels[bno]).all():
+ correct_num += 1
+ self.correct_num += correct_num
+ self.all_num += all_num
+ return {
+ 'acc': correct_num * 1.0 / all_num,
+ }
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ }
+ """
+ acc = 1.0 * self.correct_num / self.all_num
+ self.reset()
+ return {'acc': acc}
+
+ def reset(self):
+ self.correct_num = 0
+ self.all_num = 0
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index 4c941fcf65573d9314c0badda49895d0b6b5c4f9..49160b52898a50984c3036d4dea48513ca53bb0d 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -69,7 +69,7 @@ class BaseModel(nn.Layer):
self.return_all_feats = config.get("return_all_feats", False)
- def forward(self, x, data=None):
+ def forward(self, x, data=None, mode='Train'):
y = dict()
if self.use_transform:
x = self.transform(x)
@@ -81,7 +81,10 @@ class BaseModel(nn.Layer):
if data is None:
x = self.head(x)
else:
- x = self.head(x, data)
+ if mode == 'Eval' or mode == 'Test':
+ x = self.head(x, targets=data, mode=mode)
+ else:
+ x = self.head(x, targets=data)
y["head_out"] = x
if self.return_all_feats:
return y
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index fe2c9bc30a4f2abd1ba7d3d6989b9ef9b20c1f4f..13b70b203371b3be58ee82c6808d744bf6098333 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -29,6 +29,10 @@ def build_backbone(config, model_type):
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
+ elif model_type == "table":
+ from .table_resnet_vd import ResNet
+ from .table_mobilenet_v3 import MobileNetV3
+ support_dict = ['ResNet', 'MobileNetV3']
else:
raise NotImplementedError
diff --git a/ppocr/modeling/backbones/table_mobilenet_v3.py b/ppocr/modeling/backbones/table_mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa87f976038d8d5eeafadceb869b9232ba22cd9
--- /dev/null
+++ b/ppocr/modeling/backbones/table_mobilenet_v3.py
@@ -0,0 +1,287 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+__all__ = ['MobileNetV3']
+
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class MobileNetV3(nn.Layer):
+ def __init__(self,
+ in_channels=3,
+ model_name='large',
+ scale=0.5,
+ disable_se=False,
+ **kwargs):
+ """
+ the MobilenetV3 backbone network for detection module.
+ Args:
+ params(dict): the super parameters for build network
+ """
+ super(MobileNetV3, self).__init__()
+
+ self.disable_se = disable_se
+
+ if model_name == "large":
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, False, 'relu', 1],
+ [3, 64, 24, False, 'relu', 2],
+ [3, 72, 24, False, 'relu', 1],
+ [5, 72, 40, True, 'relu', 2],
+ [5, 120, 40, True, 'relu', 1],
+ [5, 120, 40, True, 'relu', 1],
+ [3, 240, 80, False, 'hardswish', 2],
+ [3, 200, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 480, 112, True, 'hardswish', 1],
+ [3, 672, 112, True, 'hardswish', 1],
+ [5, 672, 160, True, 'hardswish', 2],
+ [5, 960, 160, True, 'hardswish', 1],
+ [5, 960, 160, True, 'hardswish', 1],
+ ]
+ cls_ch_squeeze = 960
+ elif model_name == "small":
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, True, 'relu', 2],
+ [3, 72, 24, False, 'relu', 2],
+ [3, 88, 24, False, 'relu', 1],
+ [5, 96, 40, True, 'hardswish', 2],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 120, 48, True, 'hardswish', 1],
+ [5, 144, 48, True, 'hardswish', 1],
+ [5, 288, 96, True, 'hardswish', 2],
+ [5, 576, 96, True, 'hardswish', 1],
+ [5, 576, 96, True, 'hardswish', 1],
+ ]
+ cls_ch_squeeze = 576
+ else:
+ raise NotImplementedError("mode[" + model_name +
+ "_model] is not implemented!")
+
+ supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+ assert scale in supported_scale, \
+ "supported scale are {} but input scale is {}".format(supported_scale, scale)
+ inplanes = 16
+ # conv1
+ self.conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=make_divisible(inplanes * scale),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=1,
+ if_act=True,
+ act='hardswish',
+ name='conv1')
+
+ self.stages = []
+ self.out_channels = []
+ block_list = []
+ i = 0
+ inplanes = make_divisible(inplanes * scale)
+ for (k, exp, c, se, nl, s) in cfg:
+ se = se and not self.disable_se
+ start_idx = 2 if model_name == 'large' else 0
+ if s == 2 and i > start_idx:
+ self.out_channels.append(inplanes)
+ self.stages.append(nn.Sequential(*block_list))
+ block_list = []
+ block_list.append(
+ ResidualUnit(
+ in_channels=inplanes,
+ mid_channels=make_divisible(scale * exp),
+ out_channels=make_divisible(scale * c),
+ kernel_size=k,
+ stride=s,
+ use_se=se,
+ act=nl,
+ name="conv" + str(i + 2)))
+ inplanes = make_divisible(scale * c)
+ i += 1
+ block_list.append(
+ ConvBNLayer(
+ in_channels=inplanes,
+ out_channels=make_divisible(scale * cls_ch_squeeze),
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ if_act=True,
+ act='hardswish',
+ name='conv_last'))
+ self.stages.append(nn.Sequential(*block_list))
+ self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
+ for i, stage in enumerate(self.stages):
+ self.add_sublayer(sublayer=stage, name="stage{}".format(i))
+
+ def forward(self, x):
+ x = self.conv(x)
+ out_list = []
+ for stage in self.stages:
+ x = stage(x)
+ out_list.append(x)
+ return out_list
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+ self.if_act = if_act
+ self.act = act
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '_weights'),
+ bias_attr=False)
+
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=None,
+ param_attr=ParamAttr(name=name + "_bn_scale"),
+ bias_attr=ParamAttr(name=name + "_bn_offset"),
+ moving_mean_name=name + "_bn_mean",
+ moving_variance_name=name + "_bn_variance")
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.if_act:
+ if self.act == "relu":
+ x = F.relu(x)
+ elif self.act == "hardswish":
+ x = F.hardswish(x)
+ else:
+ print("The activation function({}) is selected incorrectly.".
+ format(self.act))
+ exit()
+ return x
+
+
+class ResidualUnit(nn.Layer):
+ def __init__(self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ use_se,
+ act=None,
+ name=''):
+ super(ResidualUnit, self).__init__()
+ self.if_shortcut = stride == 1 and in_channels == out_channels
+ self.if_se = use_se
+
+ self.expand_conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ if_act=True,
+ act=act,
+ name=name + "_expand")
+ self.bottleneck_conv = ConvBNLayer(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=int((kernel_size - 1) // 2),
+ groups=mid_channels,
+ if_act=True,
+ act=act,
+ name=name + "_depthwise")
+ if self.if_se:
+ self.mid_se = SEModule(mid_channels, name=name + "_se")
+ self.linear_conv = ConvBNLayer(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ if_act=False,
+ act=None,
+ name=name + "_linear")
+
+ def forward(self, inputs):
+ x = self.expand_conv(inputs)
+ x = self.bottleneck_conv(x)
+ if self.if_se:
+ x = self.mid_se(x)
+ x = self.linear_conv(x)
+ if self.if_shortcut:
+ x = paddle.add(inputs, x)
+ return x
+
+
+class SEModule(nn.Layer):
+ def __init__(self, in_channels, reduction=4, name=""):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2D(1)
+ self.conv1 = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=in_channels // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(name=name + "_1_weights"),
+ bias_attr=ParamAttr(name=name + "_1_offset"))
+ self.conv2 = nn.Conv2D(
+ in_channels=in_channels // reduction,
+ out_channels=in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(name + "_2_weights"),
+ bias_attr=ParamAttr(name=name + "_2_offset"))
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
+ return inputs * outputs
\ No newline at end of file
diff --git a/ppocr/modeling/backbones/table_resnet_vd.py b/ppocr/modeling/backbones/table_resnet_vd.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c07c2684eec8d0c4a445cc88c543bfe1da9c864
--- /dev/null
+++ b/ppocr/modeling/backbones/table_resnet_vd.py
@@ -0,0 +1,280 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+__all__ = ["ResNet"]
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None, ):
+ super(ConvBNLayer, self).__init__()
+
+ self.is_vd_mode = is_vd_mode
+ self._pool2d_avg = nn.AvgPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self._conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False)
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self._batch_norm = nn.BatchNorm(
+ out_channels,
+ act=act,
+ param_attr=ParamAttr(name=bn_name + '_scale'),
+ bias_attr=ParamAttr(bn_name + '_offset'),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance')
+
+ def forward(self, inputs):
+ if self.is_vd_mode:
+ inputs = self._pool2d_avg(inputs)
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2b")
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ act=None,
+ name=name + "_branch2c")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ stride=1,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv2)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ name=name + "_branch2b")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv1)
+ y = F.relu(y)
+ return y
+
+
+class ResNet(nn.Layer):
+ def __init__(self, in_channels=3, layers=50, **kwargs):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ supported_layers = [18, 34, 50, 101, 152, 200]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ elif layers == 200:
+ depth = [3, 12, 48, 3]
+ num_channels = [64, 256, 512,
+ 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_filters = [64, 128, 256, 512]
+
+ self.conv1_1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=32,
+ kernel_size=3,
+ stride=2,
+ act='relu',
+ name="conv1_1")
+ self.conv1_2 = ConvBNLayer(
+ in_channels=32,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv1_2")
+ self.conv1_3 = ConvBNLayer(
+ in_channels=32,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv1_3")
+ self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+ self.stages = []
+ self.out_channels = []
+ if layers >= 50:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ bottleneck_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BottleneckBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block] * 4,
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(bottleneck_block)
+ self.out_channels.append(num_filters[block] * 4)
+ self.stages.append(nn.Sequential(*block_list))
+ else:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ basic_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BasicBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block],
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(basic_block)
+ self.out_channels.append(num_filters[block])
+ self.stages.append(nn.Sequential(*block_list))
+
+ def forward(self, inputs):
+ y = self.conv1_1(inputs)
+ y = self.conv1_2(y)
+ y = self.conv1_3(y)
+ y = self.pool2d_max(y)
+ out = []
+ for block in self.stages:
+ y = block(y)
+ out.append(y)
+ return out
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 4852c7f2d14d72b9e4d59f40532469f7226c966d..5096479415f504aa9f074d55bd9b2e4a31c730b4 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -31,8 +31,10 @@ def build_head(config):
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
- 'SRNHead', 'PGHead']
+ 'SRNHead', 'PGHead', 'TableAttentionHead']
+ #table head
+ from .table_att_head import TableAttentionHead
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e5c438a358757f597e31ae8ea84a7ab1c22776b
--- /dev/null
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -0,0 +1,240 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+class TableAttentionHead(nn.Layer):
+ def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
+ super(TableAttentionHead, self).__init__()
+ self.input_size = in_channels[-1]
+ self.hidden_size = hidden_size
+ self.char_num = 280
+ self.elem_num = 30
+
+ self.structure_attention_cell = AttentionGRUCell(
+ self.input_size, hidden_size, self.elem_num, use_gru=False)
+ self.structure_generator = nn.Linear(hidden_size, self.elem_num)
+ self.loc_type = loc_type
+ self.in_max_len = in_max_len
+
+ if self.loc_type == 1:
+ self.loc_generator = nn.Linear(hidden_size, 4)
+ else:
+ if self.in_max_len == 640:
+ self.loc_fea_trans = nn.Linear(400, 801)
+ elif self.in_max_len == 800:
+ self.loc_fea_trans = nn.Linear(625, 801)
+ else:
+ self.loc_fea_trans = nn.Linear(256, 801)
+ self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None, mode='Train'):
+ # if and else branch are both needed when you want to assign a variable
+ # if you modify the var in just one branch, then the modification will not work.
+ fea = inputs[-1]
+ if len(fea.shape) == 3:
+ pass
+ else:
+ last_shape = int(np.prod(fea.shape[2:])) # gry added
+ fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
+ fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
+ batch_size = fea.shape[0]
+ #sp_tokens = targets[2].numpy()
+ #char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
+ #elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
+ #elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
+ #max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
+ max_text_length, max_elem_length, max_cell_num = 100, 800, 500
+
+ hidden = paddle.zeros((batch_size, self.hidden_size))
+ output_hiddens = []
+ if mode == 'Train' and targets is not None:
+ structure = targets[0]
+ for i in range(max_elem_length+1):
+ elem_onehots = self._char_to_onehot(
+ structure[:, i], onehot_dim=self.elem_num)
+ (outputs, hidden), alpha = self.structure_attention_cell(
+ hidden, fea, elem_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ structure_probs = self.structure_generator(output)
+ if self.loc_type == 1:
+ loc_preds = self.loc_generator(output)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
+ structure_probs = None
+ loc_preds = None
+ elem_onehots = None
+ outputs = None
+ alpha = None
+ max_elem_length = paddle.to_tensor(max_elem_length)
+ i = 0
+ while i < max_elem_length+1:
+ elem_onehots = self._char_to_onehot(
+ temp_elem, onehot_dim=self.elem_num)
+ (outputs, hidden), alpha = self.structure_attention_cell(
+ hidden, fea, elem_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ structure_probs_step = self.structure_generator(outputs)
+ temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
+ i += 1
+
+ output = paddle.concat(output_hiddens, axis=1)
+ structure_probs = self.structure_generator(output)
+ structure_probs = F.softmax(structure_probs)
+ if self.loc_type == 1:
+ loc_preds = self.loc_generator(output)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
+ return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
+
+class AttentionGRUCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionGRUCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+ return cur_hidden, alpha
+
+
+class AttentionLSTM(nn.Layer):
+ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
+ super(AttentionLSTM, self).__init__()
+ self.input_size = in_channels
+ self.hidden_size = hidden_size
+ self.num_classes = out_channels
+
+ self.attention_cell = AttentionLSTMCell(
+ in_channels, hidden_size, out_channels, use_gru=False)
+ self.generator = nn.Linear(hidden_size, out_channels)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None, batch_max_length=25):
+ batch_size = inputs.shape[0]
+ num_steps = batch_max_length
+
+ hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
+ (batch_size, self.hidden_size)))
+ output_hiddens = []
+
+ if targets is not None:
+ for i in range(num_steps):
+ # one-hot vectors for a i-th char
+ char_onehots = self._char_to_onehot(
+ targets[:, i], onehot_dim=self.num_classes)
+ hidden, alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+
+ hidden = (hidden[1][0], hidden[1][1])
+ output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ probs = self.generator(output)
+
+ else:
+ targets = paddle.zeros(shape=[batch_size], dtype="int32")
+ probs = None
+
+ for i in range(num_steps):
+ char_onehots = self._char_to_onehot(
+ targets, onehot_dim=self.num_classes)
+ hidden, alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+ probs_step = self.generator(hidden[0])
+ hidden = (hidden[1][0], hidden[1][1])
+ if probs is None:
+ probs = paddle.unsqueeze(probs_step, axis=1)
+ else:
+ probs = paddle.concat(
+ [probs, paddle.unsqueeze(
+ probs_step, axis=1)], axis=1)
+
+ next_input = probs_step.argmax(axis=1)
+
+ targets = next_input
+
+ return probs
+
+
+class AttentionLSTMCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionLSTMCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ if not use_gru:
+ self.rnn = nn.LSTMCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ else:
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+
+ return cur_hidden, alpha
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index 37a5cf7863cb386884d82ed88c756c9fc06a541d..e97c4f64bdc9acd6729d67a9c6ff7a7563f6c95e 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -21,7 +21,8 @@ def build_neck(config):
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
from .pg_fpn import PGFPN
- support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
+ from .table_fpn import TableFPN
+ support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
diff --git a/ppocr/modeling/necks/table_fpn.py b/ppocr/modeling/necks/table_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d72bff4ff14951fd532f516e3a1a8405cedc8f23
--- /dev/null
+++ b/ppocr/modeling/necks/table_fpn.py
@@ -0,0 +1,119 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class TableFPN(nn.Layer):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(TableFPN, self).__init__()
+ self.out_channels = 512
+ weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.in2_conv = nn.Conv2D(
+ in_channels=in_channels[0],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(
+ name='conv2d_51.w_0', initializer=weight_attr),
+ bias_attr=False)
+ self.in3_conv = nn.Conv2D(
+ in_channels=in_channels[1],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride = 1,
+ weight_attr=ParamAttr(
+ name='conv2d_50.w_0', initializer=weight_attr),
+ bias_attr=False)
+ self.in4_conv = nn.Conv2D(
+ in_channels=in_channels[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(
+ name='conv2d_49.w_0', initializer=weight_attr),
+ bias_attr=False)
+ self.in5_conv = nn.Conv2D(
+ in_channels=in_channels[3],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(
+ name='conv2d_48.w_0', initializer=weight_attr),
+ bias_attr=False)
+ self.p5_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(
+ name='conv2d_52.w_0', initializer=weight_attr),
+ bias_attr=False)
+ self.p4_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(
+ name='conv2d_53.w_0', initializer=weight_attr),
+ bias_attr=False)
+ self.p3_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(
+ name='conv2d_54.w_0', initializer=weight_attr),
+ bias_attr=False)
+ self.p2_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(
+ name='conv2d_55.w_0', initializer=weight_attr),
+ bias_attr=False)
+ self.fuse_conv = nn.Conv2D(
+ in_channels=self.out_channels * 4,
+ out_channels=512,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(
+ name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.in5_conv(c5)
+ in4 = self.in4_conv(c4)
+ in3 = self.in3_conv(c3)
+ in2 = self.in2_conv(c2)
+
+ out4 = in4 + F.upsample(
+ in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
+ out3 = in3 + F.upsample(
+ out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
+ out2 = in2 + F.upsample(
+ out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
+
+ p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ fuse = paddle.concat([in5, p4, p3, p2], axis=1)
+ fuse_conv = self.fuse_conv(fuse) * 0.005
+ return [c5 + fuse_conv]
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 8426bcf2b9a71e0293d912e25f1b617fd18c59fc..9429d6b473421cd526ac17823d3198f5ae0921e0 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -325,8 +325,14 @@ class TableLabelDecode(object):
""" """
def __init__(self,
+ max_text_length,
+ max_elem_length,
+ max_cell_num,
character_dict_path,
**kwargs):
+ self.max_text_length = max_text_length
+ self.max_elem_length = max_elem_length
+ self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
@@ -363,6 +369,18 @@ class TableLabelDecode(object):
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
+ def get_sp_tokens(self):
+ char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
+ char_end_idx = self.get_beg_end_flag_idx('end', 'char')
+ elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
+ elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
+ elem_char_idx1 = self.dict_elem[' | ']
+ elem_char_idx2 = self.dict_elem[' | |