From 0cdfc525073a4d285c39c780eb4e02266e716aa8 Mon Sep 17 00:00:00 2001 From: chenlizhi Date: Mon, 17 Oct 2022 15:15:37 +0800 Subject: [PATCH] add sr model Text Telescope --- configs/sr/sr_telescope.yml | 84 ++++++ doc/doc_ch/algorithm_sr_telescope.md | 128 +++++++++ doc/doc_en/algorithm_sr_telescope_en.md | 137 +++++++++ ppocr/losses/__init__.py | 6 +- ppocr/losses/text_focus_loss.py | 91 ++++++ .../modeling/heads/sr_rensnet_transformer.py | 23 +- ppocr/modeling/transforms/__init__.py | 3 +- ppocr/modeling/transforms/tbsrn.py | 264 ++++++++++++++++++ ppocr/utils/dict/confuse.pkl | Bin 0 -> 30912 bytes 9 files changed, 718 insertions(+), 18 deletions(-) create mode 100644 configs/sr/sr_telescope.yml create mode 100644 doc/doc_ch/algorithm_sr_telescope.md create mode 100644 doc/doc_en/algorithm_sr_telescope_en.md create mode 100644 ppocr/losses/text_focus_loss.py create mode 100644 ppocr/modeling/transforms/tbsrn.py create mode 100644 ppocr/utils/dict/confuse.pkl diff --git a/configs/sr/sr_telescope.yml b/configs/sr/sr_telescope.yml new file mode 100644 index 00000000..dc0b195b --- /dev/null +++ b/configs/sr/sr_telescope.yml @@ -0,0 +1,84 @@ +Global: + use_gpu: true + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/sr/sr_telescope/ + save_epoch_step: 3 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 1000] + cal_metric_during_train: False + pretrained_model: + checkpoints: + save_inference_dir: ./output/sr/sr_telescope/infer + use_visualdl: False + infer_img: doc/imgs_words_en/word_52.png + # for data or label process + character_dict_path: + max_text_length: 100 + infer_mode: False + use_space_char: False + save_res_path: ./output/sr/predicts_telescope.txt + +Optimizer: + name: Adam + beta1: 0.5 + beta2: 0.999 + clip_norm: 0.25 + lr: + learning_rate: 0.0001 + +Architecture: + model_type: sr + algorithm: Telescope + Transform: + name: TBSRN + STN: True + infer_mode: False + +Loss: + name: TelescopeLoss + confuse_dict_path: ./ppocr/utils/dict/confuse.pkl + + +PostProcess: + name: None + +Metric: + name: SRMetric + main_indicator: all + +Train: + dataset: + name: LMDBDataSetSR + data_dir: ./train_data/TextZoom/train + transforms: + - SRResize: + imgH: 32 + imgW: 128 + down_sample_scale: 2 + - KeepKeys: + keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order + loader: + shuffle: False + batch_size_per_card: 16 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSetSR + data_dir: ./train_data/TextZoom/test + transforms: + - SRResize: + imgH: 32 + imgW: 128 + down_sample_scale: 2 + - KeepKeys: + keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 16 + num_workers: 0 + diff --git a/doc/doc_ch/algorithm_sr_telescope.md b/doc/doc_ch/algorithm_sr_telescope.md new file mode 100644 index 00000000..9a21734b --- /dev/null +++ b/doc/doc_ch/algorithm_sr_telescope.md @@ -0,0 +1,128 @@ +# Text Telescope + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [Scene Text Telescope: Text-Focused Scene Image Super-Resolution](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Scene_Text_Telescope_Text-Focused_Scene_Image_Super-Resolution_CVPR_2021_paper.pdf) + +> Chen, Jingye, Bin Li, and Xiangyang Xue + +> CVPR, 2021 + +参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope) 数据下载说明,在TextZoom测试集合上超分算法效果如下: + +|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接| +|---|---|---|---|---|---| +|Text Telescope|tbsrn|21.56|0.7411| [configs/sr/sr_telescope.yml](../../configs/sr/sr_telescope.yml)|[训练模型](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz)| + +[TextZoom数据集](https://paddleocr.bj.bcebos.com/dataset/TextZoom.tar) 来自两个超分数据集RealSR和SR-RAW,两个数据集都包含LR-HR对,TextZoom有17367对训数据和4373对测试数据。 + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 + +- 训练 + +在完成数据准备后,便可以启动训练,训练命令如下: + +``` +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/sr/sr_telescope.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_telescope.yml + +``` + +- 评估 + +``` +# GPU 评估, Global.pretrained_model 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +- 预测: + +``` +# 预测使用的配置文件必须与训练一致 +python3 tools/infer_sr.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png +``` + +![](../imgs_words_en/word_52.png) + +执行命令后,上面图像的超分结果如下: + +![](../imgs_results/sr_word_52.png) + + +## 4. 推理部署 + + +### 4.1 Python推理 + +首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Telescope 训练的[模型](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz) 为例,可以使用如下命令进行转换: +```shell +python3 tools/export_model.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out +``` +Text-Telescope 文本超分模型推理,可以执行如下命令: +``` +python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128 + +``` + +执行命令后,图像的超分结果如下: + +![](../imgs_results/sr_word_52.png) + + +### 4.2 C++推理 + +暂未支持 + + +### 4.3 Serving服务化部署 + +暂未支持 + + +### 4.4 更多推理部署 + +暂未支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@INPROCEEDINGS{9578891, + author={Chen, Jingye and Li, Bin and Xue, Xiangyang}, + booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + title={Scene Text Telescope: Text-Focused Scene Image Super-Resolution}, + year={2021}, + volume={}, + number={}, + pages={12021-12030}, + doi={10.1109/CVPR46437.2021.01185}} +``` diff --git a/doc/doc_en/algorithm_sr_telescope_en.md b/doc/doc_en/algorithm_sr_telescope_en.md new file mode 100644 index 00000000..89f3b373 --- /dev/null +++ b/doc/doc_en/algorithm_sr_telescope_en.md @@ -0,0 +1,137 @@ +# Text Gestalt + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + + +## 1. Introduction + +Paper: +> [Scene Text Telescope: Text-Focused Scene Image Super-Resolution](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Scene_Text_Telescope_Text-Focused_Scene_Image_Super-Resolution_CVPR_2021_paper.pdf) + +> Chen, Jingye, Bin Li, and Xiangyang Xue + +> CVPR, 2021 + +Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows: + +|Model|Backbone|config|Acc|Download link| +|---|---|---|---|---|---| +|Text Gestalt|tsrn|21.56|0.7411| [configs/sr/sr_telescope.yml](../../configs/sr/sr_telescope.yml)|[train model](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz)| + +The [TextZoom dataset](https://paddleocr.bj.bcebos.com/dataset/TextZoom.tar) comes from two superfraction data sets, RealSR and SR-RAW, both of which contain LR-HR pairs. TextZoom has 17367 pairs of training data and 4373 pairs of test data. + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (long training period, not recommended) + +python3 tools/train.py -c configs/sr/sr_telescope.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter + +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_telescope.yml + +``` + + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training + +python3 tools/infer_sr.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png +``` + +![](../imgs_words_en/word_52.png) + +After executing the command, the super-resolution result of the above image is as follows: + +![](../imgs_results/sr_word_52.png) + + +## 4. Inference and Deployment + + +### 4.1 Python Inference + +First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz) ), you can use the following command to convert: + +```shell +python3 tools/export_model.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out +``` + +For Text-Telescope super-resolution model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128 + +``` + +After executing the command, the super-resolution result of the above image is as follows: + +![](../imgs_results/sr_word_52.png) + + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@INPROCEEDINGS{9578891, + author={Chen, Jingye and Li, Bin and Xue, Xiangyang}, + booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + title={Scene Text Telescope: Text-Focused Scene Image Super-Resolution}, + year={2021}, + volume={}, + number={}, + pages={12021-12030}, + doi={10.1109/CVPR46437.2021.01185}} +``` diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 6abaa408..46d6e81f 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,8 +25,6 @@ from .det_east_loss import EASTLoss from .det_sast_loss import SASTLoss from .det_pse_loss import PSELoss from .det_fce_loss import FCELoss -from .det_ct_loss import CTLoss -from .det_drrg_loss import DRRGLoss # rec loss from .rec_ctc_loss import CTCLoss @@ -39,7 +37,6 @@ from .rec_pren_loss import PRENLoss from .rec_multi_loss import MultiLoss from .rec_vl_loss import VLLoss from .rec_spin_att_loss import SPINAttentionLoss -from .rec_rfl_loss import RFLLoss # cls loss from .cls_loss import ClsLoss @@ -62,6 +59,7 @@ from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss # sr loss from .stroke_focus_loss import StrokeFocusLoss +from .text_focus_loss import TelescopeLoss def build_loss(config): @@ -71,7 +69,7 @@ def build_loss(config): 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', - 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss' + 'SLALoss', 'TelescopeLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/text_focus_loss.py b/ppocr/losses/text_focus_loss.py new file mode 100644 index 00000000..b5062840 --- /dev/null +++ b/ppocr/losses/text_focus_loss.py @@ -0,0 +1,91 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/FudanVI/FudanOCR/blob/main/scene-text-telescope/loss/text_focus_loss.py +""" + +import paddle.nn as nn +import paddle +import numpy as np +import pickle as pkl + +standard_alphebet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' +standard_dict = {} +for index in range(len(standard_alphebet)): + standard_dict[standard_alphebet[index]] = index + + +def load_confuse_matrix(confuse_dict_path): + f = open(confuse_dict_path, 'rb') + data = pkl.load(f) + f.close() + number = data[:10] + upper = data[10:36] + lower = data[36:] + end = np.ones((1, 62)) + pad = np.ones((63, 1)) + rearrange_data = np.concatenate((end, number, lower, upper), axis=0) + rearrange_data = np.concatenate((pad, rearrange_data), axis=1) + rearrange_data = 1 / rearrange_data + rearrange_data[rearrange_data == np.inf] = 1 + rearrange_data = paddle.to_tensor(rearrange_data) + + lower_alpha = 'abcdefghijklmnopqrstuvwxyz' + # upper_alpha = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + for i in range(63): + for j in range(63): + if i != j and standard_alphebet[j] in lower_alpha: + rearrange_data[i][j] = max(rearrange_data[i][j], rearrange_data[i][j + 26]) + rearrange_data = rearrange_data[:37, :37] + + return rearrange_data + + +def weight_cross_entropy(pred, gt, weight_table): + batch = gt.shape[0] + weight = weight_table[gt] + pred_exp = paddle.exp(pred) + pred_exp_weight = weight * pred_exp + loss = 0 + for i in range(len(gt)): + loss -= paddle.log(pred_exp_weight[i][gt[i]] / paddle.sum(pred_exp_weight, 1)[i]) + return loss / batch + + +class TelescopeLoss(nn.Layer): + def __init__(self, confuse_dict_path): + super(TelescopeLoss, self).__init__() + self.weight_table = load_confuse_matrix(confuse_dict_path) + self.mse_loss = nn.MSELoss() + self.ce_loss = nn.CrossEntropyLoss() + self.l1_loss = nn.L1Loss() + + def forward(self, pred, data): + sr_img = pred["sr_img"] + hr_img = pred["hr_img"] + sr_pred = pred["sr_pred"] + text_gt = pred["text_gt"] + + word_attention_map_gt = pred["word_attention_map_gt"] + word_attention_map_pred = pred["word_attention_map_pred"] + mse_loss = self.mse_loss(sr_img, hr_img) + attention_loss = self.l1_loss(word_attention_map_gt, word_attention_map_pred) + recognition_loss = weight_cross_entropy(sr_pred, text_gt, self.weight_table) + loss = mse_loss + attention_loss * 10 + recognition_loss * 0.0005 + return { + "mse_loss": mse_loss, + "attention_loss": attention_loss, + "loss": loss + } diff --git a/ppocr/modeling/heads/sr_rensnet_transformer.py b/ppocr/modeling/heads/sr_rensnet_transformer.py index a004a126..654f3fca 100644 --- a/ppocr/modeling/heads/sr_rensnet_transformer.py +++ b/ppocr/modeling/heads/sr_rensnet_transformer.py @@ -15,18 +15,12 @@ This code is refer from: https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py """ +import copy +import math + import paddle import paddle.nn as nn import paddle.nn.functional as F -import math, copy -import numpy as np - -# stroke-level alphabet -alphabet = '0123456789' - - -def get_alphabet_len(): - return len(alphabet) def subsequent_mask(size): @@ -373,10 +367,10 @@ class Encoder(nn.Layer): class Transformer(nn.Layer): - def __init__(self, in_channels=1): + def __init__(self, in_channels=1, alphabet='0123456789'): super(Transformer, self).__init__() - - word_n_class = get_alphabet_len() + self.alphabet = alphabet + word_n_class = self.get_alphabet_len() self.embedding_word_with_upperword = Embeddings(512, word_n_class) self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000) @@ -388,6 +382,9 @@ class Transformer(nn.Layer): if p.dim() > 1: nn.initializer.XavierNormal(p) + def get_alphabet_len(self): + return len(self.alphabet) + def forward(self, image, text_length, text_input, attention_map=None): if image.shape[1] == 3: R = image[:, 0:1, :, :] @@ -415,7 +412,7 @@ class Transformer(nn.Layer): if self.training: total_length = paddle.sum(text_length) - probs_res = paddle.zeros([total_length, get_alphabet_len()]) + probs_res = paddle.zeros([total_length, self.get_alphabet_len()]) start = 0 for index, length in enumerate(text_length): diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py index b22c60bb..022ece60 100755 --- a/ppocr/modeling/transforms/__init__.py +++ b/ppocr/modeling/transforms/__init__.py @@ -19,9 +19,10 @@ def build_transform(config): from .tps import TPS from .stn import STN_ON from .tsrn import TSRN + from .tbsrn import TBSRN from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN - support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN'] + support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN', 'TBSRN'] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppocr/modeling/transforms/tbsrn.py b/ppocr/modeling/transforms/tbsrn.py new file mode 100644 index 00000000..ee119003 --- /dev/null +++ b/ppocr/modeling/transforms/tbsrn.py @@ -0,0 +1,264 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/FudanVI/FudanOCR/blob/main/scene-text-telescope/model/tbsrn.py +""" + +import math +import warnings +import numpy as np +import paddle +from paddle import nn +import string + +warnings.filterwarnings("ignore") + +from .tps_spatial_transformer import TPSSpatialTransformer +from .stn import STN as STNHead +from .tsrn import GruBlock, mish, UpsampleBLock +from ppocr.modeling.heads.sr_rensnet_transformer import Transformer, LayerNorm, \ + PositionwiseFeedForward, MultiHeadedAttention + + +def positionalencoding2d(d_model, height, width): + """ + :param d_model: dimension of the model + :param height: height of the positions + :param width: width of the positions + :return: d_model*height*width position matrix + """ + if d_model % 4 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dimension (got dim={:d})".format(d_model)) + pe = paddle.zeros([d_model, height, width]) + # Each dimension use half of d_model + d_model = int(d_model / 2) + div_term = paddle.exp(paddle.arange(0., d_model, 2) * + -(math.log(10000.0) / d_model)) + pos_w = paddle.arange(0., width, dtype='float32').unsqueeze(1) + pos_h = paddle.arange(0., height, dtype='float32').unsqueeze(1) + + pe[0:d_model:2, :, :] = paddle.sin(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1]) + pe[1:d_model:2, :, :] = paddle.cos(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1]) + pe[d_model::2, :, :] = paddle.sin(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width]) + pe[d_model + 1::2, :, :] = paddle.cos(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width]) + + return pe + + +class FeatureEnhancer(nn.Layer): + + def __init__(self): + super(FeatureEnhancer, self).__init__() + + self.multihead = MultiHeadedAttention(h=4, d_model=128, dropout=0.1) + self.mul_layernorm1 = LayerNorm(features=128) + + self.pff = PositionwiseFeedForward(128, 128) + self.mul_layernorm3 = LayerNorm(features=128) + + self.linear = nn.Linear(128, 64) + + def forward(self, conv_feature): + ''' + text : (batch, seq_len, embedding_size) + global_info: (batch, embedding_size, 1, 1) + conv_feature: (batch, channel, H, W) + ''' + batch = conv_feature.shape[0] + position2d = positionalencoding2d(64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024]) + position2d = position2d.tile([batch, 1, 1]) + conv_feature = paddle.concat([conv_feature, position2d], 1) # batch, 128(64+64), 32, 128 + result = conv_feature.transpose([0, 2, 1]) + origin_result = result + result = self.mul_layernorm1(origin_result + self.multihead(result, result, result, mask=None)[0]) + origin_result = result + result = self.mul_layernorm3(origin_result + self.pff(result)) + result = self.linear(result) + return result.transpose([0, 2, 1]) + + +def str_filt(str_, voc_type): + alpha_dict = { + 'digit': string.digits, + 'lower': string.digits + string.ascii_lowercase, + 'upper': string.digits + string.ascii_letters, + 'all': string.digits + string.ascii_letters + string.punctuation + } + if voc_type == 'lower': + str_ = str_.lower() + for char in str_: + if char not in alpha_dict[voc_type]: + str_ = str_.replace(char, '') + str_ = str_.lower() + return str_ + + +class TBSRN(nn.Layer): + def __init__(self, + in_channels=3, + scale_factor=2, + width=128, + height=32, + STN=True, + srb_nums=5, + mask=False, + hidden_units=32, + infer_mode=False): + super(TBSRN, self).__init__() + in_planes = 3 + if mask: + in_planes = 4 + assert math.log(scale_factor, 2) % 1 == 0 + upsample_block_num = int(math.log(scale_factor, 2)) + self.block1 = nn.Sequential( + nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4), + nn.PReLU() + # nn.ReLU() + ) + self.srb_nums = srb_nums + for i in range(srb_nums): + setattr(self, 'block%d' % (i + 2), RecurrentResidualBlock(2 * hidden_units)) + + setattr(self, 'block%d' % (srb_nums + 2), + nn.Sequential( + nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1), + nn.BatchNorm2D(2 * hidden_units) + )) + + # self.non_local = NonLocalBlock2D(64, 64) + block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)] + block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4)) + setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_)) + self.tps_inputsize = [height // scale_factor, width // scale_factor] + tps_outputsize = [height // scale_factor, width // scale_factor] + num_control_points = 20 + tps_margins = [0.05, 0.05] + self.stn = STN + self.out_channels = in_channels + if self.stn: + self.tps = TPSSpatialTransformer( + output_image_size=tuple(tps_outputsize), + num_control_points=num_control_points, + margins=tuple(tps_margins)) + + self.stn_head = STNHead( + in_channels=in_planes, + num_ctrlpoints=num_control_points, + activation='none') + self.infer_mode = infer_mode + + self.english_alphabet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' + self.english_dict = {} + for index in range(len(self.english_alphabet)): + self.english_dict[self.english_alphabet[index]] = index + transformer = Transformer(alphabet='-0123456789abcdefghijklmnopqrstuvwxyz') + self.transformer = transformer + for param in self.transformer.parameters(): + param.trainable = False + + def label_encoder(self, label): + batch = len(label) + + length = [len(i) for i in label] + length_tensor = paddle.to_tensor(length, dtype='int64') + + max_length = max(length) + input_tensor = np.zeros((batch, max_length)) + for i in range(batch): + for j in range(length[i] - 1): + input_tensor[i][j + 1] = self.english_dict[label[i][j]] + + text_gt = [] + for i in label: + for j in i: + text_gt.append(self.english_dict[j]) + text_gt = paddle.to_tensor(text_gt, dtype='int64') + + input_tensor = paddle.to_tensor(input_tensor, dtype='int64') + return length_tensor, input_tensor, text_gt + + def forward(self, x): + output = {} + if self.infer_mode: + output["lr_img"] = x + y = x + else: + output["lr_img"] = x[0] + output["hr_img"] = x[1] + y = x[0] + if self.stn and self.training: + _, ctrl_points_x = self.stn_head(y) + y, _ = self.tps(y, ctrl_points_x) + block = {'1': self.block1(y)} + for i in range(self.srb_nums + 1): + block[str(i + 2)] = getattr(self, + 'block%d' % (i + 2))(block[str(i + 1)]) + + block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \ + ((block['1'] + block[str(self.srb_nums + 2)])) + + sr_img = paddle.tanh(block[str(self.srb_nums + 3)]) + output["sr_img"] = sr_img + + if self.training: + hr_img = x[1] + + # add transformer + label = [str_filt(i, 'lower') + '-' for i in x[2]] + length_tensor, input_tensor, text_gt = self.label_encoder(label) + hr_pred, word_attention_map_gt, hr_correct_list = self.transformer(hr_img, length_tensor, + input_tensor) + sr_pred, word_attention_map_pred, sr_correct_list = self.transformer(sr_img, length_tensor, + input_tensor) + output["hr_img"] = hr_img + output["hr_pred"] = hr_pred + output["text_gt"] = text_gt + output["word_attention_map_gt"] = word_attention_map_gt + output["sr_pred"] = sr_pred + output["word_attention_map_pred"] = word_attention_map_pred + + return output + + +class RecurrentResidualBlock(nn.Layer): + def __init__(self, channels): + super(RecurrentResidualBlock, self).__init__() + self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2D(channels) + self.gru1 = GruBlock(channels, channels) + # self.prelu = nn.ReLU() + self.prelu = mish() + self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2D(channels) + self.gru2 = GruBlock(channels, channels) + self.feature_enhancer = FeatureEnhancer() + + for p in self.parameters(): + if p.dim() > 1: + paddle.nn.initializer.XavierUniform(p) + + def forward(self, x): + residual = self.conv1(x) + residual = self.bn1(residual) + residual = self.prelu(residual) + residual = self.conv2(residual) + residual = self.bn2(residual) + + size = residual.shape + residual = residual.reshape([size[0], size[1], -1]) + residual = self.feature_enhancer(residual) + residual = residual.reshape([size[0], size[1], size[2], size[3]]) + return x + residual \ No newline at end of file diff --git a/ppocr/utils/dict/confuse.pkl b/ppocr/utils/dict/confuse.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e5d485320bcc94b36fa4fd653644f07d1a974369 GIT binary patch literal 30912 zcmZXdOKjZNdBy)jQnD+zU1eQ&>2AE~ri&&sXcsLiz(BmpB0v#44p5+ROxFgQbkj`| zbd`-a?IN3YUK;>p%SCU;KJ)ul~Dl z_ul&P?@sFPzI^+`R$ZTW>o0!u;)gr+vycC{Sby>5i~rs6bGPc97r**;uUUWd?2E3S z+pNF*^sOKFJkCn}uh0I?^?knQagMy6uAjSJ*EWvoFBko*@0EGA{>l2;_AU)Jc#cLuA*#eHTf33e$zjycLzuo%({P)X$Gx`1{wJwdXWxt}suJ^a>&-b%G&9mzD?)&?# z%**^*=hgbp&i);GKOL{X*3kFk{(jr*I$h#MYu>+K@N3(A+X$aTpQ*a&v155T4FC20y4SnuaW=fKO+R!0_! zT6r%%Y?|)N`MvhVx~D7;E2h&{=qGt@Wj?KQ&U|Y5`K8cBa=7IAA99P5_6nT_>Jk0o#k6T#>`x4*Qvu^1Tde<>OHvIj#`M90ur6+6QkM_H4J`{b1&yxG9 z()@=WXD9SR&XH%WZ{GB6Ti#YYekXb;y*|izdVgd(EroxQ^OpI3=-tlK(@jGnPBjgKDBg)ZV-!7a_R z8NHC4bUe?2Kkr4}B`;M&xO*Hr!Vl=C_bXYqez$G zblSDt@0s82>;v3H?xinVkw3}HZsY>S>xuwV<_7HrRy#x2LEA!!}#@V(W zA9{ZFv+DgF#XbtJ_9HLi>uzvJ^PYIU?aTw8;XC@y`t|-=_@sG8{re@)S8Ir$Lk;1_ zvE}5%_M#PjYQ9O!O*`_W-z|833z1Xtcf$Os8psj!)O&ELYUuf57sU6Q)}LnNhIPn3 zY{bq=Ptd=ihV%)$UTc7x*fDTV^I%W5V<*vX(RalAoy4x1UK{Zn&|ke)4XtxEc0>ES zm2ouxdgKH8;FoEf`Osg_fA)&LBQK(B*Yw{tKlXxOl8YU$2l*efJQV##FTs1!dBJ@D zB7EmLt$ROs0v_@;>lxIE_Z{$P1s4o(Ya(7WB>i|9RVx#;%E z&4lrB*3XYykC2O5%geps3HqUNPdx5Sa1z|p{^l$P*Nk)cgH6x#WpD*~5H76wclhHw z=HpZ!J-fj>{r-8@1HB|SQ=zBi#8NO@(haTr){4ae!&=5a{ zjCV)Ij}yOcMz1t))B9+9J>BpZx@bQ)%)d$FaMgfbS~qrZH{VO{kgFxr;ja1Fw4AO4 z|HQv7%T>$sb-cd|wo|+28~a@JxVzpT{@kYNvu%DZdY>D9yX@0N2 z@;&+{dotqjr~TZP^$CBo<^8pyXV@L|OL$*!oAKc@e8)f1=bND)bQ6AHm(d^m%&Xon z@ms-T;Y!PLzV7{lTliU9{{5->@jUtpPHG>64asZU^vB=ZjvdnXd*KWEs(rGKmhtjw z^aniD?`I;<{7&*XXZhRAeuyv7FX?5=>p8X^*p6JWK5$a>-i|zqKiioHzftSMPIj%o zE7>=GkJd?CvKQO{AK){57k{Qq_fyN|fc={T%X8QCD)=S&n((>{zp-b+iD~okwsH8T z$LX2hYyQ3+IoA5gJCrvat@6wj{3UWyLL!E5=Q<|U3sUyl8|6Ypo= z`iXzuj^2omTc*#PU(w4gp@$UE^n@TX(BgU{r%y5{GR={RJ5_iX2nEkAs|6?)4K92;kv9_K}7|Lh-fi625- zzTtJ$8j{CD%lVPVsTz!{?}tJk=qNe6*pQsb5f$=#mUiTz^z9plkS=yg7?auWJ#UxN+l{gCs2+m;ApF80@aUXs(@(Q298*oW< zIJADl$HUdh$Lk&Q`&vVGyBGbFyzE(jdhz2}pZ3AL;C#>eycT>F-|!oj zO;7l~YJFXa96}%Vt>5EEJ+S=mTi(g{of_94#2$)YN0u-2guFHSy`OpYd^7%{aH{yd z+V@HFB>08kh`yf$+GnqIomg% zTBiF&<1h7~LF4F5A3w>{ty*8nH=r{3`^lZTPn!3?WOM zM9uiq`hLRWJ@>j#ydM0R;t!(F@C}?79gz?8ZN%d}wtml;pZ87IBd-hn#}8hKzYQNX z&xqxN`p&rbF`D%gS4ghs`tY-6KX@Z}NE{395pQYURsVk7`uxCrS&sb?{TH(z@K>)z zZoxz5lRS)jyb;@-FC*{TZ!>bK->>-lE0I6u*Z1T%7yLaq(oP&NI|yz)jT~qk^0M@u z6x>Ij;WK=fd@sar&~x;=Y#CRVjT`IU&$j80+$}{PrN`7eU+s60+!Z^E+_Mh7;;)kT znvDIBe#}_DUhQkp{_e$})Ov23?weVs>=|_}^0>49en0vzdXrBs_(%Q|JFIn)M=SMr z?YGpw<%e7`9S<#+Uznb|iBGhTy~L>+zioc(njiGpOxjKroCCMDKJsr1me;D%zKT83 zeu(S#q9>9E@&#+bag94=I-kUk6~54yQtD&ir}(jNeJS;C(WPogE|Cx9Xx(^L{C15~ z>fh2!aDUb7!QUXCvKaYA-z0Bc>)TZPN9~We89x;KW?iNJE!?LLNZywG@1gDQLimoH z3-8dcR`g!{nGK)yTxI{}Fn$pBM$eNU9*I7~XY?K1LS7|j&&=;5(|JGgBs$C+pHJfl ziQb=EerGJd&EOmUws85e<@j0Toc%~%)hKg#Jf-P zx%f%lU@SN)eLeNO_+k9NdnV(Eu1B#W=jps3IYl4!{j%luM)<06Z%1z4JYyVw8u1wU0OHLR%Ud^o z782&EuDgXPr<+>U924}Df_$B&1^}coM^|9$Q!Ckw z$nh)rq~pk+=!;)TU+%uwkK8R72RfEN;;VM_gZNAK=D>JOpEP+f`um{^@@NKr^dFL+Bj4LIPW58oLX;7zn0_&yu{xmKSh5EaTWE4s)EkQ6Z$WCCT~JN=w=^ogFE2df%Ox6$^5T5 zuX$iOrGJt526+It$Y*h0f^!Mr?Xu;l^nr-4^h2##ZYS(VAm`Xs>Om#HCVNPo0e`L3 z7bM?P#^<&0Pxg`XB&((m{rL3v&`(0$5x-^4`iC5nPiQ7?MP9*6>;>{9{*fmt_E~aT z&IiC}a9-ol4}jgEf4lVIG7s~M4rM-YU3R7!J;8pF{{Z;vb0a zu$K4eGkk{ck{|T6O1&MvYW*c|ihPOA z*x5VYFXtfWpCE2%nLcgHcgbgg@50As?7aM(Me7xHyglO`dBSe|0qyV9a}m4 zZLIBhmHIh$l6)lno#p(OJqLO9A-UQ_n_sT z{C3UHmGgia_aylR?Q6s98g%~w@#4O50y}rf{SwsShth{4y{Fz!A8s#lhyB%j*eU!T z>;(M}^f8q_fIdFE@8Zz-fgeSE8Tlb@q7MtXAntp^{%d8qdENBew*JuPK;H=QIeppG zv-q4mJob3d^y1vcj`97W{Y2`byT+X(f4`kLN_txOC)~%+;ap_N*Gdk09)I5ZKz}9@ z@1n=#X|Q+t{MhkLl|BR6LHwV6k6ZF<*hBmcjq`cpP~i^wsA=mR^%DH)w*7-e{P)N+ z`mS-w$ME`j{1fmE9M<#nc~m+7!8$cA{sjH9)MMxi-bh@DJ<|B>hrZRL@C`px^HWEo z4|mJYePjK7p7FJwk?=`$oiaVZ5$fOgjnvI=T8<7ZN3(u^H+}_i9rhDD0Dko2r1Tp5 z!+EDs>ow=-I1ho})wG?wmv~e3D1Hn$i+$F-#4FgB8RWqJJbj^arr!hGgL1A?_8kA5 zdN=(nmHqX3{|-A(KW#auDm(Gq_LlL7*K0p>9(O%{gXE+cdm(y`CqJRMa65dE9No2Cah?Ugm^uORJMj>C)p5^D9jwI9 z+E2-QB43&hxuk!PzT_3J3;WA?ubDnQs4D3x@z6cv4*dnCZ(QAKap~(3{pt679Da*WOOapwj`}}U)TODb(x0(sJ5lglICf|{6UWiNNj|XDQS|&{^P4zjrO!_#zumF_ z3qLu>f9l_(?_Zi9B~Pw>-VUE6&y!v!`mtpC&>v1e4(AqmPn}@G{O3FdeeILRdHjWe z^#2eaiN53+N*phor*2*DiNKCZ&hJN#u@j63juTh$z2?Qf7P%)bhR^U_zXSj1S3nL2 zb6!yEJ+<8*H=f{sy&9Kk9p!WIP=4h@`;qvi-bU;H2Z zLg7}a8)@Exr+Ux%)MKx6BI^amv6Ju_zH5KiyzdfciLRwzP0vxEEVzvw1yA%k>hTtg z7n{}_`rGkisgr@5_~G=ot=s+oF*uYEtoJHF?{DflPMfX|T zY3|kFJkDhFS~&8R=~4U_t(SAP+z*1Bjd_2ZpP;Xtc$GX3aU^|ErS2@fSc$*MezhL@ zYdI%AZn>gw62G1N6XTP=e%dd>^dAKlNN(w88+hd;%+ z1?tz-$vAJr`3LH%&0z82ClM z4d*||KV2|AsOMaCUnKX=TuNS0>nBd&JQDP;%ulSuetvBOd2Y6fj zYT1E;N8}4N{#I~L`t#Iygk7fZx2k0K!4dinO8ptU72R(;z9w# zyG5^W(DF*Z5BC93kHF6(uZMlEIe$ytn7jvdN${8YAm{d8Pd-_Es(enFd8r?s_W3o1 zKJXttpZDvy_klVgeLLjE$R|J_{MDM*G2r!WTmJ@Aua_KL^uFm&IkG&{&p{sf$o$|O z5&f0B=0E-#c^~X3{SWm2l=~6z%cU3i7t{7zrad3$m%ta!`y;nA#&PahA)e=)?t1(M z`HT2*GxkqrJRkQtV5c|YPveg$?j>(ay^wyOsrZrjS@_%7cYVGT+(r-e9C>N%JMr|A z{qM(#qZtRek=~Q<sQ$sgiZ>y`5(vx!THLq!k#6wZ&%`8n=KR)H2p8%B(EiT zq%KAOOSz{*>*U-;$&1L3AwSym`iU3ukFQ(q$kS5S=KebJozJX?R}+5_|A{`}DRq(6 zK71j*FD)n>S&5@ z@k=;AM&CI2!nrK^k;(gTPHx10>sO%{c>?@Pz3w``U-kU6#s~VQz|~RfTUoDu|H$;? z+zjzPaRUA3r&)*SOy3UrpE^wiE@}LE`x~DHpOAm)4d=*8-tl~1W#9Wd4E_0D<4h$# zBE6;0vD{09UP_;@$L=Hd_%FnfqBHV9zjTo|@KE!?U(RE2pJOX=uISJTpCtFUtPkV^ zHo`x}O~iAY!)%*>n|V$+!Mz~#A>hy9kD-smMZ}r(-Ie}}^Klh*P@W^cFa7H98T*bt z!(ZXx*U5LXFMVF<13&dS=Qrs0zZ<)y=eQS@ek}amxzLyAz)|J}PqBMEC;X?s09@sM z1Nu$MeH2uEalEZ|C+}JCT)!i4Mf{In#d);F;1HkGGh- zh}QAo75fLCYd_=*xYv`ucJ7(woD6xsM~Oeh*HPo+<5%#U`UW^8KAa}568*sq@QA*P z&l1PMXZTKhgua94`ksCk`oXwo0Dqf)jRlVnjz2Tb-LGrDr?DI0E`Aj_ia&}z^BzAA zIhWpWj+6M8b19rt#7=X*i9TcEB=Y|o_OHmB-!v{`hw)#^xhbs=JZ;BsVSVD$qU8@i zm3WT0blH9Z_tJ6hkh*Be^NK#;+KJ^7`(N@d=(FUM`?J4}9>C}GeCHf1_o0#xquxOu z8TVt6XXiXC@i+Bf?jPkG6ZV(>^3v}hy(xXbTGxp6f%A39@eSh<@_o~B80TEbD-Ajy zMxLMk4eo*BUOmonZKggW`YlJFgd@MVzeHap_r<_J`jd8&7n41u&%N{+;MahI!i&4s z_i}DidI4?^IUhm3yPPk<-$rl6-<9BY|L5>s`ghaz{5bU#={@&8uBIPC>nwH<9Mt&K zPsm%|uv}yJskd_P1$hzbX4E(7OQc`9_~F94M~O2<-&*dW(R#=~P!}g(Mc%UH_4GXb zu%-W2^OW;(dXDo4==+}Of<2#Aj>JI2L>!~W? zEcXafN2HH*GIbWwxAarO2gy&VpNh}qPdPtBKM?0o=!?dmrSGW7ukeb#P4Y?9bGZ-c zg8NZfH~RgZ=}SM>w*6}QCOLOk?mq($^@`tsJY2IJk-q~^N_-4%$e#hHxc9K+8{{7p zoWswT-duG5BYhf!zHghpCGr6Dt5SERf0=!e2jsjv_w;c-8?EQQr)lFBeah$~=jQRJ zxF5AVhnxzp$V-ve!=FNKXY4<5fA>`USkZ-h6FJw${g3E7buRjcu@l_iUGz}$N)aT;;q?f4}#b{cq~toO3Ss;mbbUGaf8C&H*Rr zKj9ofId>_0gTHl0Otz|zDgd* zpHk-?@I6oy!3){jQ~O=)=b`;K`WNXVCI8IdhuJiK6IXD)jynFO{9Os*%tia7)aQGS zhd577{l3(jz!~Bo@ddw{`V@AXd=d3H`aj#Ahrjbfy^r%r)alpb=SiN&O()I|;O~~d zA3&TTy9YjU?gRefcM)e(C!~*%d(nyiv4h0T^!G#WDdRJL4}o!xtq1foPMfdTbMk)N zKgD^=fykZo3VINyktgLI1H<82yku&B!oG$)=^mo#FxRSVte8qV`ko*$&kS`}cDLQRg{-FzXK+a7Oj}m{Aw*r4l9hrQa z#--0-IrJqzik=ex6W@uB+}lN5Phax9*EttIjCf1>HI_U#e8#flS7=@MPu#CgUa|BS zAm`#|vDb=MIJd!lZTvl<(&t6ILYz#VUHe5Z>60n`uhvPw2YvkG&ST&wfT!QZU%*d; z{+gfj1myF$Uky8oybAse{75;?e4SeRF-SYQp^c;DeA~)hceF)$k z=i=xmAudAysjt#^wBmKq&&|EK1?QoM=rd)%;ZFPx@}JTd>gLQ}{5$dSmiN#63H34Rcbv!O-rZ8SlAPX-z1BYH7cS>wrB^43KQ$ly zThvn~>2%qU= zo%Oun2K^}X1JfTvA1U>ld&v`vU)*zdKlp+^ihk5Hh->I;BrimtLz$1?!)N#|x>G+O zU&eWv{&TBi#sTmH|BAkulCKjU@^?sxW4Tw4`yuHwAg(WUT=WvW(fX+~vo7dH-vRk$ z{%!#GRS-YGZ{kbz51br{e#stk{{eWhVfp9$H}~2Rx8g4n|B=TaUZ5^JX_J&FMl^i^Wa~>56%tVGQX&sanC%@@pmVn2mNoz5r5Bc#PREc z@L#z1dF;5>fgU4AWOr7?g<%T}S!{EQh zFMm%__KSNTiBpK*H{ocgdFRrfFMZ>FQTl2)XUsWy@Pl*M zr7w;+m$-)bg3qNV#0T_y5GT=Z#OqSxXW}aCJGh8_*7vi