From c1fd46641eff86b8e7ee2aa8d1943f7aced5b3a7 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 30 Dec 2020 16:15:49 +0800 Subject: [PATCH] add srn for dygraph --- configs/rec/rec_mv3_none_bilstm_ctc.yml | 6 +- configs/rec/rec_mv3_none_none_ctc.yml | 4 +- configs/rec/rec_mv3_tps_bilstm_ctc.yml | 4 +- configs/rec/rec_r34_vd_none_bilstm_ctc.yml | 4 +- configs/rec/rec_r34_vd_none_none_ctc.yml | 4 +- configs/rec/rec_r34_vd_tps_bilstm_ctc.yml | 4 +- configs/rec/rec_r50_fpn_srn.yml | 106 ++++++ ppocr/data/__init__.py | 4 +- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 48 +++ ppocr/data/imaug/rec_img_aug.py | 90 ++++- ppocr/data/lmdb_dataset.py | 4 +- ppocr/losses/__init__.py | 5 +- ppocr/losses/rec_srn_loss.py | 47 +++ ppocr/metrics/__init__.py | 1 + ppocr/metrics/rec_metric.py | 4 +- ppocr/modeling/architectures/base_model.py | 7 +- ppocr/modeling/backbones/__init__.py | 3 +- ppocr/modeling/backbones/rec_resnet_fpn.py | 307 ++++++++++++++++ ppocr/modeling/heads/__init__.py | 5 +- ppocr/modeling/heads/rec_srn_head.py | 279 ++++++++++++++ ppocr/modeling/heads/self_attention.py | 408 +++++++++++++++++++++ ppocr/postprocess/__init__.py | 5 +- ppocr/postprocess/rec_postprocess.py | 84 ++++- tools/export_model.py | 41 ++- tools/infer/predict_rec.py | 149 +++++++- tools/infer_rec.py | 25 +- tools/program.py | 14 +- 28 files changed, 1594 insertions(+), 70 deletions(-) create mode 100644 configs/rec/rec_r50_fpn_srn.yml create mode 100644 ppocr/losses/rec_srn_loss.py create mode 100644 ppocr/modeling/backbones/rec_resnet_fpn.py create mode 100644 ppocr/modeling/heads/rec_srn_head.py create mode 100644 ppocr/modeling/heads/self_attention.py diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 38f1e869..00c1db88 100644 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -1,5 +1,5 @@ Global: - use_gpu: true + use_gpu: True epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 @@ -59,7 +59,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -78,7 +78,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml index 33079ad4..6711b1d2 100644 --- a/configs/rec/rec_mv3_none_none_ctc.yml +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -58,7 +58,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -77,7 +77,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index 08f68939..1b9fb0a0 100644 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -63,7 +63,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -82,7 +82,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 4ad2ff89..e4d301a6 100644 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -58,7 +58,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -77,7 +77,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml index 9c1eeb30..4a17a004 100644 --- a/configs/rec/rec_r34_vd_none_none_ctc.yml +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -56,7 +56,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -75,7 +75,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index aeded492..62edf843 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -62,7 +62,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -81,7 +81,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml new file mode 100644 index 00000000..78f8d551 --- /dev/null +++ b/configs/rec/rec_r50_fpn_srn.yml @@ -0,0 +1,106 @@ +Global: + use_gpu: True + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/rec/srn + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 5000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + num_heads: 8 + infer_mode: False + use_space_char: False + + +Optimizer: + name: Adam + lr: + name: Cosine + learning_rate: 0.0001 + +Architecture: + model_type: rec + algorithm: SRN + in_channels: 1 + Transform: + Backbone: + name: ResNetFPN + Head: + name: SRNHead + max_text_length: 25 + num_heads: 8 + num_encoder_TUs: 2 + num_decoder_TUs: 4 + hidden_dims: 512 + +Loss: + name: SRNLoss + +PostProcess: + name: SRNLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/srn_train_data_duiqi + #label_file_list: ["./train_data/ic15_data/1.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SRNLabelEncode: # Class handling label + - SRNRecResizeImg: + image_shape: [1, 64, 256] + - KeepKeys: + keep_keys: ['image', + 'label', + 'length', + 'encoder_word_pos', + 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', + 'gsrm_slf_attn_bias2'] # dataloader will return list in this order + loader: + shuffle: False + batch_size_per_card: 64 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SRNLabelEncode: # Class handling label + - SRNRecResizeImg: + image_shape: [1, 64, 256] + - KeepKeys: + keep_keys: ['image', + 'label', + 'length', + 'encoder_word_pos', + 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', + 'gsrm_slf_attn_bias2'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 32 + num_workers: 4 diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 7b0faf12..eb461ffa 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -33,7 +33,7 @@ import paddle.distributed as dist from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet -from ppocr.data.lmdb_dataset import LMDBDateSet +from ppocr.data.lmdb_dataset import LMDBDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp) def build_dataloader(config, mode, device, logger): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDateSet'] + support_dict = ['SimpleDataSet', 'LMDBDataSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 6ea4dd8e..250ac75e 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .randaugment import RandAugment from .operators import * from .label_ops import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index af3308a5..986cec3d 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -98,6 +98,8 @@ class BaseRecLabelEncode(object): support_character_type, character_type) self.max_text_len = max_text_length + self.beg_str = "sos" + self.end_str = "eos" if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -213,3 +215,49 @@ class AttnLabelEncode(BaseRecLabelEncode): assert False, "Unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class SRNLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length=25, + character_dict_path=None, + character_type='en', + use_space_char=False, + **kwargs): + super(SRNLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + char_num = len(self.character_str) + if text is None: + return None + if len(text) > self.max_text_len: + return None + data['length'] = np.array(len(text)) + text = text + [char_num] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 2ccb2d1d..28e6bd0b 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -12,20 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import math import cv2 import numpy as np @@ -77,6 +63,26 @@ class RecResizeImg(object): return data +class SRNRecResizeImg(object): + def __init__(self, image_shape, num_heads, max_text_length, **kwargs): + self.image_shape = image_shape + self.num_heads = num_heads + self.max_text_length = max_text_length + + def __call__(self, data): + img = data['image'] + norm_img = resize_norm_img_srn(img, self.image_shape) + data['image'] = norm_img + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length) + + data['encoder_word_pos'] = encoder_word_pos + data['gsrm_word_pos'] = gsrm_word_pos + data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1 + data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2 + return data + + def resize_norm_img(img, image_shape): imgC, imgH, imgW = image_shape h = img.shape[0] @@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape): def resize_norm_img_chinese(img, image_shape): imgC, imgH, imgW = image_shape # todo: change to 0 and modified image shape - max_wh_ratio = 0 + max_wh_ratio = imgW * 1.0 / imgH h, w = img.shape[0], img.shape[1] ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, ratio) @@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape): return padding_im +def resize_norm_img_srn(img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + +def srn_other_inputs(image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, + [num_heads, 1, 1]) * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, + [num_heads, 1, 1]) * [-1e9] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def flag(): """ flag diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e7bb6dd3..515279fb 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -20,9 +20,9 @@ import cv2 from .imaug import transform, create_operators -class LMDBDateSet(Dataset): +class LMDBDataSet(Dataset): def __init__(self, config, mode, logger): - super(LMDBDateSet, self).__init__() + super(LMDBDataSet, self).__init__() global_config = config['Global'] dataset_config = config[mode]['dataset'] diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 4673d35c..b280eb33 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -23,11 +23,14 @@ def build_loss(config): # rec loss from .rec_ctc_loss import CTCLoss + from .rec_srn_loss import SRNLoss # cls loss from .cls_loss import ClsLoss - support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss'] + support_dict = [ + 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss' + ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py new file mode 100644 index 00000000..d722ee0f --- /dev/null +++ b/ppocr/losses/rec_srn_loss.py @@ -0,0 +1,47 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class SRNLoss(nn.Layer): + def __init__(self, **kwargs): + super(SRNLoss, self).__init__() + self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum") + + def forward(self, predicts, batch): + predict = predicts['predict'] + word_predict = predicts['word_out'] + gsrm_predict = predicts['gsrm_out'] + label = batch[1] + + casted_label = paddle.cast(x=label, dtype='int64') + casted_label = paddle.reshape(x=casted_label, shape=[-1, 1]) + + cost_word = self.loss_func(word_predict, label=casted_label) + cost_gsrm = self.loss_func(gsrm_predict, label=casted_label) + cost_vsfd = self.loss_func(predict, label=casted_label) + + cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1]) + cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1]) + cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1]) + + sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15 + + return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd} diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index a0e7d912..41828f51 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -26,6 +26,7 @@ def build_metric(config): from .det_metric import DetMetric from .rec_metric import RecMetric from .cls_metric import ClsMetric + from .rec_metric import RecMetric support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index bd0f92e0..459fe8e4 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -31,8 +31,6 @@ class RecMetric(object): if pred == target: correct_num += 1 all_num += 1 - # if all_num < 10 and kwargs.get('show_str', False): - # print('{} -> {}'.format(pred, target)) self.correct_num += correct_num self.all_num += all_num self.norm_edit_dis += norm_edit_dis @@ -48,7 +46,7 @@ class RecMetric(object): 'norm_edit_dis': 0, } """ - acc = self.correct_num / self.all_num + acc = 1.0 * self.correct_num / self.all_num norm_edit_dis = 1 - self.norm_edit_dis / self.all_num self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index ab44b53a..09b6e034 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -68,11 +68,14 @@ class BaseModel(nn.Layer): config["Head"]['in_channels'] = in_channels self.head = build_head(config["Head"]) - def forward(self, x): + def forward(self, x, data=None): if self.use_transform: x = self.transform(x) x = self.backbone(x) if self.use_neck: x = self.neck(x) - x = self.head(x) + if data is None: + x = self.head(x) + else: + x = self.head(x, data) return x diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 43103e53..03c15508 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -24,7 +24,8 @@ def build_backbone(config, model_type): elif model_type == 'rec' or model_type == 'cls': from .rec_mobilenet_v3 import MobileNetV3 from .rec_resnet_vd import ResNet - support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN'] + from .rec_resnet_fpn import ResNetFPN + support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/rec_resnet_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py new file mode 100644 index 00000000..a7e876a2 --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_fpn.py @@ -0,0 +1,307 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import paddle +import numpy as np + +__all__ = ["ResNetFPN"] + + +class ResNetFPN(nn.Layer): + def __init__(self, in_channels=1, layers=50, **kwargs): + super(ResNetFPN, self).__init__() + supported_layers = { + 18: { + 'depth': [2, 2, 2, 2], + 'block_class': BasicBlock + }, + 34: { + 'depth': [3, 4, 6, 3], + 'block_class': BasicBlock + }, + 50: { + 'depth': [3, 4, 6, 3], + 'block_class': BottleneckBlock + }, + 101: { + 'depth': [3, 4, 23, 3], + 'block_class': BottleneckBlock + }, + 152: { + 'depth': [3, 8, 36, 3], + 'block_class': BottleneckBlock + } + } + stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)] + num_filters = [64, 128, 256, 512] + self.depth = supported_layers[layers]['depth'] + self.F = [] + self.conv = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=7, + stride=2, + act="relu", + name="conv1") + self.block_list = [] + in_ch = 64 + if layers >= 50: + for block in range(len(self.depth)): + for i in range(self.depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + block_list = self.add_sublayer( + "bottleneckBlock_{}_{}".format(block, i), + BottleneckBlock( + in_channels=in_ch, + out_channels=num_filters[block], + stride=stride_list[block] if i == 0 else 1, + name=conv_name)) + in_ch = num_filters[block] * 4 + self.block_list.append(block_list) + self.F.append(block_list) + else: + for block in range(len(self.depth)): + for i in range(self.depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + if i == 0 and block != 0: + stride = (2, 1) + else: + stride = (1, 1) + basic_block = self.add_sublayer( + conv_name, + BasicBlock( + in_channels=in_ch, + out_channels=num_filters[block], + stride=stride_list[block] if i == 0 else 1, + is_first=block == i == 0, + name=conv_name)) + in_ch = basic_block.out_channels + self.block_list.append(basic_block) + out_ch_list = [in_ch // 4, in_ch // 2, in_ch] + self.base_block = [] + self.conv_trans = [] + self.bn_block = [] + for i in [-2, -3]: + in_channels = out_ch_list[i + 1] + out_ch_list[i] + + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_0".format(i), + nn.Conv2D( + in_channels=in_channels, + out_channels=out_ch_list[i], + kernel_size=1, + weight_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_1".format(i), + nn.Conv2D( + in_channels=out_ch_list[i], + out_channels=out_ch_list[i], + kernel_size=3, + padding=1, + weight_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_2".format(i), + nn.BatchNorm( + num_channels=out_ch_list[i], + act="relu", + param_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_3".format(i), + nn.Conv2D( + in_channels=out_ch_list[i], + out_channels=512, + kernel_size=1, + bias_attr=ParamAttr(trainable=True), + weight_attr=ParamAttr(trainable=True)))) + self.out_channels = 512 + + def __call__(self, x): + x = self.conv(x) + fpn_list = [] + F = [] + for i in range(len(self.depth)): + fpn_list.append(np.sum(self.depth[:i + 1])) + + for i, block in enumerate(self.block_list): + x = block(x) + for number in fpn_list: + if i + 1 == number: + F.append(x) + base = F[-1] + + j = 0 + for i, block in enumerate(self.base_block): + if i % 3 == 0 and i < 6: + j = j + 1 + b, c, w, h = F[-j - 1].shape + if [w, h] == list(base.shape[2:]): + base = base + else: + base = self.conv_trans[j - 1](base) + base = self.bn_block[j - 1](base) + base = paddle.concat([base, F[-j - 1]], axis=1) + base = block(base) + return base + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 if stride == (1, 1) else kernel_size, + dilation=2 if stride == (1, 1) else 1, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'), + bias_attr=False, ) + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name=name + '.output.1.w_0'), + bias_attr=ParamAttr(name=name + '.output.1.b_0'), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance") + + def __call__(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class ShortCut(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name, is_first=False): + super(ShortCut, self).__init__() + self.use_conv = True + + if in_channels != out_channels or stride != 1 or is_first == True: + if stride == (1, 1): + self.conv = ConvBNLayer( + in_channels, out_channels, 1, 1, name=name) + else: # stride==(2,2) + self.conv = ConvBNLayer( + in_channels, out_channels, 1, stride, name=name) + else: + self.use_conv = False + + def forward(self, x): + if self.use_conv: + x = self.conv(x) + return x + + +class BottleneckBlock(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name): + super(BottleneckBlock, self).__init__() + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + self.short = ShortCut( + in_channels=in_channels, + out_channels=out_channels * 4, + stride=stride, + is_first=False, + name=name + "_branch1") + self.out_channels = out_channels * 4 + + def forward(self, x): + y = self.conv0(x) + y = self.conv1(y) + y = self.conv2(y) + y = y + self.short(x) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name, is_first): + super(BasicBlock, self).__init__() + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + act='relu', + stride=stride, + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + self.short = ShortCut( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + is_first=is_first, + name=name + "_branch1") + self.out_channels = out_channels + + def forward(self, x): + y = self.conv0(x) + y = self.conv1(y) + y = y + self.short(x) + return F.relu(y) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 78074709..1a39ca41 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -23,10 +23,13 @@ def build_head(config): # rec head from .rec_ctc_head import CTCHead + from .rec_srn_head import SRNHead # cls head from .cls_head import ClsHead - support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead'] + support_dict = [ + 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead' + ] module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py new file mode 100644 index 00000000..8aaf65e1 --- /dev/null +++ b/ppocr/modeling/heads/rec_srn_head.py @@ -0,0 +1,279 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import numpy as np +from .self_attention import WrapEncoderForFeature +from .self_attention import WrapEncoder +from paddle.static import Program +from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN +import paddle.fluid.framework as framework + +from collections import OrderedDict +gradient_clip = 10 + + +class PVAM(nn.Layer): + def __init__(self, in_channels, char_num, max_text_length, num_heads, + num_encoder_tus, hidden_dims): + super(PVAM, self).__init__() + self.char_num = char_num + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_tus + self.hidden_dims = hidden_dims + # Transformer encoder + t = 256 + c = 512 + self.wrap_encoder_for_feature = WrapEncoderForFeature( + src_vocab_size=1, + max_length=t, + n_layer=self.num_encoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + # PVAM + self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1) + self.fc0 = paddle.nn.Linear( + in_features=in_channels, + out_features=in_channels, ) + self.emb = paddle.nn.Embedding( + num_embeddings=self.max_length, embedding_dim=in_channels) + self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2) + self.fc1 = paddle.nn.Linear( + in_features=in_channels, out_features=1, bias_attr=False) + + def forward(self, inputs, encoder_word_pos, gsrm_word_pos): + b, c, h, w = inputs.shape + conv_features = paddle.reshape(inputs, shape=[-1, c, h * w]) + conv_features = paddle.transpose(conv_features, perm=[0, 2, 1]) + # transformer encoder + b, t, c = conv_features.shape + + enc_inputs = [conv_features, encoder_word_pos, None] + word_features = self.wrap_encoder_for_feature(enc_inputs) + + # pvam + b, t, c = word_features.shape + word_features = self.fc0(word_features) + word_features_ = paddle.reshape(word_features, [-1, 1, t, c]) + word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1]) + word_pos_feature = self.emb(gsrm_word_pos) + word_pos_feature_ = paddle.reshape(word_pos_feature, + [-1, self.max_length, 1, c]) + word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1]) + y = word_pos_feature_ + word_features_ + y = F.tanh(y) + attention_weight = self.fc1(y) + attention_weight = paddle.reshape( + attention_weight, shape=[-1, self.max_length, t]) + attention_weight = F.softmax(attention_weight, axis=-1) + pvam_features = paddle.matmul(attention_weight, + word_features) #[b, max_length, c] + return pvam_features + + +class GSRM(nn.Layer): + def __init__(self, in_channels, char_num, max_text_length, num_heads, + num_encoder_tus, num_decoder_tus, hidden_dims): + super(GSRM, self).__init__() + self.char_num = char_num + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_tus + self.num_decoder_TUs = num_decoder_tus + self.hidden_dims = hidden_dims + + self.fc0 = paddle.nn.Linear( + in_features=in_channels, out_features=self.char_num) + self.wrap_encoder0 = WrapEncoder( + src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + self.wrap_encoder1 = WrapEncoder( + src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + self.mul = lambda x: paddle.matmul(x=x, + y=self.wrap_encoder0.prepare_decoder.emb0.weight, + transpose_y=True) + + def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2): + # ===== GSRM Visual-to-semantic embedding block ===== + b, t, c = inputs.shape + pvam_features = paddle.reshape(inputs, [-1, c]) + word_out = self.fc0(pvam_features) + word_ids = paddle.argmax(F.softmax(word_out), axis=1) + word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1]) + + #===== GSRM Semantic reasoning block ===== + """ + This module is achieved through bi-transformers, + ngram_feature1 is the froward one, ngram_fetaure2 is the backward one + """ + pad_idx = self.char_num + + word1 = paddle.cast(word_ids, "float32") + word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC") + word1 = paddle.cast(word1, "int64") + word1 = word1[:, :-1, :] + word2 = word_ids + + enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1] + enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2] + + gsrm_feature1 = self.wrap_encoder0(enc_inputs_1) + gsrm_feature2 = self.wrap_encoder1(enc_inputs_2) + + gsrm_feature2 = F.pad(gsrm_feature2, [0, 1], + value=0., + data_format="NLC") + gsrm_feature2 = gsrm_feature2[:, 1:, ] + gsrm_features = gsrm_feature1 + gsrm_feature2 + + gsrm_out = self.mul(gsrm_features) + + b, t, c = gsrm_out.shape + gsrm_out = paddle.reshape(gsrm_out, [-1, c]) + + return gsrm_features, word_out, gsrm_out + + +class VSFD(nn.Layer): + def __init__(self, in_channels=512, pvam_ch=512, char_num=38): + super(VSFD, self).__init__() + self.char_num = char_num + self.fc0 = paddle.nn.Linear( + in_features=in_channels * 2, out_features=pvam_ch) + self.fc1 = paddle.nn.Linear( + in_features=pvam_ch, out_features=self.char_num) + + def forward(self, pvam_feature, gsrm_feature): + b, t, c1 = pvam_feature.shape + b, t, c2 = gsrm_feature.shape + combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2) + img_comb_feature_ = paddle.reshape( + combine_feature_, shape=[-1, c1 + c2]) + img_comb_feature_map = self.fc0(img_comb_feature_) + img_comb_feature_map = F.sigmoid(img_comb_feature_map) + img_comb_feature_map = paddle.reshape( + img_comb_feature_map, shape=[-1, t, c1]) + combine_feature = img_comb_feature_map * pvam_feature + ( + 1.0 - img_comb_feature_map) * gsrm_feature + img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1]) + + out = self.fc1(img_comb_feature) + return out + + +class SRNHead(nn.Layer): + def __init__(self, in_channels, out_channels, max_text_length, num_heads, + num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs): + super(SRNHead, self).__init__() + self.char_num = out_channels + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_TUs + self.num_decoder_TUs = num_decoder_TUs + self.hidden_dims = hidden_dims + + self.pvam = PVAM( + in_channels=in_channels, + char_num=self.char_num, + max_text_length=self.max_length, + num_heads=self.num_heads, + num_encoder_tus=self.num_encoder_TUs, + hidden_dims=self.hidden_dims) + + self.gsrm = GSRM( + in_channels=in_channels, + char_num=self.char_num, + max_text_length=self.max_length, + num_heads=self.num_heads, + num_encoder_tus=self.num_encoder_TUs, + num_decoder_tus=self.num_decoder_TUs, + hidden_dims=self.hidden_dims) + self.vsfd = VSFD(in_channels=in_channels) + + self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 + + def forward(self, inputs, others): + encoder_word_pos = others[0] + gsrm_word_pos = others[1] + gsrm_slf_attn_bias1 = others[2] + gsrm_slf_attn_bias2 = others[3] + + pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos) + + gsrm_feature, word_out, gsrm_out = self.gsrm( + pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + + final_out = self.vsfd(pvam_feature, gsrm_feature) + if not self.training: + final_out = F.softmax(final_out, axis=1) + + _, decoded_out = paddle.topk(final_out, k=1) + + predicts = OrderedDict([ + ('predict', final_out), + ('pvam_feature', pvam_feature), + ('decoded_out', decoded_out), + ('word_out', word_out), + ('gsrm_out', gsrm_out), + ]) + + return predicts diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py new file mode 100644 index 00000000..6aeb8f0c --- /dev/null +++ b/ppocr/modeling/heads/self_attention.py @@ -0,0 +1,408 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +from paddle import ParamAttr, nn +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import numpy as np +gradient_clip = 10 + + +class WrapEncoderForFeature(nn.Layer): + def __init__(self, + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + bos_idx=0): + super(WrapEncoderForFeature, self).__init__() + + self.prepare_encoder = PrepareEncoder( + src_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx, + word_emb_param_name="src_word_emb_table") + self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, prepostprocess_dropout, + attention_dropout, relu_dropout, preprocess_cmd, + postprocess_cmd) + + def forward(self, enc_inputs): + conv_features, src_pos, src_slf_attn_bias = enc_inputs + enc_input = self.prepare_encoder(conv_features, src_pos) + enc_output = self.encoder(enc_input, src_slf_attn_bias) + return enc_output + + +class WrapEncoder(nn.Layer): + """ + embedder + encoder + """ + + def __init__(self, + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + bos_idx=0): + super(WrapEncoder, self).__init__() + + self.prepare_decoder = PrepareDecoder( + src_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx) + self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, prepostprocess_dropout, + attention_dropout, relu_dropout, preprocess_cmd, + postprocess_cmd) + + def forward(self, enc_inputs): + src_word, src_pos, src_slf_attn_bias = enc_inputs + enc_input = self.prepare_decoder(src_word, src_pos) + enc_output = self.encoder(enc_input, src_slf_attn_bias) + return enc_output + + +class Encoder(nn.Layer): + """ + encoder + """ + + def __init__(self, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + + super(Encoder, self).__init__() + + self.encoder_layers = list() + for i in range(n_layer): + self.encoder_layers.append( + self.add_sublayer( + "layer_%d" % i, + EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid, + prepostprocess_dropout, attention_dropout, + relu_dropout, preprocess_cmd, + postprocess_cmd))) + self.processer = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + + def forward(self, enc_input, attn_bias): + for encoder_layer in self.encoder_layers: + enc_output = encoder_layer(enc_input, attn_bias) + enc_input = enc_output + enc_output = self.processer(enc_output) + return enc_output + + +class EncoderLayer(nn.Layer): + """ + EncoderLayer + """ + + def __init__(self, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + + super(EncoderLayer, self).__init__() + self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head, + attention_dropout) + self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model, + prepostprocess_dropout) + + self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + self.ffn = FFN(d_inner_hid, d_model, relu_dropout) + self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model, + prepostprocess_dropout) + + def forward(self, enc_input, attn_bias): + attn_output = self.self_attn( + self.preprocesser1(enc_input), None, None, attn_bias) + attn_output = self.postprocesser1(attn_output, enc_input) + ffn_output = self.ffn(self.preprocesser2(attn_output)) + ffn_output = self.postprocesser2(ffn_output, attn_output) + return ffn_output + + +class MultiHeadAttention(nn.Layer): + """ + Multi-Head Attention + """ + + def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): + super(MultiHeadAttention, self).__init__() + self.n_head = n_head + self.d_key = d_key + self.d_value = d_value + self.d_model = d_model + self.dropout_rate = dropout_rate + self.q_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_key * n_head, bias_attr=False) + self.k_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_key * n_head, bias_attr=False) + self.v_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_value * n_head, bias_attr=False) + self.proj_fc = paddle.nn.Linear( + in_features=d_value * n_head, out_features=d_model, bias_attr=False) + + def _prepare_qkv(self, queries, keys, values, cache=None): + if keys is None: # self-attention + keys, values = queries, queries + static_kv = False + else: # cross-attention + static_kv = True + + q = self.q_fc(queries) + q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key]) + q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) + + if cache is not None and static_kv and "static_k" in cache: + # for encoder-decoder attention in inference and has cached + k = cache["static_k"] + v = cache["static_v"] + else: + k = self.k_fc(keys) + v = self.v_fc(values) + k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) + k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) + v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) + v = paddle.transpose(x=v, perm=[0, 2, 1, 3]) + + if cache is not None: + if static_kv and not "static_k" in cache: + # for encoder-decoder attention in inference and has not cached + cache["static_k"], cache["static_v"] = k, v + elif not static_kv: + # for decoder self-attention in inference + cache_k, cache_v = cache["k"], cache["v"] + k = paddle.concat([cache_k, k], axis=2) + v = paddle.concat([cache_v, v], axis=2) + cache["k"], cache["v"] = k, v + + return q, k, v + + def forward(self, queries, keys, values, attn_bias, cache=None): + # compute q ,k ,v + keys = queries if keys is None else keys + values = keys if values is None else values + q, k, v = self._prepare_qkv(queries, keys, values, cache) + + # scale dot product attention + product = paddle.matmul(x=q, y=k, transpose_y=True) + product = product * self.d_model**-0.5 + if attn_bias is not None: + product += attn_bias + weights = F.softmax(product) + if self.dropout_rate: + weights = F.dropout( + weights, p=self.dropout_rate, mode="downscale_in_infer") + out = paddle.matmul(weights, v) + + # combine heads + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.proj_fc(out) + + return out + + +class PrePostProcessLayer(nn.Layer): + """ + PrePostProcessLayer + """ + + def __init__(self, process_cmd, d_model, dropout_rate): + super(PrePostProcessLayer, self).__init__() + self.process_cmd = process_cmd + self.functors = [] + for cmd in self.process_cmd: + if cmd == "a": # add residual connection + self.functors.append(lambda x, y: x + y if y is not None else x) + elif cmd == "n": # add layer normalization + self.functors.append( + self.add_sublayer( + "layer_norm_%d" % len( + self.sublayers(include_sublayers=False)), + paddle.nn.LayerNorm( + normalized_shape=d_model, + weight_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(1.)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.))))) + elif cmd == "d": # add dropout + self.functors.append(lambda x: F.dropout( + x, p=dropout_rate, mode="downscale_in_infer") + if dropout_rate else x) + + def forward(self, x, residual=None): + for i, cmd in enumerate(self.process_cmd): + if cmd == "a": + x = self.functors[i](x, residual) + else: + x = self.functors[i](x) + return x + + +class PrepareEncoder(nn.Layer): + def __init__(self, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0, + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): + super(PrepareEncoder, self).__init__() + self.src_emb_dim = src_emb_dim + self.src_max_len = src_max_len + self.emb = paddle.nn.Embedding( + num_embeddings=self.src_max_len, + embedding_dim=self.src_emb_dim, + sparse=True) + self.dropout_rate = dropout_rate + + def forward(self, src_word, src_pos): + src_word_emb = src_word + src_word_emb = fluid.layers.cast(src_word_emb, 'float32') + src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5) + src_pos = paddle.squeeze(src_pos, axis=-1) + src_pos_enc = self.emb(src_pos) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + if self.dropout_rate: + out = F.dropout( + x=enc_input, p=self.dropout_rate, mode="downscale_in_infer") + else: + out = enc_input + return out + + +class PrepareDecoder(nn.Layer): + def __init__(self, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0, + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): + super(PrepareDecoder, self).__init__() + self.src_emb_dim = src_emb_dim + """ + self.emb0 = Embedding(num_embeddings=src_vocab_size, + embedding_dim=src_emb_dim) + """ + self.emb0 = paddle.nn.Embedding( + num_embeddings=src_vocab_size, + embedding_dim=self.src_emb_dim, + weight_attr=paddle.ParamAttr( + name=word_emb_param_name, + initializer=nn.initializer.Normal(0., src_emb_dim**-0.5))) + self.emb1 = paddle.nn.Embedding( + num_embeddings=src_max_len, + embedding_dim=self.src_emb_dim, + weight_attr=paddle.ParamAttr(name=pos_enc_param_name)) + self.dropout_rate = dropout_rate + + def forward(self, src_word, src_pos): + src_word = fluid.layers.cast(src_word, 'int64') + src_word = paddle.squeeze(src_word, axis=-1) + src_word_emb = self.emb0(src_word) + src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5) + src_pos = paddle.squeeze(src_pos, axis=-1) + src_pos_enc = self.emb1(src_pos) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + if self.dropout_rate: + out = F.dropout( + x=enc_input, p=self.dropout_rate, mode="downscale_in_infer") + else: + out = enc_input + return out + + +class FFN(nn.Layer): + """ + Feed-Forward Network + """ + + def __init__(self, d_inner_hid, d_model, dropout_rate): + super(FFN, self).__init__() + self.dropout_rate = dropout_rate + self.fc1 = paddle.nn.Linear( + in_features=d_model, out_features=d_inner_hid) + self.fc2 = paddle.nn.Linear( + in_features=d_inner_hid, out_features=d_model) + + def forward(self, x): + hidden = self.fc1(x) + hidden = F.relu(hidden) + if self.dropout_rate: + hidden = F.dropout( + hidden, p=self.dropout_rate, mode="downscale_in_infer") + out = self.fc2(hidden) + return out diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index c9b42e08..0156e438 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -26,11 +26,12 @@ def build_post_process(config, global_config=None): from .db_postprocess import DBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess - from .rec_postprocess import CTCLabelDecode, AttnLabelDecode + from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode from .cls_postprocess import ClsPostProcess support_dict = [ - 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' + 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', + 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index a18e101b..c2303cea 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -29,6 +29,9 @@ class BaseRecLabelDecode(object): assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) + self.beg_str = "sos" + self.end_str = "eos" + if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -104,7 +107,6 @@ class CTCLabelDecode(BaseRecLabelDecode): def __call__(self, preds, label=None, *args, **kwargs): if isinstance(preds, paddle.Tensor): preds = preds.numpy() - preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob) @@ -153,3 +155,83 @@ class AttnLabelDecode(BaseRecLabelDecode): assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class SRNLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='en', + use_space_char=False, + **kwargs): + super(SRNLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + pred = preds['predict'] + char_num = len(self.character_str) + 2 + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = np.reshape(pred, [-1, char_num]) + + preds_idx = np.argmax(pred, axis=1) + preds_prob = np.max(pred, axis=1) + + preds_idx = np.reshape(preds_idx, [-1, 25]) + + preds_prob = np.reshape(preds_prob, [-1, 25]) + + text = self.decode(preds_idx, preds_prob) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def decode(self, text_index, text_prob=None, is_remove_duplicate=True): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx diff --git a/tools/export_model.py b/tools/export_model.py index 74357d58..58dc0def 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", help="configuration file to use") + parser.add_argument( + "-o", "--output_path", type=str, default='./output/infer/') + return parser.parse_args() + + def main(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -51,14 +59,33 @@ def main(): model.eval() save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - infer_shape = [3, 32, 100] if config['Architecture'][ - 'model_type'] != "det" else [3, 640, 640] - model = to_static( - model, - input_spec=[ + + if config['Architecture']['algorithm'] == "SRN": + other_shape = [ paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') - ]) + shape=[None, 1, 64, 256], dtype='float32'), [ + paddle.static.InputSpec( + shape=[None, 256, 1], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 25, 1], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 8, 25, 25], dtype="int64"), + paddle.static.InputSpec( + shape=[None, 8, 25, 25], dtype="int64") + ] + ] + model = to_static(model, input_spec=other_shape) + + else: + infer_shape = [3, 32, 100] if config['Architecture'][ + 'model_type'] != "det" else [3, 640, 640] + model = to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + infer_shape, dtype='float32') + ]) + paddle.jit.save(model, save_path) logger.info('inference model is saved to {}'.format(save_path)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 974fdbb6..fd895e50 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -25,6 +25,7 @@ import numpy as np import math import time import traceback +import paddle import tools.infer.utility as utility from ppocr.postprocess import build_post_process @@ -46,6 +47,13 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + if self.rec_algorithm == "SRN": + postprocess_params = { + 'name': 'SRNLabelDecode', + "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors = \ utility.create_predictor(args, 'rec', logger) @@ -70,6 +78,78 @@ class TextRecognizer(object): padding_im[:, :, 0:resized_w] = resized_image return padding_im + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile( + gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile( + gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def process_image_srn(self, img, image_shape, num_heads, max_text_length): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + self.srn_other_inputs(image_shape, num_heads, max_text_length) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + encoder_word_pos = encoder_word_pos.astype(np.int64) + gsrm_word_pos = gsrm_word_pos.astype(np.int64) + + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -93,21 +173,64 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) + if self.rec_algorithm != "SRN": + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.process_image_srn( + img_list[indices[ino]], self.rec_image_shape, 8, 25) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - preds = outputs[0] + + if self.rec_algorithm == "SRN": + starttime = time.time() + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, + encoder_word_pos_list, + gsrm_word_pos_list, + gsrm_slf_attn_bias1_list, + gsrm_slf_attn_bias2_list, + ] + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[ + i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = {"predict": outputs[2]} + else: + starttime = time.time() + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = outputs[0] + rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 7e4b0811..075ec261 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -62,7 +62,13 @@ def main(): elif op_name in ['RecResizeImg']: op[op_name]['infer_mode'] = True elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['image'] + if config['Architecture']['algorithm'] == "SRN": + op[op_name]['keep_keys'] = [ + 'image', 'encoder_word_pos', 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' + ] + else: + op[op_name]['keep_keys'] = ['image'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) @@ -74,10 +80,25 @@ def main(): img = f.read() data = {'image': img} batch = transform(data, ops) + if config['Architecture']['algorithm'] == "SRN": + encoder_word_pos_list = np.expand_dims(batch[1], axis=0) + gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) + gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) + gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) + + others = [ + paddle.to_tensor(encoder_word_pos_list), + paddle.to_tensor(gsrm_word_pos_list), + paddle.to_tensor(gsrm_slf_attn_bias1_list), + paddle.to_tensor(gsrm_slf_attn_bias2_list) + ] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) - preds = model(images) + if config['Architecture']['algorithm'] == "SRN": + preds = model(images, others) + else: + preds = model(images) post_result = post_process_class(preds) for rec_reuslt in post_result: logger.info('\t result: {}'.format(rec_reuslt)) diff --git a/tools/program.py b/tools/program.py index c2915426..ce52a610 100755 --- a/tools/program.py +++ b/tools/program.py @@ -179,9 +179,9 @@ def train(config, if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: - start_epoch = 1 + start_epoch = 0 - for epoch in range(start_epoch, epoch_num + 1): + for epoch in range(start_epoch, epoch_num): if epoch > 0: train_dataloader = build_dataloader(config, 'Train', device, logger) train_batch_cost = 0.0 @@ -194,7 +194,11 @@ def train(config, break lr = optimizer.get_lr() images = batch[0] - preds = model(images) + if config['Architecture']['algorithm'] == "SRN": + others = batch[-4:] + preds = model(images, others) + else: + preds = model(images) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -212,6 +216,7 @@ def train(config, stats['lr'] = lr train_stats.update(stats) + #cal_metric_during_train = False if cal_metric_during_train: # onlt rec and cls need batch = [item.numpy() for item in batch] post_result = post_process_class(preds, batch[1]) @@ -312,8 +317,9 @@ def eval(model, valid_dataloader, post_process_class, eval_class): if idx >= len(valid_dataloader): break images = batch[0] + others = batch[-4:] start = time.time() - preds = model(images) + preds = model(images, others) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods -- GitLab