diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml
new file mode 100755
index 0000000000000000000000000000000000000000..a74e18d318699685400cc48430c04db3fef70b60
--- /dev/null
+++ b/configs/table/table_mv3.yml
@@ -0,0 +1,116 @@
+Global:
+ use_gpu: true
+ epoch_num: 50
+ log_smooth_window: 20
+ print_batch_step: 5
+ save_model_dir: ./output/table_mv3/
+ save_epoch_step: 5
+ # evaluation is run every 400 iterations after the 0th iteration
+ eval_batch_step: [0, 400]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path: ppocr/utils/dict/table_structure_dict.txt
+ character_type: en
+ max_text_length: 100
+ max_elem_length: 500
+ max_cell_num: 500
+ infer_mode: False
+ process_total_num: 0
+ process_cut_num: 0
+
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 5.0
+ lr:
+ learning_rate: 0.001
+ regularizer:
+ name: 'L2'
+ factor: 0.00000
+
+Architecture:
+ model_type: table
+ algorithm: TableAttn
+ Backbone:
+ name: MobileNetV3
+ scale: 1.0
+ model_name: small
+ disable_se: True
+ Head:
+ name: TableAttentionHead
+ hidden_size: 256
+ l2_decay: 0.00001
+ loc_type: 2
+
+Loss:
+ name: TableAttentionLoss
+ structure_weight: 100.0
+ loc_weight: 10000.0
+
+PostProcess:
+ name: TableLabelDecode
+
+Metric:
+ name: TableMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: PubTabDataSet
+ data_dir: train_data/table/pubtabnet/train/
+ label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ResizeTableImage:
+ max_len: 488
+ - TableLabelEncode:
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
+ loader:
+ shuffle: True
+ batch_size_per_card: 32
+ drop_last: True
+ num_workers: 1
+
+Eval:
+ dataset:
+ name: PubTabDataSet
+ data_dir: train_data/table/pubtabnet/val/
+ label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ResizeTableImage:
+ max_len: 488
+ - TableLabelEncode:
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 16
+ num_workers: 1
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 728b8317f54687ee76b519cba18f4d7807493821..e860c5a6986f495e6384d9df93c24795c04a0d5f 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet
+from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
- support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
+ support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index bba3209f7560f19b74a54c102caf697319814803..e25cce79b553f127afc0167f18b6f663ceb617d7 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -351,3 +351,162 @@ class SRNLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+class TableLabelEncode(object):
+ """ Convert between text-label and text-index """
+ def __init__(self,
+ max_text_length,
+ max_elem_length,
+ max_cell_num,
+ character_dict_path,
+ span_weight = 1.0,
+ **kwargs):
+ self.max_text_length = max_text_length
+ self.max_elem_length = max_elem_length
+ self.max_cell_num = max_cell_num
+ list_character, list_elem = self.load_char_elem_dict(character_dict_path)
+ list_character = self.add_special_char(list_character)
+ list_elem = self.add_special_char(list_elem)
+ self.dict_character = {}
+ for i, char in enumerate(list_character):
+ self.dict_character[char] = i
+ self.dict_elem = {}
+ for i, elem in enumerate(list_elem):
+ self.dict_elem[elem] = i
+ self.span_weight = span_weight
+
+ def load_char_elem_dict(self, character_dict_path):
+ list_character = []
+ list_elem = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ substr = lines[0].decode('utf-8').strip("\n").split("\t")
+ character_num = int(substr[0])
+ elem_num = int(substr[1])
+ for cno in range(1, 1+character_num):
+ character = lines[cno].decode('utf-8').strip("\n")
+ list_character.append(character)
+ for eno in range(1+character_num, 1+character_num+elem_num):
+ elem = lines[eno].decode('utf-8').strip("\n")
+ list_elem.append(elem)
+ return list_character, list_elem
+
+ def add_special_char(self, list_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ list_character = [self.beg_str] + list_character + [self.end_str]
+ return list_character
+
+ def get_span_idx_list(self):
+ span_idx_list = []
+ for elem in self.dict_elem:
+ if 'span' in elem:
+ span_idx_list.append(self.dict_elem[elem])
+ return span_idx_list
+
+ def __call__(self, data):
+ cells = data['cells']
+ structure = data['structure']['tokens']
+ structure = self.encode(structure, 'elem')
+ if structure is None:
+ return None
+ elem_num = len(structure)
+ structure = [0] + structure + [len(self.dict_elem) - 1]
+ structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
+ structure = np.array(structure)
+ data['structure'] = structure
+ elem_char_idx1 = self.dict_elem['
']
+ elem_char_idx2 = self.dict_elem[' | 0:
+ span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
+ span_weight = min(max(span_weight, 1.0), self.span_weight)
+ for cno in range(len(cells)):
+ if 'bbox' in cells[cno]:
+ bbox = cells[cno]['bbox'].copy()
+ bbox[0] = bbox[0] * 1.0 / img_width
+ bbox[1] = bbox[1] * 1.0 / img_height
+ bbox[2] = bbox[2] * 1.0 / img_width
+ bbox[3] = bbox[3] * 1.0 / img_height
+ td_idx = td_idx_list[cno]
+ bbox_list[td_idx] = bbox
+ bbox_list_mask[td_idx] = 1.0
+ cand_span_idx = td_idx + 1
+ if cand_span_idx < (self.max_elem_length + 2):
+ if structure[cand_span_idx] in span_idx_list:
+ structure_mask[cand_span_idx] = span_weight
+
+ data['bbox_list'] = bbox_list
+ data['bbox_list_mask'] = bbox_list_mask
+ data['structure_mask'] = structure_mask
+ char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
+ char_end_idx = self.get_beg_end_flag_idx('end', 'char')
+ elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
+ elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
+ data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
+ elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
+ self.max_elem_length, self.max_cell_num, elem_num])
+ return data
+
+ def encode(self, text, char_or_elem):
+ """convert text-label into text-index.
+ """
+ if char_or_elem == "char":
+ max_len = self.max_text_length
+ current_dict = self.dict_character
+ else:
+ max_len = self.max_elem_length
+ current_dict = self.dict_elem
+ if len(text) > max_len:
+ return None
+ if len(text) == 0:
+ if char_or_elem == "char":
+ return [self.dict_character['space']]
+ else:
+ return None
+ text_list = []
+ for char in text:
+ if char not in current_dict:
+ return None
+ text_list.append(current_dict[char])
+ if len(text_list) == 0:
+ if char_or_elem == "char":
+ return [self.dict_character['space']]
+ else:
+ return None
+ return text_list
+
+ def get_ignored_tokens(self, char_or_elem):
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
+ if char_or_elem == "char":
+ if beg_or_end == "beg":
+ idx = np.array(self.dict_character[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict_character[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
+ % beg_or_end
+ elif char_or_elem == "elem":
+ if beg_or_end == "beg":
+ idx = np.array(self.dict_elem[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict_elem[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
+ % beg_or_end
+ else:
+ assert False, "Unsupport type %s in char_or_elem" \
+ % char_or_elem
+ return idx
+
\ No newline at end of file
diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..78b76c5afb8c96bc96730c7b8ad76b4bafa31c67
--- /dev/null
+++ b/ppocr/data/pubtab_dataset.py
@@ -0,0 +1,107 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+import random
+from paddle.io import Dataset
+import json
+
+from .imaug import transform, create_operators
+
+
+class PubTabDataSet(Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ super(PubTabDataSet, self).__init__()
+ self.logger = logger
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+
+ label_file_path = dataset_config.pop('label_file_path')
+
+ self.data_dir = dataset_config['data_dir']
+ self.do_shuffle = loader_config['shuffle']
+ self.do_hard_select = False
+ if 'hard_select' in loader_config:
+ self.do_hard_select = loader_config['hard_select']
+ self.hard_prob = loader_config['hard_prob']
+ if self.do_hard_select:
+ self.img_select_prob = self.load_hard_select_prob()
+ self.table_select_type = None
+ if 'table_select_type' in loader_config:
+ self.table_select_type = loader_config['table_select_type']
+ self.table_select_prob = loader_config['table_select_prob']
+
+ self.seed = seed
+ logger.info("Initialize indexs of datasets:%s" % label_file_path)
+ with open(label_file_path, "rb") as f:
+ self.data_lines = f.readlines()
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+ if mode.lower() == "train":
+ self.shuffle_data_random()
+ self.ops = create_operators(dataset_config['transforms'], global_config)
+
+ def shuffle_data_random(self):
+ if self.do_shuffle:
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
+ return
+
+ def __getitem__(self, idx):
+ try:
+ data_line = self.data_lines[idx]
+ data_line = data_line.decode('utf-8').strip("\n")
+ info = json.loads(data_line)
+ file_name = info['filename']
+ select_flag = True
+ if self.do_hard_select:
+ prob = self.img_select_prob[file_name]
+ if prob < random.uniform(0, 1):
+ select_flag = False
+
+ if self.table_select_type:
+ structure = info['html']['structure']['tokens'].copy()
+ structure_str = ''.join(structure)
+ table_type = "simple"
+ if 'colspan' in structure_str or 'rowspan' in structure_str:
+ table_type = "complex"
+ if table_type == "complex":
+ if self.table_select_prob < random.uniform(0, 1):
+ select_flag = False
+
+ if select_flag:
+ cells = info['html']['cells'].copy()
+ structure = info['html']['structure'].copy()
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {'img_path': img_path, 'cells': cells, 'structure':structure}
+ if not os.path.exists(img_path):
+ raise Exception("{} does not exist!".format(img_path))
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ outs = transform(data, self.ops)
+ else:
+ outs = None
+ except Exception as e:
+ self.logger.error(
+ "When parsing line {}, error happened with msg: {}".format(
+ data_line, e))
+ outs = None
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return len(self.data_idx_order_list)
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index bf10d2982dcdd36021a7385ab8828398b51af3d3..025ae7ca5cc604eea59423ca7f523c37c1492e35 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -38,11 +38,13 @@ from .basic_loss import DistanceLoss
# combined loss function
from .combined_loss import CombinedLoss
+# table loss
+from .table_att_loss import TableAttentionLoss
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
- 'SRNLoss', 'PGLoss', 'CombinedLoss'
+ 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7fd99e6952aacc0182a482ca5ae5ddaf959a026
--- /dev/null
+++ b/ppocr/losses/table_att_loss.py
@@ -0,0 +1,109 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+from paddle import fluid
+
+class TableAttentionLoss(nn.Layer):
+ def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
+ super(TableAttentionLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
+ self.structure_weight = structure_weight
+ self.loc_weight = loc_weight
+ self.use_giou = use_giou
+ self.giou_weight = giou_weight
+
+ def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
+ '''
+ :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
+ :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
+ :return: loss
+ '''
+ ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
+ iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
+ ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
+ iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
+
+ iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
+ ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
+
+ # overlap
+ inters = iw * ih
+
+ # union
+ uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
+ ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
+ bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
+
+ # ious
+ ious = inters / uni
+
+ ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
+ ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
+ ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
+ ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
+ ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
+ eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
+
+ # enclose erea
+ enclose = ew * eh + eps
+ giou = ious - (enclose - uni) / enclose
+
+ loss = 1 - giou
+
+ if reduction == 'mean':
+ loss = paddle.mean(loss)
+ elif reduction == 'sum':
+ loss = paddle.sum(loss)
+ else:
+ raise NotImplementedError
+ return loss
+
+ def forward(self, predicts, batch):
+ structure_probs = predicts['structure_probs']
+ structure_targets = batch[1].astype("int64")
+ structure_targets = structure_targets[:, 1:]
+ if len(batch) == 6:
+ structure_mask = batch[5].astype("int64")
+ structure_mask = structure_mask[:, 1:]
+ structure_mask = paddle.reshape(structure_mask, [-1])
+ structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
+ structure_targets = paddle.reshape(structure_targets, [-1])
+ structure_loss = self.loss_func(structure_probs, structure_targets)
+
+ if len(batch) == 6:
+ structure_loss = structure_loss * structure_mask
+
+# structure_loss = paddle.sum(structure_loss) * self.structure_weight
+ structure_loss = paddle.mean(structure_loss) * self.structure_weight
+
+ loc_preds = predicts['loc_preds']
+ loc_targets = batch[2].astype("float32")
+ loc_targets_mask = batch[4].astype("float32")
+ loc_targets = loc_targets[:, 1:, :]
+ loc_targets_mask = loc_targets_mask[:, 1:, :]
+ loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
+ if self.use_giou:
+ loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
+ total_loss = structure_loss + loc_loss + loc_loss_giou
+ return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
+ else:
+ total_loss = structure_loss + loc_loss
+ return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
\ No newline at end of file
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 9e9060fa999bd3175c31dfc0797cd293d4e7afec..64f62e51cdf922773c03bb784a4edffdc17f506f 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -26,11 +26,11 @@ from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
-
+from .table_metric import TableMetric
def build_metric(config):
support_dict = [
- "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
+ "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d1c789ecc3979bd4c33620af91ccd28012f7a8
--- /dev/null
+++ b/ppocr/metrics/table_metric.py
@@ -0,0 +1,50 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+class TableMetric(object):
+ def __init__(self, main_indicator='acc', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, pred, batch, *args, **kwargs):
+ structure_probs = pred['structure_probs'].numpy()
+ structure_labels = batch[1]
+ correct_num = 0
+ all_num = 0
+ structure_probs = np.argmax(structure_probs, axis=2)
+ structure_labels = structure_labels[:, 1:]
+ batch_size = structure_probs.shape[0]
+ for bno in range(batch_size):
+ all_num += 1
+ if (structure_probs[bno] == structure_labels[bno]).all():
+ correct_num += 1
+ self.correct_num += correct_num
+ self.all_num += all_num
+ return {
+ 'acc': correct_num * 1.0 / all_num,
+ }
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ }
+ """
+ acc = 1.0 * self.correct_num / self.all_num
+ self.reset()
+ return {'acc': acc}
+
+ def reset(self):
+ self.correct_num = 0
+ self.all_num = 0
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index 4c941fcf65573d9314c0badda49895d0b6b5c4f9..03fbcee8465df9c8bb7845ea62fc0ac04917caa0 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -78,10 +78,7 @@ class BaseModel(nn.Layer):
if self.use_neck:
x = self.neck(x)
y["neck_out"] = x
- if data is None:
- x = self.head(x)
- else:
- x = self.head(x, data)
+ x = self.head(x, targets=data)
y["head_out"] = x
if self.return_all_feats:
return y
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index fe2c9bc30a4f2abd1ba7d3d6989b9ef9b20c1f4f..13b70b203371b3be58ee82c6808d744bf6098333 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -29,6 +29,10 @@ def build_backbone(config, model_type):
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
+ elif model_type == "table":
+ from .table_resnet_vd import ResNet
+ from .table_mobilenet_v3 import MobileNetV3
+ support_dict = ['ResNet', 'MobileNetV3']
else:
raise NotImplementedError
diff --git a/ppocr/modeling/backbones/table_mobilenet_v3.py b/ppocr/modeling/backbones/table_mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa87f976038d8d5eeafadceb869b9232ba22cd9
--- /dev/null
+++ b/ppocr/modeling/backbones/table_mobilenet_v3.py
@@ -0,0 +1,287 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+__all__ = ['MobileNetV3']
+
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class MobileNetV3(nn.Layer):
+ def __init__(self,
+ in_channels=3,
+ model_name='large',
+ scale=0.5,
+ disable_se=False,
+ **kwargs):
+ """
+ the MobilenetV3 backbone network for detection module.
+ Args:
+ params(dict): the super parameters for build network
+ """
+ super(MobileNetV3, self).__init__()
+
+ self.disable_se = disable_se
+
+ if model_name == "large":
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, False, 'relu', 1],
+ [3, 64, 24, False, 'relu', 2],
+ [3, 72, 24, False, 'relu', 1],
+ [5, 72, 40, True, 'relu', 2],
+ [5, 120, 40, True, 'relu', 1],
+ [5, 120, 40, True, 'relu', 1],
+ [3, 240, 80, False, 'hardswish', 2],
+ [3, 200, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 480, 112, True, 'hardswish', 1],
+ [3, 672, 112, True, 'hardswish', 1],
+ [5, 672, 160, True, 'hardswish', 2],
+ [5, 960, 160, True, 'hardswish', 1],
+ [5, 960, 160, True, 'hardswish', 1],
+ ]
+ cls_ch_squeeze = 960
+ elif model_name == "small":
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, True, 'relu', 2],
+ [3, 72, 24, False, 'relu', 2],
+ [3, 88, 24, False, 'relu', 1],
+ [5, 96, 40, True, 'hardswish', 2],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 120, 48, True, 'hardswish', 1],
+ [5, 144, 48, True, 'hardswish', 1],
+ [5, 288, 96, True, 'hardswish', 2],
+ [5, 576, 96, True, 'hardswish', 1],
+ [5, 576, 96, True, 'hardswish', 1],
+ ]
+ cls_ch_squeeze = 576
+ else:
+ raise NotImplementedError("mode[" + model_name +
+ "_model] is not implemented!")
+
+ supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+ assert scale in supported_scale, \
+ "supported scale are {} but input scale is {}".format(supported_scale, scale)
+ inplanes = 16
+ # conv1
+ self.conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=make_divisible(inplanes * scale),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=1,
+ if_act=True,
+ act='hardswish',
+ name='conv1')
+
+ self.stages = []
+ self.out_channels = []
+ block_list = []
+ i = 0
+ inplanes = make_divisible(inplanes * scale)
+ for (k, exp, c, se, nl, s) in cfg:
+ se = se and not self.disable_se
+ start_idx = 2 if model_name == 'large' else 0
+ if s == 2 and i > start_idx:
+ self.out_channels.append(inplanes)
+ self.stages.append(nn.Sequential(*block_list))
+ block_list = []
+ block_list.append(
+ ResidualUnit(
+ in_channels=inplanes,
+ mid_channels=make_divisible(scale * exp),
+ out_channels=make_divisible(scale * c),
+ kernel_size=k,
+ stride=s,
+ use_se=se,
+ act=nl,
+ name="conv" + str(i + 2)))
+ inplanes = make_divisible(scale * c)
+ i += 1
+ block_list.append(
+ ConvBNLayer(
+ in_channels=inplanes,
+ out_channels=make_divisible(scale * cls_ch_squeeze),
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ if_act=True,
+ act='hardswish',
+ name='conv_last'))
+ self.stages.append(nn.Sequential(*block_list))
+ self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
+ for i, stage in enumerate(self.stages):
+ self.add_sublayer(sublayer=stage, name="stage{}".format(i))
+
+ def forward(self, x):
+ x = self.conv(x)
+ out_list = []
+ for stage in self.stages:
+ x = stage(x)
+ out_list.append(x)
+ return out_list
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+ self.if_act = if_act
+ self.act = act
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '_weights'),
+ bias_attr=False)
+
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=None,
+ param_attr=ParamAttr(name=name + "_bn_scale"),
+ bias_attr=ParamAttr(name=name + "_bn_offset"),
+ moving_mean_name=name + "_bn_mean",
+ moving_variance_name=name + "_bn_variance")
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.if_act:
+ if self.act == "relu":
+ x = F.relu(x)
+ elif self.act == "hardswish":
+ x = F.hardswish(x)
+ else:
+ print("The activation function({}) is selected incorrectly.".
+ format(self.act))
+ exit()
+ return x
+
+
+class ResidualUnit(nn.Layer):
+ def __init__(self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ use_se,
+ act=None,
+ name=''):
+ super(ResidualUnit, self).__init__()
+ self.if_shortcut = stride == 1 and in_channels == out_channels
+ self.if_se = use_se
+
+ self.expand_conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ if_act=True,
+ act=act,
+ name=name + "_expand")
+ self.bottleneck_conv = ConvBNLayer(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=int((kernel_size - 1) // 2),
+ groups=mid_channels,
+ if_act=True,
+ act=act,
+ name=name + "_depthwise")
+ if self.if_se:
+ self.mid_se = SEModule(mid_channels, name=name + "_se")
+ self.linear_conv = ConvBNLayer(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ if_act=False,
+ act=None,
+ name=name + "_linear")
+
+ def forward(self, inputs):
+ x = self.expand_conv(inputs)
+ x = self.bottleneck_conv(x)
+ if self.if_se:
+ x = self.mid_se(x)
+ x = self.linear_conv(x)
+ if self.if_shortcut:
+ x = paddle.add(inputs, x)
+ return x
+
+
+class SEModule(nn.Layer):
+ def __init__(self, in_channels, reduction=4, name=""):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2D(1)
+ self.conv1 = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=in_channels // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(name=name + "_1_weights"),
+ bias_attr=ParamAttr(name=name + "_1_offset"))
+ self.conv2 = nn.Conv2D(
+ in_channels=in_channels // reduction,
+ out_channels=in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(name + "_2_weights"),
+ bias_attr=ParamAttr(name=name + "_2_offset"))
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
+ return inputs * outputs
\ No newline at end of file
diff --git a/ppocr/modeling/backbones/table_resnet_vd.py b/ppocr/modeling/backbones/table_resnet_vd.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c07c2684eec8d0c4a445cc88c543bfe1da9c864
--- /dev/null
+++ b/ppocr/modeling/backbones/table_resnet_vd.py
@@ -0,0 +1,280 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+__all__ = ["ResNet"]
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None, ):
+ super(ConvBNLayer, self).__init__()
+
+ self.is_vd_mode = is_vd_mode
+ self._pool2d_avg = nn.AvgPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self._conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False)
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self._batch_norm = nn.BatchNorm(
+ out_channels,
+ act=act,
+ param_attr=ParamAttr(name=bn_name + '_scale'),
+ bias_attr=ParamAttr(bn_name + '_offset'),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance')
+
+ def forward(self, inputs):
+ if self.is_vd_mode:
+ inputs = self._pool2d_avg(inputs)
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2b")
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ act=None,
+ name=name + "_branch2c")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ stride=1,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv2)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ name=name + "_branch2b")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv1)
+ y = F.relu(y)
+ return y
+
+
+class ResNet(nn.Layer):
+ def __init__(self, in_channels=3, layers=50, **kwargs):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ supported_layers = [18, 34, 50, 101, 152, 200]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ elif layers == 200:
+ depth = [3, 12, 48, 3]
+ num_channels = [64, 256, 512,
+ 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_filters = [64, 128, 256, 512]
+
+ self.conv1_1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=32,
+ kernel_size=3,
+ stride=2,
+ act='relu',
+ name="conv1_1")
+ self.conv1_2 = ConvBNLayer(
+ in_channels=32,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv1_2")
+ self.conv1_3 = ConvBNLayer(
+ in_channels=32,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv1_3")
+ self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+ self.stages = []
+ self.out_channels = []
+ if layers >= 50:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ bottleneck_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BottleneckBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block] * 4,
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(bottleneck_block)
+ self.out_channels.append(num_filters[block] * 4)
+ self.stages.append(nn.Sequential(*block_list))
+ else:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ basic_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BasicBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block],
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(basic_block)
+ self.out_channels.append(num_filters[block])
+ self.stages.append(nn.Sequential(*block_list))
+
+ def forward(self, inputs):
+ y = self.conv1_1(inputs)
+ y = self.conv1_2(y)
+ y = self.conv1_3(y)
+ y = self.pool2d_max(y)
+ out = []
+ for block in self.stages:
+ y = block(y)
+ out.append(y)
+ return out
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 4852c7f2d14d72b9e4d59f40532469f7226c966d..5096479415f504aa9f074d55bd9b2e4a31c730b4 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -31,8 +31,10 @@ def build_head(config):
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
- 'SRNHead', 'PGHead']
+ 'SRNHead', 'PGHead', 'TableAttentionHead']
+ #table head
+ from .table_att_head import TableAttentionHead
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
diff --git a/ppocr/modeling/heads/cls_head.py b/ppocr/modeling/heads/cls_head.py
index d9b78b841b3c31ea349cfbf4e767328b12f39aa7..91bfa615a8206b5ec0f993429ccae990a05d0b9b 100644
--- a/ppocr/modeling/heads/cls_head.py
+++ b/ppocr/modeling/heads/cls_head.py
@@ -43,7 +43,7 @@ class ClsHead(nn.Layer):
initializer=nn.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_0.b_0"), )
- def forward(self, x):
+ def forward(self, x, targets=None):
x = self.pool(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.fc(x)
diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py
index 83e7a5ebfe131ed209b7dd2d4b5a324605be8370..f76cb34d37af7d81b5e628d06c1a4cfe126f8bb4 100644
--- a/ppocr/modeling/heads/det_db_head.py
+++ b/ppocr/modeling/heads/det_db_head.py
@@ -106,7 +106,7 @@ class DBHead(nn.Layer):
def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
- def forward(self, x):
+ def forward(self, x, targets=None):
shrink_maps = self.binarize(x)
if not self.training:
return {'maps': shrink_maps}
diff --git a/ppocr/modeling/heads/det_east_head.py b/ppocr/modeling/heads/det_east_head.py
index 9d0c3c4cf83adb018fcc368374cbe305658e07a9..004eb5d7bb9a134d1a84f980e37e5336dc43a29a 100644
--- a/ppocr/modeling/heads/det_east_head.py
+++ b/ppocr/modeling/heads/det_east_head.py
@@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
act=None,
name="f_geo")
- def forward(self, x):
+ def forward(self, x, targets=None):
f_det = self.det_conv1(x)
f_det = self.det_conv2(f_det)
f_score = self.score_conv(f_det)
diff --git a/ppocr/modeling/heads/det_sast_head.py b/ppocr/modeling/heads/det_sast_head.py
index 263b28672299e733369938fa03952dca7685fabe..7a88a2db6c29c8c4fa1ee94d27bd0701cdbc90f8 100644
--- a/ppocr/modeling/heads/det_sast_head.py
+++ b/ppocr/modeling/heads/det_sast_head.py
@@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
self.head1 = SAST_Header1(in_channels)
self.head2 = SAST_Header2(in_channels)
- def forward(self, x):
+ def forward(self, x, targets=None):
f_score, f_border = self.head1(x)
f_tvo, f_tco = self.head2(x)
diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py
index 0da9de7580a0ceb473f971b2246c966497026a5d..274e1cdac5172f45590c9f7d7b50522c74db6750 100644
--- a/ppocr/modeling/heads/e2e_pg_head.py
+++ b/ppocr/modeling/heads/e2e_pg_head.py
@@ -220,7 +220,7 @@ class PGHead(nn.Layer):
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False)
- def forward(self, x):
+ def forward(self, x, targets=None):
f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score)
diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py
index b54322da01cebff1034a2b89e33015ff120fc727..9c38d31fa0abcf39a583e5edcebfc8f336f41c46 100755
--- a/ppocr/modeling/heads/rec_ctc_head.py
+++ b/ppocr/modeling/heads/rec_ctc_head.py
@@ -67,13 +67,13 @@ class CTCHead(nn.Layer):
self.out_channels = out_channels
self.mid_channels = mid_channels
- def forward(self, x, labels=None):
+ def forward(self, x, targets=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
predicts = self.fc1(x)
predicts = self.fc2(predicts)
-
+
if not self.training:
predicts = F.softmax(predicts, axis=2)
return predicts
diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py
index d2c7fc028d28c79057708d4e6f306c417ba6306a..8d59e4711a043afd9234f430a62c9876c0a8f6f4 100644
--- a/ppocr/modeling/heads/rec_srn_head.py
+++ b/ppocr/modeling/heads/rec_srn_head.py
@@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
- def forward(self, inputs, others):
+ def forward(self, inputs, targets=None):
+ others = targets[-4:]
encoder_word_pos = others[0]
gsrm_word_pos = others[1]
gsrm_slf_attn_bias1 = others[2]
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..155f036d15673135eae9e5ee493648603609535d
--- /dev/null
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -0,0 +1,238 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+
+class TableAttentionHead(nn.Layer):
+ def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
+ super(TableAttentionHead, self).__init__()
+ self.input_size = in_channels[-1]
+ self.hidden_size = hidden_size
+ self.elem_num = 30
+ self.max_text_length = 100
+ self.max_elem_length = 500
+ self.max_cell_num = 500
+
+ self.structure_attention_cell = AttentionGRUCell(
+ self.input_size, hidden_size, self.elem_num, use_gru=False)
+ self.structure_generator = nn.Linear(hidden_size, self.elem_num)
+ self.loc_type = loc_type
+ self.in_max_len = in_max_len
+
+ if self.loc_type == 1:
+ self.loc_generator = nn.Linear(hidden_size, 4)
+ else:
+ if self.in_max_len == 640:
+ self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
+ elif self.in_max_len == 800:
+ self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
+ else:
+ self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
+ self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None):
+ # if and else branch are both needed when you want to assign a variable
+ # if you modify the var in just one branch, then the modification will not work.
+ fea = inputs[-1]
+ if len(fea.shape) == 3:
+ pass
+ else:
+ last_shape = int(np.prod(fea.shape[2:])) # gry added
+ fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
+ fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
+ batch_size = fea.shape[0]
+
+ hidden = paddle.zeros((batch_size, self.hidden_size))
+ output_hiddens = []
+ if self.training and targets is not None:
+ structure = targets[0]
+ for i in range(self.max_elem_length+1):
+ elem_onehots = self._char_to_onehot(
+ structure[:, i], onehot_dim=self.elem_num)
+ (outputs, hidden), alpha = self.structure_attention_cell(
+ hidden, fea, elem_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ structure_probs = self.structure_generator(output)
+ if self.loc_type == 1:
+ loc_preds = self.loc_generator(output)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
+ structure_probs = None
+ loc_preds = None
+ elem_onehots = None
+ outputs = None
+ alpha = None
+ max_elem_length = paddle.to_tensor(self.max_elem_length)
+ i = 0
+ while i < max_elem_length+1:
+ elem_onehots = self._char_to_onehot(
+ temp_elem, onehot_dim=self.elem_num)
+ (outputs, hidden), alpha = self.structure_attention_cell(
+ hidden, fea, elem_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ structure_probs_step = self.structure_generator(outputs)
+ temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
+ i += 1
+
+ output = paddle.concat(output_hiddens, axis=1)
+ structure_probs = self.structure_generator(output)
+ structure_probs = F.softmax(structure_probs)
+ if self.loc_type == 1:
+ loc_preds = self.loc_generator(output)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
+ return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
+
+
+class AttentionGRUCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionGRUCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+ return cur_hidden, alpha
+
+
+class AttentionLSTM(nn.Layer):
+ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
+ super(AttentionLSTM, self).__init__()
+ self.input_size = in_channels
+ self.hidden_size = hidden_size
+ self.num_classes = out_channels
+
+ self.attention_cell = AttentionLSTMCell(
+ in_channels, hidden_size, out_channels, use_gru=False)
+ self.generator = nn.Linear(hidden_size, out_channels)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None, batch_max_length=25):
+ batch_size = inputs.shape[0]
+ num_steps = batch_max_length
+
+ hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
+ (batch_size, self.hidden_size)))
+ output_hiddens = []
+
+ if targets is not None:
+ for i in range(num_steps):
+ # one-hot vectors for a i-th char
+ char_onehots = self._char_to_onehot(
+ targets[:, i], onehot_dim=self.num_classes)
+ hidden, alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+
+ hidden = (hidden[1][0], hidden[1][1])
+ output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ probs = self.generator(output)
+
+ else:
+ targets = paddle.zeros(shape=[batch_size], dtype="int32")
+ probs = None
+
+ for i in range(num_steps):
+ char_onehots = self._char_to_onehot(
+ targets, onehot_dim=self.num_classes)
+ hidden, alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+ probs_step = self.generator(hidden[0])
+ hidden = (hidden[1][0], hidden[1][1])
+ if probs is None:
+ probs = paddle.unsqueeze(probs_step, axis=1)
+ else:
+ probs = paddle.concat(
+ [probs, paddle.unsqueeze(
+ probs_step, axis=1)], axis=1)
+
+ next_input = probs_step.argmax(axis=1)
+
+ targets = next_input
+
+ return probs
+
+
+class AttentionLSTMCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionLSTMCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ if not use_gru:
+ self.rnn = nn.LSTMCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ else:
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+
+ return cur_hidden, alpha
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index 37a5cf7863cb386884d82ed88c756c9fc06a541d..e97c4f64bdc9acd6729d67a9c6ff7a7563f6c95e 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -21,7 +21,8 @@ def build_neck(config):
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
from .pg_fpn import PGFPN
- support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
+ from .table_fpn import TableFPN
+ support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
diff --git a/ppocr/modeling/necks/table_fpn.py b/ppocr/modeling/necks/table_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..734f15af65e4e15a7ddb4004954a61bfa1934246
--- /dev/null
+++ b/ppocr/modeling/necks/table_fpn.py
@@ -0,0 +1,110 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class TableFPN(nn.Layer):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(TableFPN, self).__init__()
+ self.out_channels = 512
+ weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.in2_conv = nn.Conv2D(
+ in_channels=in_channels[0],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in3_conv = nn.Conv2D(
+ in_channels=in_channels[1],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride = 1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in4_conv = nn.Conv2D(
+ in_channels=in_channels[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in5_conv = nn.Conv2D(
+ in_channels=in_channels[3],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p5_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p4_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p3_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p2_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.fuse_conv = nn.Conv2D(
+ in_channels=self.out_channels * 4,
+ out_channels=512,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.in5_conv(c5)
+ in4 = self.in4_conv(c4)
+ in3 = self.in3_conv(c3)
+ in2 = self.in2_conv(c2)
+
+ out4 = in4 + F.upsample(
+ in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
+ out3 = in3 + F.upsample(
+ out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
+ out2 = in2 + F.upsample(
+ out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
+
+ p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ fuse = paddle.concat([in5, p4, p3, p2], axis=1)
+ fuse_conv = self.fuse_conv(fuse) * 0.005
+ return [c5 + fuse_conv]
diff --git a/tools/eval.py b/tools/eval.py
index f5e8cd5593a0eb12247300b9f52f152655a49c59..c1315805b5ff9bf29dee87a21688a145b4662b9a 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -55,6 +55,7 @@ def main():
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
+ model_type = config['Architecture']['model_type']
best_model_dict = init_model(config, model)
if len(best_model_dict):
@@ -67,7 +68,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class, use_srn)
+ eval_class, model_type, use_srn)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
diff --git a/tools/export_model.py b/tools/export_model.py
index 625c82468edff7c3eeb787422bdef07b4b274460..785aca10e46200bda49bdff2b89ba00cafbe7a20 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
-
+ elif arch_config["model_type"] == "table":
+ infer_shape = [3, 488, 488]
model = to_static(
model,
input_spec=[
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index d491d6013869da5cc5e7cc7975a3324a460182a2..1d652e7d352da90bcc08b701f332f586ebb9339c 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -331,10 +331,11 @@ def create_predictor(args, mode, logger):
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
+ if mode == 'structure':
+ config.delete_pass("fc_fuse_pass") # not supported for table
config.switch_use_feed_fetch_ops(False)
config.switch_ir_optim(True)
- if mode == 'structure':
- config.switch_ir_optim(False)
+
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
diff --git a/tools/infer_table.py b/tools/infer_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..f743d87540f7fd64157a808db156c9f62a042d9c
--- /dev/null
+++ b/tools/infer_table.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+import json
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import paddle
+from paddle.jit import to_static
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import init_model
+from ppocr.utils.utility import get_image_file_list
+import tools.program as program
+import cv2
+
+def main(config, device, logger, vdl_writer):
+ global_config = config['Global']
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ if hasattr(post_process_class, 'character'):
+ config['Architecture']["Head"]['out_channels'] = len(
+ getattr(post_process_class, 'character'))
+
+ model = build_model(config['Architecture'])
+
+ init_model(config, model, logger)
+
+ # create data ops
+ transforms = []
+ use_padding = False
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ continue
+ if op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = ['image']
+ if op_name == "ResizeTableImage":
+ use_padding = True
+ padding_max_len = op['ResizeTableImage']['max_len']
+ transforms.append(op)
+
+ global_config['infer_mode'] = True
+ ops = create_operators(transforms, global_config)
+
+ model.eval()
+ for file in get_image_file_list(config['Global']['infer_img']):
+ logger.info("infer_img: {}".format(file))
+ with open(file, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, ops)
+ images = np.expand_dims(batch[0], axis=0)
+ images = paddle.to_tensor(images)
+ preds = model(images)
+ post_result = post_process_class(preds)
+ res_html_code = post_result['res_html_code']
+ res_loc = post_result['res_loc']
+ img = cv2.imread(file)
+ imgh, imgw = img.shape[0:2]
+ res_loc_final = []
+ for rno in range(len(res_loc[0])):
+ x0, y0, x1, y1 = res_loc[0][rno]
+ left = max(int(imgw * x0), 0)
+ top = max(int(imgh * y0), 0)
+ right = min(int(imgw * x1), imgw - 1)
+ bottom = min(int(imgh * y1), imgh - 1)
+ cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
+ res_loc_final.append([left, top, right, bottom])
+ res_loc_str = json.dumps(res_loc_final)
+ logger.info("result: {}, {}".format(res_html_code, res_loc_final))
+ logger.info("success!")
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main(config, device, logger, vdl_writer)
+
diff --git a/tools/program.py b/tools/program.py
index 7641bed749ff4bb0d58712a9f50c6a119a4f25ee..2d99f2968a3f0c8acc359ed0fbb199650bd7010c 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -186,6 +186,7 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
+ model_type = config['Architecture']['model_type']
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
@@ -208,9 +209,9 @@ def train(config,
lr = optimizer.get_lr()
images = batch[0]
if use_srn:
- others = batch[-4:]
- preds = model(images, others)
model_average = True
+ if use_srn or model_type == 'table':
+ preds = model(images, data=batch[1:])
else:
preds = model(images)
loss = loss_class(preds, batch)
@@ -232,8 +233,11 @@ def train(config,
if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
- post_result = post_process_class(preds, batch[1])
- eval_class(post_result, batch)
+ if model_type == 'table':
+ eval_class(preds, batch)
+ else:
+ post_result = post_process_class(preds, batch[1])
+ eval_class(post_result, batch)
metric = eval_class.get_metric()
train_stats.update(metric)
@@ -269,6 +273,7 @@ def train(config,
valid_dataloader,
post_process_class,
eval_class,
+ model_type,
use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
@@ -336,7 +341,11 @@ def train(config,
return
-def eval(model, valid_dataloader, post_process_class, eval_class,
+def eval(model,
+ valid_dataloader,
+ post_process_class,
+ eval_class,
+ model_type,
use_srn=False):
model.eval()
with paddle.no_grad():
@@ -350,19 +359,19 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break
images = batch[0]
start = time.time()
-
- if use_srn:
- others = batch[-4:]
- preds = model(images, others)
+ if use_srn or model_type == 'table':
+ preds = model(images, data=batch[1:])
else:
preds = model(images)
-
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
- post_result = post_process_class(preds, batch[1])
total_time += time.time() - start
# Evaluate the results of the current batch
- eval_class(post_result, batch)
+ if model_type == 'table':
+ eval_class(preds, batch)
+ else:
+ post_result = post_process_class(preds, batch[1])
+ eval_class(post_result, batch)
pbar.update(1)
total_frame += len(images)
# Get final metric,eg. acc or hmean
@@ -386,7 +395,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
- 'CLS', 'PGNet', 'Distillation'
+ 'CLS', 'PGNet', 'Distillation', 'TableAttn'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|