diff --git a/configs/sr/sr_telescope.yml b/configs/sr/sr_telescope.yml
new file mode 100644
index 0000000000000000000000000000000000000000..dc0b195ba6a7e5d3cf34f811bc6bf6e58fd98ea6
--- /dev/null
+++ b/configs/sr/sr_telescope.yml
@@ -0,0 +1,84 @@
+Global:
+ use_gpu: true
+ epoch_num: 100
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/sr/sr_telescope/
+ save_epoch_step: 3
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 1000]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir: ./output/sr/sr_telescope/infer
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_52.png
+ # for data or label process
+ character_dict_path:
+ max_text_length: 100
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/sr/predicts_telescope.txt
+
+Optimizer:
+ name: Adam
+ beta1: 0.5
+ beta2: 0.999
+ clip_norm: 0.25
+ lr:
+ learning_rate: 0.0001
+
+Architecture:
+ model_type: sr
+ algorithm: Telescope
+ Transform:
+ name: TBSRN
+ STN: True
+ infer_mode: False
+
+Loss:
+ name: TelescopeLoss
+ confuse_dict_path: ./ppocr/utils/dict/confuse.pkl
+
+
+PostProcess:
+ name: None
+
+Metric:
+ name: SRMetric
+ main_indicator: all
+
+Train:
+ dataset:
+ name: LMDBDataSetSR
+ data_dir: ./train_data/TextZoom/train
+ transforms:
+ - SRResize:
+ imgH: 32
+ imgW: 128
+ down_sample_scale: 2
+ - KeepKeys:
+ keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ batch_size_per_card: 16
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSetSR
+ data_dir: ./train_data/TextZoom/test
+ transforms:
+ - SRResize:
+ imgH: 32
+ imgW: 128
+ down_sample_scale: 2
+ - KeepKeys:
+ keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 16
+ num_workers: 0
+
diff --git a/doc/doc_ch/algorithm_sr_telescope.md b/doc/doc_ch/algorithm_sr_telescope.md
new file mode 100644
index 0000000000000000000000000000000000000000..9a21734b6e84c5e856940f5b2482032864d5ce27
--- /dev/null
+++ b/doc/doc_ch/algorithm_sr_telescope.md
@@ -0,0 +1,128 @@
+# Text Telescope
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [Scene Text Telescope: Text-Focused Scene Image Super-Resolution](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Scene_Text_Telescope_Text-Focused_Scene_Image_Super-Resolution_CVPR_2021_paper.pdf)
+
+> Chen, Jingye, Bin Li, and Xiangyang Xue
+
+> CVPR, 2021
+
+参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope) 数据下载说明,在TextZoom测试集合上超分算法效果如下:
+
+|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接|
+|---|---|---|---|---|---|
+|Text Telescope|tbsrn|21.56|0.7411| [configs/sr/sr_telescope.yml](../../configs/sr/sr_telescope.yml)|[训练模型](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz)|
+
+[TextZoom数据集](https://paddleocr.bj.bcebos.com/dataset/TextZoom.tar) 来自两个超分数据集RealSR和SR-RAW,两个数据集都包含LR-HR对,TextZoom有17367对训数据和4373对测试数据。
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
+
+- 训练
+
+在完成数据准备后,便可以启动训练,训练命令如下:
+
+```
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/sr/sr_telescope.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_telescope.yml
+
+```
+
+- 评估
+
+```
+# GPU 评估, Global.pretrained_model 为待测权重
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+- 预测:
+
+```
+# 预测使用的配置文件必须与训练一致
+python3 tools/infer_sr.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
+```
+
+![](../imgs_words_en/word_52.png)
+
+执行命令后,上面图像的超分结果如下:
+
+![](../imgs_results/sr_word_52.png)
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+
+首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Telescope 训练的[模型](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz) 为例,可以使用如下命令进行转换:
+```shell
+python3 tools/export_model.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
+```
+Text-Telescope 文本超分模型推理,可以执行如下命令:
+```
+python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
+
+```
+
+执行命令后,图像的超分结果如下:
+
+![](../imgs_results/sr_word_52.png)
+
+
+### 4.2 C++推理
+
+暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂未支持
+
+
+### 4.4 更多推理部署
+
+暂未支持
+
+
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@INPROCEEDINGS{9578891,
+ author={Chen, Jingye and Li, Bin and Xue, Xiangyang},
+ booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ title={Scene Text Telescope: Text-Focused Scene Image Super-Resolution},
+ year={2021},
+ volume={},
+ number={},
+ pages={12021-12030},
+ doi={10.1109/CVPR46437.2021.01185}}
+```
diff --git a/doc/doc_en/algorithm_sr_telescope_en.md b/doc/doc_en/algorithm_sr_telescope_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..89f3b373ea041aee33841c86727913c5523bc054
--- /dev/null
+++ b/doc/doc_en/algorithm_sr_telescope_en.md
@@ -0,0 +1,137 @@
+# Text Gestalt
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+
+## 1. Introduction
+
+Paper:
+> [Scene Text Telescope: Text-Focused Scene Image Super-Resolution](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Scene_Text_Telescope_Text-Focused_Scene_Image_Super-Resolution_CVPR_2021_paper.pdf)
+
+> Chen, Jingye, Bin Li, and Xiangyang Xue
+
+> CVPR, 2021
+
+Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+|---|---|---|---|---|---|
+|Text Gestalt|tsrn|21.56|0.7411| [configs/sr/sr_telescope.yml](../../configs/sr/sr_telescope.yml)|[train model](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz)|
+
+The [TextZoom dataset](https://paddleocr.bj.bcebos.com/dataset/TextZoom.tar) comes from two superfraction data sets, RealSR and SR-RAW, both of which contain LR-HR pairs. TextZoom has 17367 pairs of training data and 4373 pairs of test data.
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+
+python3 tools/train.py -c configs/sr/sr_telescope.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_telescope.yml
+
+```
+
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+
+python3 tools/infer_sr.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
+```
+
+![](../imgs_words_en/word_52.png)
+
+After executing the command, the super-resolution result of the above image is as follows:
+
+![](../imgs_results/sr_word_52.png)
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+
+First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz) ), you can use the following command to convert:
+
+```shell
+python3 tools/export_model.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
+```
+
+For Text-Telescope super-resolution model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
+
+```
+
+After executing the command, the super-resolution result of the above image is as follows:
+
+![](../imgs_results/sr_word_52.png)
+
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@INPROCEEDINGS{9578891,
+ author={Chen, Jingye and Li, Bin and Xue, Xiangyang},
+ booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ title={Scene Text Telescope: Text-Focused Scene Image Super-Resolution},
+ year={2021},
+ volume={},
+ number={},
+ pages={12021-12030},
+ doi={10.1109/CVPR46437.2021.01185}}
+```
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 6abaa408b3f6995a0b4c377206e8a1551b48c56b..46d6e81f29bb53fd5bcbfcef6ec183e7e69df76c 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -25,8 +25,6 @@ from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss
from .det_fce_loss import FCELoss
-from .det_ct_loss import CTLoss
-from .det_drrg_loss import DRRGLoss
# rec loss
from .rec_ctc_loss import CTCLoss
@@ -39,7 +37,6 @@ from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
-from .rec_rfl_loss import RFLLoss
# cls loss
from .cls_loss import ClsLoss
@@ -62,6 +59,7 @@ from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
# sr loss
from .stroke_focus_loss import StrokeFocusLoss
+from .text_focus_loss import TelescopeLoss
def build_loss(config):
@@ -71,7 +69,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
- 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss'
+ 'SLALoss', 'TelescopeLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/text_focus_loss.py b/ppocr/losses/text_focus_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b50628405b58c8589719cafb8c0efdaa7db05aa5
--- /dev/null
+++ b/ppocr/losses/text_focus_loss.py
@@ -0,0 +1,91 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/FudanVI/FudanOCR/blob/main/scene-text-telescope/loss/text_focus_loss.py
+"""
+
+import paddle.nn as nn
+import paddle
+import numpy as np
+import pickle as pkl
+
+standard_alphebet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+standard_dict = {}
+for index in range(len(standard_alphebet)):
+ standard_dict[standard_alphebet[index]] = index
+
+
+def load_confuse_matrix(confuse_dict_path):
+ f = open(confuse_dict_path, 'rb')
+ data = pkl.load(f)
+ f.close()
+ number = data[:10]
+ upper = data[10:36]
+ lower = data[36:]
+ end = np.ones((1, 62))
+ pad = np.ones((63, 1))
+ rearrange_data = np.concatenate((end, number, lower, upper), axis=0)
+ rearrange_data = np.concatenate((pad, rearrange_data), axis=1)
+ rearrange_data = 1 / rearrange_data
+ rearrange_data[rearrange_data == np.inf] = 1
+ rearrange_data = paddle.to_tensor(rearrange_data)
+
+ lower_alpha = 'abcdefghijklmnopqrstuvwxyz'
+ # upper_alpha = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
+ for i in range(63):
+ for j in range(63):
+ if i != j and standard_alphebet[j] in lower_alpha:
+ rearrange_data[i][j] = max(rearrange_data[i][j], rearrange_data[i][j + 26])
+ rearrange_data = rearrange_data[:37, :37]
+
+ return rearrange_data
+
+
+def weight_cross_entropy(pred, gt, weight_table):
+ batch = gt.shape[0]
+ weight = weight_table[gt]
+ pred_exp = paddle.exp(pred)
+ pred_exp_weight = weight * pred_exp
+ loss = 0
+ for i in range(len(gt)):
+ loss -= paddle.log(pred_exp_weight[i][gt[i]] / paddle.sum(pred_exp_weight, 1)[i])
+ return loss / batch
+
+
+class TelescopeLoss(nn.Layer):
+ def __init__(self, confuse_dict_path):
+ super(TelescopeLoss, self).__init__()
+ self.weight_table = load_confuse_matrix(confuse_dict_path)
+ self.mse_loss = nn.MSELoss()
+ self.ce_loss = nn.CrossEntropyLoss()
+ self.l1_loss = nn.L1Loss()
+
+ def forward(self, pred, data):
+ sr_img = pred["sr_img"]
+ hr_img = pred["hr_img"]
+ sr_pred = pred["sr_pred"]
+ text_gt = pred["text_gt"]
+
+ word_attention_map_gt = pred["word_attention_map_gt"]
+ word_attention_map_pred = pred["word_attention_map_pred"]
+ mse_loss = self.mse_loss(sr_img, hr_img)
+ attention_loss = self.l1_loss(word_attention_map_gt, word_attention_map_pred)
+ recognition_loss = weight_cross_entropy(sr_pred, text_gt, self.weight_table)
+ loss = mse_loss + attention_loss * 10 + recognition_loss * 0.0005
+ return {
+ "mse_loss": mse_loss,
+ "attention_loss": attention_loss,
+ "loss": loss
+ }
diff --git a/ppocr/modeling/heads/sr_rensnet_transformer.py b/ppocr/modeling/heads/sr_rensnet_transformer.py
index a004a12663ac2061a329236c58e147a017c80ba6..654f3fca5486229c176246237708c4cf6a8da9ec 100644
--- a/ppocr/modeling/heads/sr_rensnet_transformer.py
+++ b/ppocr/modeling/heads/sr_rensnet_transformer.py
@@ -15,18 +15,12 @@
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
"""
+import copy
+import math
+
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
-import math, copy
-import numpy as np
-
-# stroke-level alphabet
-alphabet = '0123456789'
-
-
-def get_alphabet_len():
- return len(alphabet)
def subsequent_mask(size):
@@ -373,10 +367,10 @@ class Encoder(nn.Layer):
class Transformer(nn.Layer):
- def __init__(self, in_channels=1):
+ def __init__(self, in_channels=1, alphabet='0123456789'):
super(Transformer, self).__init__()
-
- word_n_class = get_alphabet_len()
+ self.alphabet = alphabet
+ word_n_class = self.get_alphabet_len()
self.embedding_word_with_upperword = Embeddings(512, word_n_class)
self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
@@ -388,6 +382,9 @@ class Transformer(nn.Layer):
if p.dim() > 1:
nn.initializer.XavierNormal(p)
+ def get_alphabet_len(self):
+ return len(self.alphabet)
+
def forward(self, image, text_length, text_input, attention_map=None):
if image.shape[1] == 3:
R = image[:, 0:1, :, :]
@@ -415,7 +412,7 @@ class Transformer(nn.Layer):
if self.training:
total_length = paddle.sum(text_length)
- probs_res = paddle.zeros([total_length, get_alphabet_len()])
+ probs_res = paddle.zeros([total_length, self.get_alphabet_len()])
start = 0
for index, length in enumerate(text_length):
diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py
index b22c60bb3d5e1933056d37bad208f4c311139c8e..022ece60a56131a25049547a64bdaf9f94c0e69c 100755
--- a/ppocr/modeling/transforms/__init__.py
+++ b/ppocr/modeling/transforms/__init__.py
@@ -19,9 +19,10 @@ def build_transform(config):
from .tps import TPS
from .stn import STN_ON
from .tsrn import TSRN
+ from .tbsrn import TBSRN
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
- support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN']
+ support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN', 'TBSRN']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
diff --git a/ppocr/modeling/transforms/tbsrn.py b/ppocr/modeling/transforms/tbsrn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee119003600b0515feb6fd1049e2c91565528b7d
--- /dev/null
+++ b/ppocr/modeling/transforms/tbsrn.py
@@ -0,0 +1,264 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/FudanVI/FudanOCR/blob/main/scene-text-telescope/model/tbsrn.py
+"""
+
+import math
+import warnings
+import numpy as np
+import paddle
+from paddle import nn
+import string
+
+warnings.filterwarnings("ignore")
+
+from .tps_spatial_transformer import TPSSpatialTransformer
+from .stn import STN as STNHead
+from .tsrn import GruBlock, mish, UpsampleBLock
+from ppocr.modeling.heads.sr_rensnet_transformer import Transformer, LayerNorm, \
+ PositionwiseFeedForward, MultiHeadedAttention
+
+
+def positionalencoding2d(d_model, height, width):
+ """
+ :param d_model: dimension of the model
+ :param height: height of the positions
+ :param width: width of the positions
+ :return: d_model*height*width position matrix
+ """
+ if d_model % 4 != 0:
+ raise ValueError("Cannot use sin/cos positional encoding with "
+ "odd dimension (got dim={:d})".format(d_model))
+ pe = paddle.zeros([d_model, height, width])
+ # Each dimension use half of d_model
+ d_model = int(d_model / 2)
+ div_term = paddle.exp(paddle.arange(0., d_model, 2) *
+ -(math.log(10000.0) / d_model))
+ pos_w = paddle.arange(0., width, dtype='float32').unsqueeze(1)
+ pos_h = paddle.arange(0., height, dtype='float32').unsqueeze(1)
+
+ pe[0:d_model:2, :, :] = paddle.sin(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
+ pe[1:d_model:2, :, :] = paddle.cos(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
+ pe[d_model::2, :, :] = paddle.sin(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
+ pe[d_model + 1::2, :, :] = paddle.cos(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
+
+ return pe
+
+
+class FeatureEnhancer(nn.Layer):
+
+ def __init__(self):
+ super(FeatureEnhancer, self).__init__()
+
+ self.multihead = MultiHeadedAttention(h=4, d_model=128, dropout=0.1)
+ self.mul_layernorm1 = LayerNorm(features=128)
+
+ self.pff = PositionwiseFeedForward(128, 128)
+ self.mul_layernorm3 = LayerNorm(features=128)
+
+ self.linear = nn.Linear(128, 64)
+
+ def forward(self, conv_feature):
+ '''
+ text : (batch, seq_len, embedding_size)
+ global_info: (batch, embedding_size, 1, 1)
+ conv_feature: (batch, channel, H, W)
+ '''
+ batch = conv_feature.shape[0]
+ position2d = positionalencoding2d(64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024])
+ position2d = position2d.tile([batch, 1, 1])
+ conv_feature = paddle.concat([conv_feature, position2d], 1) # batch, 128(64+64), 32, 128
+ result = conv_feature.transpose([0, 2, 1])
+ origin_result = result
+ result = self.mul_layernorm1(origin_result + self.multihead(result, result, result, mask=None)[0])
+ origin_result = result
+ result = self.mul_layernorm3(origin_result + self.pff(result))
+ result = self.linear(result)
+ return result.transpose([0, 2, 1])
+
+
+def str_filt(str_, voc_type):
+ alpha_dict = {
+ 'digit': string.digits,
+ 'lower': string.digits + string.ascii_lowercase,
+ 'upper': string.digits + string.ascii_letters,
+ 'all': string.digits + string.ascii_letters + string.punctuation
+ }
+ if voc_type == 'lower':
+ str_ = str_.lower()
+ for char in str_:
+ if char not in alpha_dict[voc_type]:
+ str_ = str_.replace(char, '')
+ str_ = str_.lower()
+ return str_
+
+
+class TBSRN(nn.Layer):
+ def __init__(self,
+ in_channels=3,
+ scale_factor=2,
+ width=128,
+ height=32,
+ STN=True,
+ srb_nums=5,
+ mask=False,
+ hidden_units=32,
+ infer_mode=False):
+ super(TBSRN, self).__init__()
+ in_planes = 3
+ if mask:
+ in_planes = 4
+ assert math.log(scale_factor, 2) % 1 == 0
+ upsample_block_num = int(math.log(scale_factor, 2))
+ self.block1 = nn.Sequential(
+ nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4),
+ nn.PReLU()
+ # nn.ReLU()
+ )
+ self.srb_nums = srb_nums
+ for i in range(srb_nums):
+ setattr(self, 'block%d' % (i + 2), RecurrentResidualBlock(2 * hidden_units))
+
+ setattr(self, 'block%d' % (srb_nums + 2),
+ nn.Sequential(
+ nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1),
+ nn.BatchNorm2D(2 * hidden_units)
+ ))
+
+ # self.non_local = NonLocalBlock2D(64, 64)
+ block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)]
+ block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4))
+ setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
+ self.tps_inputsize = [height // scale_factor, width // scale_factor]
+ tps_outputsize = [height // scale_factor, width // scale_factor]
+ num_control_points = 20
+ tps_margins = [0.05, 0.05]
+ self.stn = STN
+ self.out_channels = in_channels
+ if self.stn:
+ self.tps = TPSSpatialTransformer(
+ output_image_size=tuple(tps_outputsize),
+ num_control_points=num_control_points,
+ margins=tuple(tps_margins))
+
+ self.stn_head = STNHead(
+ in_channels=in_planes,
+ num_ctrlpoints=num_control_points,
+ activation='none')
+ self.infer_mode = infer_mode
+
+ self.english_alphabet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+ self.english_dict = {}
+ for index in range(len(self.english_alphabet)):
+ self.english_dict[self.english_alphabet[index]] = index
+ transformer = Transformer(alphabet='-0123456789abcdefghijklmnopqrstuvwxyz')
+ self.transformer = transformer
+ for param in self.transformer.parameters():
+ param.trainable = False
+
+ def label_encoder(self, label):
+ batch = len(label)
+
+ length = [len(i) for i in label]
+ length_tensor = paddle.to_tensor(length, dtype='int64')
+
+ max_length = max(length)
+ input_tensor = np.zeros((batch, max_length))
+ for i in range(batch):
+ for j in range(length[i] - 1):
+ input_tensor[i][j + 1] = self.english_dict[label[i][j]]
+
+ text_gt = []
+ for i in label:
+ for j in i:
+ text_gt.append(self.english_dict[j])
+ text_gt = paddle.to_tensor(text_gt, dtype='int64')
+
+ input_tensor = paddle.to_tensor(input_tensor, dtype='int64')
+ return length_tensor, input_tensor, text_gt
+
+ def forward(self, x):
+ output = {}
+ if self.infer_mode:
+ output["lr_img"] = x
+ y = x
+ else:
+ output["lr_img"] = x[0]
+ output["hr_img"] = x[1]
+ y = x[0]
+ if self.stn and self.training:
+ _, ctrl_points_x = self.stn_head(y)
+ y, _ = self.tps(y, ctrl_points_x)
+ block = {'1': self.block1(y)}
+ for i in range(self.srb_nums + 1):
+ block[str(i + 2)] = getattr(self,
+ 'block%d' % (i + 2))(block[str(i + 1)])
+
+ block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
+ ((block['1'] + block[str(self.srb_nums + 2)]))
+
+ sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
+ output["sr_img"] = sr_img
+
+ if self.training:
+ hr_img = x[1]
+
+ # add transformer
+ label = [str_filt(i, 'lower') + '-' for i in x[2]]
+ length_tensor, input_tensor, text_gt = self.label_encoder(label)
+ hr_pred, word_attention_map_gt, hr_correct_list = self.transformer(hr_img, length_tensor,
+ input_tensor)
+ sr_pred, word_attention_map_pred, sr_correct_list = self.transformer(sr_img, length_tensor,
+ input_tensor)
+ output["hr_img"] = hr_img
+ output["hr_pred"] = hr_pred
+ output["text_gt"] = text_gt
+ output["word_attention_map_gt"] = word_attention_map_gt
+ output["sr_pred"] = sr_pred
+ output["word_attention_map_pred"] = word_attention_map_pred
+
+ return output
+
+
+class RecurrentResidualBlock(nn.Layer):
+ def __init__(self, channels):
+ super(RecurrentResidualBlock, self).__init__()
+ self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
+ self.bn1 = nn.BatchNorm2D(channels)
+ self.gru1 = GruBlock(channels, channels)
+ # self.prelu = nn.ReLU()
+ self.prelu = mish()
+ self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
+ self.bn2 = nn.BatchNorm2D(channels)
+ self.gru2 = GruBlock(channels, channels)
+ self.feature_enhancer = FeatureEnhancer()
+
+ for p in self.parameters():
+ if p.dim() > 1:
+ paddle.nn.initializer.XavierUniform(p)
+
+ def forward(self, x):
+ residual = self.conv1(x)
+ residual = self.bn1(residual)
+ residual = self.prelu(residual)
+ residual = self.conv2(residual)
+ residual = self.bn2(residual)
+
+ size = residual.shape
+ residual = residual.reshape([size[0], size[1], -1])
+ residual = self.feature_enhancer(residual)
+ residual = residual.reshape([size[0], size[1], size[2], size[3]])
+ return x + residual
\ No newline at end of file
diff --git a/ppocr/utils/dict/confuse.pkl b/ppocr/utils/dict/confuse.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..e5d485320bcc94b36fa4fd653644f07d1a974369
Binary files /dev/null and b/ppocr/utils/dict/confuse.pkl differ