提交 0cdfc525 编写于 作者: 风为何不回来's avatar 风为何不回来

add sr model Text Telescope

上级 8babfc86
Global:
use_gpu: true
epoch_num: 100
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/sr/sr_telescope/
save_epoch_step: 3
# evaluation is run every 2000 iterations
eval_batch_step: [0, 1000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir: ./output/sr/sr_telescope/infer
use_visualdl: False
infer_img: doc/imgs_words_en/word_52.png
# for data or label process
character_dict_path:
max_text_length: 100
infer_mode: False
use_space_char: False
save_res_path: ./output/sr/predicts_telescope.txt
Optimizer:
name: Adam
beta1: 0.5
beta2: 0.999
clip_norm: 0.25
lr:
learning_rate: 0.0001
Architecture:
model_type: sr
algorithm: Telescope
Transform:
name: TBSRN
STN: True
infer_mode: False
Loss:
name: TelescopeLoss
confuse_dict_path: ./ppocr/utils/dict/confuse.pkl
PostProcess:
name: None
Metric:
name: SRMetric
main_indicator: all
Train:
dataset:
name: LMDBDataSetSR
data_dir: ./train_data/TextZoom/train
transforms:
- SRResize:
imgH: 32
imgW: 128
down_sample_scale: 2
- KeepKeys:
keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order
loader:
shuffle: False
batch_size_per_card: 16
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSetSR
data_dir: ./train_data/TextZoom/test
transforms:
- SRResize:
imgH: 32
imgW: 128
down_sample_scale: 2
- KeepKeys:
keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 0
# Text Telescope
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [Scene Text Telescope: Text-Focused Scene Image Super-Resolution](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Scene_Text_Telescope_Text-Focused_Scene_Image_Super-Resolution_CVPR_2021_paper.pdf)
> Chen, Jingye, Bin Li, and Xiangyang Xue
> CVPR, 2021
参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope) 数据下载说明,在TextZoom测试集合上超分算法效果如下:
|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接|
|---|---|---|---|---|---|
|Text Telescope|tbsrn|21.56|0.7411| [configs/sr/sr_telescope.yml](../../configs/sr/sr_telescope.yml)|[训练模型](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz)|
[TextZoom数据集](https://paddleocr.bj.bcebos.com/dataset/TextZoom.tar) 来自两个超分数据集RealSR和SR-RAW,两个数据集都包含LR-HR对,TextZoom有17367对训数据和4373对测试数据。
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
- 训练
在完成数据准备后,便可以启动训练,训练命令如下:
```
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/sr/sr_telescope.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_telescope.yml
```
- 评估
```
# GPU 评估, Global.pretrained_model 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
- 预测:
```
# 预测使用的配置文件必须与训练一致
python3 tools/infer_sr.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
```
![](../imgs_words_en/word_52.png)
执行命令后,上面图像的超分结果如下:
![](../imgs_results/sr_word_52.png)
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Telescope 训练的[模型](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz) 为例,可以使用如下命令进行转换:
```shell
python3 tools/export_model.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
```
Text-Telescope 文本超分模型推理,可以执行如下命令:
```
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
```
执行命令后,图像的超分结果如下:
![](../imgs_results/sr_word_52.png)
<a name="4-2"></a>
### 4.2 C++推理
暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂未支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂未支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@INPROCEEDINGS{9578891,
author={Chen, Jingye and Li, Bin and Xue, Xiangyang},
booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
title={Scene Text Telescope: Text-Focused Scene Image Super-Resolution},
year={2021},
volume={},
number={},
pages={12021-12030},
doi={10.1109/CVPR46437.2021.01185}}
```
# Text Gestalt
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [Scene Text Telescope: Text-Focused Scene Image Super-Resolution](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Scene_Text_Telescope_Text-Focused_Scene_Image_Super-Resolution_CVPR_2021_paper.pdf)
> Chen, Jingye, Bin Li, and Xiangyang Xue
> CVPR, 2021
Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows:
|Model|Backbone|config|Acc|Download link|
|---|---|---|---|---|---|
|Text Gestalt|tsrn|21.56|0.7411| [configs/sr/sr_telescope.yml](../../configs/sr/sr_telescope.yml)|[train model](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz)|
The [TextZoom dataset](https://paddleocr.bj.bcebos.com/dataset/TextZoom.tar) comes from two superfraction data sets, RealSR and SR-RAW, both of which contain LR-HR pairs. TextZoom has 17367 pairs of training data and 4373 pairs of test data.
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/sr/sr_telescope.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_telescope.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_sr.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
```
![](../imgs_words_en/word_52.png)
After executing the command, the super-resolution result of the above image is as follows:
![](../imgs_results/sr_word_52.png)
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz) ), you can use the following command to convert:
```shell
python3 tools/export_model.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
```
For Text-Telescope super-resolution model inference, the following commands can be executed:
```
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
```
After executing the command, the super-resolution result of the above image is as follows:
![](../imgs_results/sr_word_52.png)
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@INPROCEEDINGS{9578891,
author={Chen, Jingye and Li, Bin and Xue, Xiangyang},
booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
title={Scene Text Telescope: Text-Focused Scene Image Super-Resolution},
year={2021},
volume={},
number={},
pages={12021-12030},
doi={10.1109/CVPR46437.2021.01185}}
```
......@@ -25,8 +25,6 @@ from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss
from .det_fce_loss import FCELoss
from .det_ct_loss import CTLoss
from .det_drrg_loss import DRRGLoss
# rec loss
from .rec_ctc_loss import CTCLoss
......@@ -39,7 +37,6 @@ from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss
# cls loss
from .cls_loss import ClsLoss
......@@ -62,6 +59,7 @@ from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
# sr loss
from .stroke_focus_loss import StrokeFocusLoss
from .text_focus_loss import TelescopeLoss
def build_loss(config):
......@@ -71,7 +69,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss'
'SLALoss', 'TelescopeLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/scene-text-telescope/loss/text_focus_loss.py
"""
import paddle.nn as nn
import paddle
import numpy as np
import pickle as pkl
standard_alphebet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
standard_dict = {}
for index in range(len(standard_alphebet)):
standard_dict[standard_alphebet[index]] = index
def load_confuse_matrix(confuse_dict_path):
f = open(confuse_dict_path, 'rb')
data = pkl.load(f)
f.close()
number = data[:10]
upper = data[10:36]
lower = data[36:]
end = np.ones((1, 62))
pad = np.ones((63, 1))
rearrange_data = np.concatenate((end, number, lower, upper), axis=0)
rearrange_data = np.concatenate((pad, rearrange_data), axis=1)
rearrange_data = 1 / rearrange_data
rearrange_data[rearrange_data == np.inf] = 1
rearrange_data = paddle.to_tensor(rearrange_data)
lower_alpha = 'abcdefghijklmnopqrstuvwxyz'
# upper_alpha = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
for i in range(63):
for j in range(63):
if i != j and standard_alphebet[j] in lower_alpha:
rearrange_data[i][j] = max(rearrange_data[i][j], rearrange_data[i][j + 26])
rearrange_data = rearrange_data[:37, :37]
return rearrange_data
def weight_cross_entropy(pred, gt, weight_table):
batch = gt.shape[0]
weight = weight_table[gt]
pred_exp = paddle.exp(pred)
pred_exp_weight = weight * pred_exp
loss = 0
for i in range(len(gt)):
loss -= paddle.log(pred_exp_weight[i][gt[i]] / paddle.sum(pred_exp_weight, 1)[i])
return loss / batch
class TelescopeLoss(nn.Layer):
def __init__(self, confuse_dict_path):
super(TelescopeLoss, self).__init__()
self.weight_table = load_confuse_matrix(confuse_dict_path)
self.mse_loss = nn.MSELoss()
self.ce_loss = nn.CrossEntropyLoss()
self.l1_loss = nn.L1Loss()
def forward(self, pred, data):
sr_img = pred["sr_img"]
hr_img = pred["hr_img"]
sr_pred = pred["sr_pred"]
text_gt = pred["text_gt"]
word_attention_map_gt = pred["word_attention_map_gt"]
word_attention_map_pred = pred["word_attention_map_pred"]
mse_loss = self.mse_loss(sr_img, hr_img)
attention_loss = self.l1_loss(word_attention_map_gt, word_attention_map_pred)
recognition_loss = weight_cross_entropy(sr_pred, text_gt, self.weight_table)
loss = mse_loss + attention_loss * 10 + recognition_loss * 0.0005
return {
"mse_loss": mse_loss,
"attention_loss": attention_loss,
"loss": loss
}
......@@ -15,18 +15,12 @@
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
"""
import copy
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import math, copy
import numpy as np
# stroke-level alphabet
alphabet = '0123456789'
def get_alphabet_len():
return len(alphabet)
def subsequent_mask(size):
......@@ -373,10 +367,10 @@ class Encoder(nn.Layer):
class Transformer(nn.Layer):
def __init__(self, in_channels=1):
def __init__(self, in_channels=1, alphabet='0123456789'):
super(Transformer, self).__init__()
word_n_class = get_alphabet_len()
self.alphabet = alphabet
word_n_class = self.get_alphabet_len()
self.embedding_word_with_upperword = Embeddings(512, word_n_class)
self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
......@@ -388,6 +382,9 @@ class Transformer(nn.Layer):
if p.dim() > 1:
nn.initializer.XavierNormal(p)
def get_alphabet_len(self):
return len(self.alphabet)
def forward(self, image, text_length, text_input, attention_map=None):
if image.shape[1] == 3:
R = image[:, 0:1, :, :]
......@@ -415,7 +412,7 @@ class Transformer(nn.Layer):
if self.training:
total_length = paddle.sum(text_length)
probs_res = paddle.zeros([total_length, get_alphabet_len()])
probs_res = paddle.zeros([total_length, self.get_alphabet_len()])
start = 0
for index, length in enumerate(text_length):
......
......@@ -19,9 +19,10 @@ def build_transform(config):
from .tps import TPS
from .stn import STN_ON
from .tsrn import TSRN
from .tbsrn import TBSRN
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN']
support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN', 'TBSRN']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/scene-text-telescope/model/tbsrn.py
"""
import math
import warnings
import numpy as np
import paddle
from paddle import nn
import string
warnings.filterwarnings("ignore")
from .tps_spatial_transformer import TPSSpatialTransformer
from .stn import STN as STNHead
from .tsrn import GruBlock, mish, UpsampleBLock
from ppocr.modeling.heads.sr_rensnet_transformer import Transformer, LayerNorm, \
PositionwiseFeedForward, MultiHeadedAttention
def positionalencoding2d(d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model))
pe = paddle.zeros([d_model, height, width])
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = paddle.exp(paddle.arange(0., d_model, 2) *
-(math.log(10000.0) / d_model))
pos_w = paddle.arange(0., width, dtype='float32').unsqueeze(1)
pos_h = paddle.arange(0., height, dtype='float32').unsqueeze(1)
pe[0:d_model:2, :, :] = paddle.sin(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
pe[1:d_model:2, :, :] = paddle.cos(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
pe[d_model::2, :, :] = paddle.sin(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
pe[d_model + 1::2, :, :] = paddle.cos(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
return pe
class FeatureEnhancer(nn.Layer):
def __init__(self):
super(FeatureEnhancer, self).__init__()
self.multihead = MultiHeadedAttention(h=4, d_model=128, dropout=0.1)
self.mul_layernorm1 = LayerNorm(features=128)
self.pff = PositionwiseFeedForward(128, 128)
self.mul_layernorm3 = LayerNorm(features=128)
self.linear = nn.Linear(128, 64)
def forward(self, conv_feature):
'''
text : (batch, seq_len, embedding_size)
global_info: (batch, embedding_size, 1, 1)
conv_feature: (batch, channel, H, W)
'''
batch = conv_feature.shape[0]
position2d = positionalencoding2d(64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024])
position2d = position2d.tile([batch, 1, 1])
conv_feature = paddle.concat([conv_feature, position2d], 1) # batch, 128(64+64), 32, 128
result = conv_feature.transpose([0, 2, 1])
origin_result = result
result = self.mul_layernorm1(origin_result + self.multihead(result, result, result, mask=None)[0])
origin_result = result
result = self.mul_layernorm3(origin_result + self.pff(result))
result = self.linear(result)
return result.transpose([0, 2, 1])
def str_filt(str_, voc_type):
alpha_dict = {
'digit': string.digits,
'lower': string.digits + string.ascii_lowercase,
'upper': string.digits + string.ascii_letters,
'all': string.digits + string.ascii_letters + string.punctuation
}
if voc_type == 'lower':
str_ = str_.lower()
for char in str_:
if char not in alpha_dict[voc_type]:
str_ = str_.replace(char, '')
str_ = str_.lower()
return str_
class TBSRN(nn.Layer):
def __init__(self,
in_channels=3,
scale_factor=2,
width=128,
height=32,
STN=True,
srb_nums=5,
mask=False,
hidden_units=32,
infer_mode=False):
super(TBSRN, self).__init__()
in_planes = 3
if mask:
in_planes = 4
assert math.log(scale_factor, 2) % 1 == 0
upsample_block_num = int(math.log(scale_factor, 2))
self.block1 = nn.Sequential(
nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4),
nn.PReLU()
# nn.ReLU()
)
self.srb_nums = srb_nums
for i in range(srb_nums):
setattr(self, 'block%d' % (i + 2), RecurrentResidualBlock(2 * hidden_units))
setattr(self, 'block%d' % (srb_nums + 2),
nn.Sequential(
nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1),
nn.BatchNorm2D(2 * hidden_units)
))
# self.non_local = NonLocalBlock2D(64, 64)
block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)]
block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4))
setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
self.tps_inputsize = [height // scale_factor, width // scale_factor]
tps_outputsize = [height // scale_factor, width // scale_factor]
num_control_points = 20
tps_margins = [0.05, 0.05]
self.stn = STN
self.out_channels = in_channels
if self.stn:
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STNHead(
in_channels=in_planes,
num_ctrlpoints=num_control_points,
activation='none')
self.infer_mode = infer_mode
self.english_alphabet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
self.english_dict = {}
for index in range(len(self.english_alphabet)):
self.english_dict[self.english_alphabet[index]] = index
transformer = Transformer(alphabet='-0123456789abcdefghijklmnopqrstuvwxyz')
self.transformer = transformer
for param in self.transformer.parameters():
param.trainable = False
def label_encoder(self, label):
batch = len(label)
length = [len(i) for i in label]
length_tensor = paddle.to_tensor(length, dtype='int64')
max_length = max(length)
input_tensor = np.zeros((batch, max_length))
for i in range(batch):
for j in range(length[i] - 1):
input_tensor[i][j + 1] = self.english_dict[label[i][j]]
text_gt = []
for i in label:
for j in i:
text_gt.append(self.english_dict[j])
text_gt = paddle.to_tensor(text_gt, dtype='int64')
input_tensor = paddle.to_tensor(input_tensor, dtype='int64')
return length_tensor, input_tensor, text_gt
def forward(self, x):
output = {}
if self.infer_mode:
output["lr_img"] = x
y = x
else:
output["lr_img"] = x[0]
output["hr_img"] = x[1]
y = x[0]
if self.stn and self.training:
_, ctrl_points_x = self.stn_head(y)
y, _ = self.tps(y, ctrl_points_x)
block = {'1': self.block1(y)}
for i in range(self.srb_nums + 1):
block[str(i + 2)] = getattr(self,
'block%d' % (i + 2))(block[str(i + 1)])
block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
((block['1'] + block[str(self.srb_nums + 2)]))
sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
output["sr_img"] = sr_img
if self.training:
hr_img = x[1]
# add transformer
label = [str_filt(i, 'lower') + '-' for i in x[2]]
length_tensor, input_tensor, text_gt = self.label_encoder(label)
hr_pred, word_attention_map_gt, hr_correct_list = self.transformer(hr_img, length_tensor,
input_tensor)
sr_pred, word_attention_map_pred, sr_correct_list = self.transformer(sr_img, length_tensor,
input_tensor)
output["hr_img"] = hr_img
output["hr_pred"] = hr_pred
output["text_gt"] = text_gt
output["word_attention_map_gt"] = word_attention_map_gt
output["sr_pred"] = sr_pred
output["word_attention_map_pred"] = word_attention_map_pred
return output
class RecurrentResidualBlock(nn.Layer):
def __init__(self, channels):
super(RecurrentResidualBlock, self).__init__()
self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2D(channels)
self.gru1 = GruBlock(channels, channels)
# self.prelu = nn.ReLU()
self.prelu = mish()
self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2D(channels)
self.gru2 = GruBlock(channels, channels)
self.feature_enhancer = FeatureEnhancer()
for p in self.parameters():
if p.dim() > 1:
paddle.nn.initializer.XavierUniform(p)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
size = residual.shape
residual = residual.reshape([size[0], size[1], -1])
residual = self.feature_enhancer(residual)
residual = residual.reshape([size[0], size[1], size[2], size[3]])
return x + residual
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册