未验证 提交 70540130 编写于 作者: X xiaoting 提交者: GitHub

Submit SR model (#6933)

* add sr model

* update for eval

* submit sr

* polish code

* polish code

* polish code

* update sr model

* update doc

* update doc

* update doc

* fix typo

* format code

* update metric

* fix export
上级 f74f897f
Global:
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/sr/sr_tsrn_transformer_strock/
save_epoch_step: 3
# evaluation is run every 2000 iterations
eval_batch_step: [0, 1000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir: sr_output
use_visualdl: False
infer_img: doc/imgs_words_en/word_52.png
# for data or label process
character_dict_path: ./train_data/srdata/english_decomposition.txt
max_text_length: 100
infer_mode: False
use_space_char: False
save_res_path: ./output/sr/predicts_gestalt.txt
Optimizer:
name: Adam
beta1: 0.5
beta2: 0.999
clip_norm: 0.25
lr:
learning_rate: 0.0001
Architecture:
model_type: sr
algorithm: Gestalt
Transform:
name: TSRN
STN: True
infer_mode: False
Loss:
name: StrokeFocusLoss
character_dict_path: ./train_data/srdata/english_decomposition.txt
PostProcess:
name: None
Metric:
name: SRMetric
main_indicator: all
Train:
dataset:
name: LMDBDataSetSR
data_dir: ./train_data/srdata/train
transforms:
- SRResize:
imgH: 32
imgW: 128
down_sample_scale: 2
- SRLabelEncode: # Class handling label
- KeepKeys:
keep_keys: ['img_lr', 'img_hr', 'length', 'input_tensor', 'label'] # dataloader will return list in this order
loader:
shuffle: False
batch_size_per_card: 16
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSetSR
data_dir: ./train_data/srdata/test
transforms:
- SRResize:
imgH: 32
imgW: 128
down_sample_scale: 2
- SRLabelEncode: # Class handling label
- KeepKeys:
keep_keys: ['img_lr', 'img_hr','length', 'input_tensor', 'label'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 4
# Text Gestalt
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf)
> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang
> AAAI, 2022
参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) 数据下载说明,在TextZoom测试集合上超分算法效果如下:
|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接|
|---|---|---|---|---|---|
|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[训练模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
- 训练
在完成数据准备后,便可以启动训练,训练命令如下:
```
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
```
- 评估
```
# GPU 评估, Global.pretrained_model 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
- 预测:
```
# 预测使用的配置文件必须与训练一致
python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
```
![](../imgs_words_en/word_52.png)
执行命令后,上面图像的超分结果如下:
![](../imgs_results/sr_word_52.png)
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Gestalt 训练的[模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) 为例,可以使用如下命令进行转换:
```shell
python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
```
Text-Gestalt 文本超分模型推理,可以执行如下命令:
```
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
```
执行命令后,图像的超分结果如下:
![](../imgs_results/sr_word_52.png)
<a name="4-2"></a>
### 4.2 C++推理
暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂未支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂未支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@inproceedings{chen2022text,
title={Text gestalt: Stroke-aware scene text image super-resolution},
author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={36},
number={1},
pages={285--293},
year={2022}
}
```
# Text Gestalt
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf)
> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang
> AAAI, 2022
Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows:
|Model|Backbone|config|Acc|Download link|
|---|---|---|---|---|---|
|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[train model](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
```
![](../imgs_words_en/word_52.png)
After executing the command, the super-resolution result of the above image is as follows:
![](../imgs_results/sr_word_52.png)
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) ), you can use the following command to convert:
```shell
python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
```
For Text-Gestalt super-resolution model inference, the following commands can be executed:
```
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
```
After executing the command, the super-resolution result of the above image is as follows:
![](../imgs_results/sr_word_52.png)
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@inproceedings{chen2022text,
title={Text gestalt: Stroke-aware scene text image super-resolution},
author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={36},
number={1},
pages={285--293},
year={2022}
}
```
...@@ -34,7 +34,7 @@ import paddle.distributed as dist ...@@ -34,7 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet from ppocr.data.pubtab_dataset import PubTabDataSet
...@@ -54,7 +54,8 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -54,7 +54,8 @@ def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = [ support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet' 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
'LMDBDataSetSR'
] ]
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
......
...@@ -1236,6 +1236,54 @@ class ABINetLabelEncode(BaseRecLabelEncode): ...@@ -1236,6 +1236,54 @@ class ABINetLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class SRLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(SRLabelEncode, self).__init__(max_text_length,
character_dict_path, use_space_char)
self.dic = {}
with open(character_dict_path, 'r') as fin:
for line in fin.readlines():
line = line.strip()
character, sequence = line.split()
self.dic[character] = sequence
english_stroke_alphabet = '0123456789'
self.english_stroke_dict = {}
for index in range(len(english_stroke_alphabet)):
self.english_stroke_dict[english_stroke_alphabet[index]] = index
def encode(self, label):
stroke_sequence = ''
for character in label:
if character not in self.dic:
continue
else:
stroke_sequence += self.dic[character]
stroke_sequence += '0'
label = stroke_sequence
length = len(label)
input_tensor = np.zeros(self.max_text_len).astype("int64")
for j in range(length - 1):
input_tensor[j + 1] = self.english_stroke_dict[label[j]]
return length, input_tensor
def __call__(self, data):
text = data['label']
length, input_tensor = self.encode(text)
data["length"] = length
data["input_tensor"] = input_tensor
if text is None:
return None
return data
class SPINLabelEncode(AttnLabelEncode): class SPINLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -24,6 +24,7 @@ import six ...@@ -24,6 +24,7 @@ import six
import cv2 import cv2
import numpy as np import numpy as np
import math import math
from PIL import Image
class DecodeImage(object): class DecodeImage(object):
...@@ -440,3 +441,52 @@ class KieResize(object): ...@@ -440,3 +441,52 @@ class KieResize(object):
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points return points
class SRResize(object):
def __init__(self,
imgH=32,
imgW=128,
down_sample_scale=4,
keep_ratio=False,
min_ratio=1,
mask=False,
infer_mode=False,
**kwargs):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
self.down_sample_scale = down_sample_scale
self.mask = mask
self.infer_mode = infer_mode
def __call__(self, data):
imgH = self.imgH
imgW = self.imgW
images_lr = data["image_lr"]
transform2 = ResizeNormalize(
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
images_lr = transform2(images_lr)
data["img_lr"] = images_lr
if self.infer_mode:
return data
images_HR = data["image_hr"]
label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR
return data
class ResizeNormalize(object):
def __init__(self, size, interpolation=Image.BICUBIC):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
...@@ -16,6 +16,9 @@ import os ...@@ -16,6 +16,9 @@ import os
from paddle.io import Dataset from paddle.io import Dataset
import lmdb import lmdb
import cv2 import cv2
import string
import six
from PIL import Image
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -116,3 +119,58 @@ class LMDBDataSet(Dataset): ...@@ -116,3 +119,58 @@ class LMDBDataSet(Dataset):
def __len__(self): def __len__(self):
return self.data_idx_order_list.shape[0] return self.data_idx_order_list.shape[0]
class LMDBDataSetSR(LMDBDataSet):
def buf2PIL(self, txn, key, type='RGB'):
imgbuf = txn.get(key)
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
im = Image.open(buf).convert(type)
return im
def str_filt(self, str_, voc_type):
alpha_dict = {
'digit': string.digits,
'lower': string.digits + string.ascii_lowercase,
'upper': string.digits + string.ascii_letters,
'all': string.digits + string.ascii_letters + string.punctuation
}
if voc_type == 'lower':
str_ = str_.lower()
for char in str_:
if char not in alpha_dict[voc_type]:
str_ = str_.replace(char, '')
return str_
def get_lmdb_sample_info(self, txn, index):
self.voc_type = 'upper'
self.max_len = 100
self.test = False
label_key = b'label-%09d' % index
word = str(txn.get(label_key).decode())
img_HR_key = b'image_hr-%09d' % index # 128*32
img_lr_key = b'image_lr-%09d' % index # 64*16
try:
img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
except IOError or len(word) > self.max_len:
return self[index + 1]
label_str = self.str_filt(word, self.voc_type)
return img_HR, img_lr, label_str
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img_HR, img_lr, label_str = sample_info
data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
...@@ -57,6 +57,9 @@ from .table_master_loss import TableMasterLoss ...@@ -57,6 +57,9 @@ from .table_master_loss import TableMasterLoss
# vqa token loss # vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
# sr loss
from .stroke_focus_loss import StrokeFocusLoss
def build_loss(config): def build_loss(config):
support_dict = [ support_dict = [
...@@ -64,7 +67,7 @@ def build_loss(config): ...@@ -64,7 +67,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss' 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss','StrokeFocusLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/stroke_focus_loss.py
"""
import cv2
import sys
import time
import string
import random
import numpy as np
import paddle.nn as nn
import paddle
class StrokeFocusLoss(nn.Layer):
def __init__(self, character_dict_path=None, **kwargs):
super(StrokeFocusLoss, self).__init__(character_dict_path)
self.mse_loss = nn.MSELoss()
self.ce_loss = nn.CrossEntropyLoss()
self.l1_loss = nn.L1Loss()
self.english_stroke_alphabet = '0123456789'
self.english_stroke_dict = {}
for index in range(len(self.english_stroke_alphabet)):
self.english_stroke_dict[self.english_stroke_alphabet[
index]] = index
stroke_decompose_lines = open(character_dict_path, 'r').readlines()
self.dic = {}
for line in stroke_decompose_lines:
line = line.strip()
character, sequence = line.split()
self.dic[character] = sequence
def forward(self, pred, data):
sr_img = pred["sr_img"]
hr_img = pred["hr_img"]
mse_loss = self.mse_loss(sr_img, hr_img)
word_attention_map_gt = pred["word_attention_map_gt"]
word_attention_map_pred = pred["word_attention_map_pred"]
hr_pred = pred["hr_pred"]
sr_pred = pred["sr_pred"]
attention_loss = paddle.nn.functional.l1_loss(word_attention_map_gt,
word_attention_map_pred)
loss = (mse_loss + attention_loss * 50) * 100
return {
"mse_loss": mse_loss,
"attention_loss": attention_loss,
"loss": loss
}
...@@ -30,13 +30,13 @@ from .table_metric import TableMetric ...@@ -30,13 +30,13 @@ from .table_metric import TableMetric
from .kie_metric import KIEMetric from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric from .vqa_token_re_metric import VQAReTokenMetric
from .sr_metric import SRMetric
def build_metric(config): def build_metric(config):
support_dict = [ support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric' 'VQAReTokenMetric', 'SRMetric'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -16,6 +16,7 @@ import Levenshtein ...@@ -16,6 +16,7 @@ import Levenshtein
import string import string
class RecMetric(object): class RecMetric(object):
def __init__(self, def __init__(self,
main_indicator='acc', main_indicator='acc',
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/utils/ssim_psnr.py
"""
from math import exp
import paddle
import paddle.nn.functional as F
import paddle.nn as nn
import string
class SSIM(nn.Layer):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = self.create_window(window_size, self.channel)
def gaussian(self, window_size, sigma):
gauss = paddle.to_tensor([
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(self, window_size, channel):
_1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
window = _2D_window.expand([channel, 1, window_size, window_size])
return window
def _ssim(self, img1, img2, window, window_size, channel,
size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2,
groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2,
groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2,
groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean([1, 2, 3])
def ssim(self, img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.shape
window = self.create_window(window_size, channel)
return self._ssim(img1, img2, window, window_size, channel,
size_average)
def forward(self, img1, img2):
(_, channel, _, _) = img1.shape
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = self.create_window(self.window_size, channel)
self.window = window
self.channel = channel
return self._ssim(img1, img2, window, self.window_size, channel,
self.size_average)
class SRMetric(object):
def __init__(self, main_indicator='all', **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.psnr_result = []
self.ssim_result = []
self.calculate_ssim = SSIM()
self.reset()
def reset(self):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
self.psnr_result = []
self.ssim_result = []
def calculate_psnr(self, img1, img2):
# img1 and img2 have range [0, 1]
mse = ((img1 * 255 - img2 * 255)**2).mean()
if mse == 0:
return float('inf')
return 20 * paddle.log10(255.0 / paddle.sqrt(mse))
def _normalize_text(self, text):
text = ''.join(
filter(lambda x: x in (string.digits + string.ascii_letters), text))
return text.lower()
def __call__(self, pred_label, *args, **kwargs):
metric = {}
images_sr = pred_label["sr_img"]
images_hr = pred_label["hr_img"]
psnr = self.calculate_psnr(images_sr, images_hr)
ssim = self.calculate_ssim(images_sr, images_hr)
self.psnr_result.append(psnr)
self.ssim_result.append(ssim)
def get_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
self.psnr_avg = sum(self.psnr_result) / len(self.psnr_result)
self.psnr_avg = round(self.psnr_avg.item(), 6)
self.ssim_avg = sum(self.ssim_result) / len(self.ssim_result)
self.ssim_avg = round(self.ssim_avg.item(), 6)
self.all_avg = self.psnr_avg + self.ssim_avg
self.reset()
return {
'psnr_avg': self.psnr_avg,
"ssim_avg": self.ssim_avg,
"all": self.all_avg
}
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import nn from paddle import nn
from ppocr.modeling.transforms import build_transform from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone from ppocr.modeling.backbones import build_backbone
...@@ -46,9 +47,13 @@ class BaseModel(nn.Layer): ...@@ -46,9 +47,13 @@ class BaseModel(nn.Layer):
in_channels = self.transform.out_channels in_channels = self.transform.out_channels
# build backbone, backbone is need for del, rec and cls # build backbone, backbone is need for del, rec and cls
config["Backbone"]['in_channels'] = in_channels if 'Backbone' not in config or config['Backbone'] is None:
self.backbone = build_backbone(config["Backbone"], model_type) self.use_backbone = False
in_channels = self.backbone.out_channels else:
self.use_backbone = True
config["Backbone"]['in_channels'] = in_channels
self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels
# build neck # build neck
# for rec, neck can be cnn,rnn or reshape(None) # for rec, neck can be cnn,rnn or reshape(None)
...@@ -77,7 +82,8 @@ class BaseModel(nn.Layer): ...@@ -77,7 +82,8 @@ class BaseModel(nn.Layer):
y = dict() y = dict()
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
x = self.backbone(x) if self.use_backbone:
x = self.backbone(x)
if isinstance(x, dict): if isinstance(x, dict):
y.update(x) y.update(x)
else: else:
...@@ -109,4 +115,4 @@ class BaseModel(nn.Layer): ...@@ -109,4 +115,4 @@ class BaseModel(nn.Layer):
else: else:
return {final_name: x} return {final_name: x}
else: else:
return x return x
\ No newline at end of file
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
"""
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import math, copy
import numpy as np
# stroke-level alphabet
alphabet = '0123456789'
def get_alphabet_len():
return len(alphabet)
def subsequent_mask(size):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = paddle.ones([1, size, size], dtype='float32')
mask_inf = paddle.triu(
paddle.full(
shape=[1, size, size], dtype='float32', fill_value='-inf'),
diagonal=1)
mask = mask + mask_inf
padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
return padding_mask
def clones(module, N):
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
def attention(query, key, value, mask=None, dropout=None, attention_map=None):
d_k = query.shape[-1]
scores = paddle.matmul(query,
paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
if mask is not None:
scores = masked_fill(scores, mask == 0, float('-inf'))
else:
pass
p_attn = F.softmax(scores, axis=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return paddle.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Layer):
def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
self.compress_attention = compress_attention
self.compress_attention_linear = nn.Linear(h, 1)
def forward(self, query, key, value, mask=None, attention_map=None):
if mask is not None:
mask = mask.unsqueeze(1)
nbatches = query.shape[0]
query, key, value = \
[paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
for l, x in zip(self.linears, (query, key, value))]
x, attention_map = attention(
query,
key,
value,
mask=mask,
dropout=self.dropout,
attention_map=attention_map)
x = paddle.reshape(
paddle.transpose(x, [0, 2, 1, 3]),
[nbatches, -1, self.h * self.d_k])
return self.linears[-1](x), attention_map
class ResNet(nn.Layer):
def __init__(self, num_in, block, layers):
super(ResNet, self).__init__()
self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2D(64, use_global_stats=True)
self.relu1 = nn.ReLU()
self.pool = nn.MaxPool2D((2, 2), (2, 2))
self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(128, use_global_stats=True)
self.relu2 = nn.ReLU()
self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2))
self.layer1 = self._make_layer(block, 128, 256, layers[0])
self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1)
self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True)
self.layer1_relu = nn.ReLU()
self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2))
self.layer2 = self._make_layer(block, 256, 256, layers[1])
self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1)
self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True)
self.layer2_relu = nn.ReLU()
self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2))
self.layer3 = self._make_layer(block, 256, 512, layers[2])
self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1)
self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True)
self.layer3_relu = nn.ReLU()
self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2))
self.layer4 = self._make_layer(block, 512, 512, layers[3])
self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1)
self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True)
self.layer4_conv2_relu = nn.ReLU()
def _make_layer(self, block, inplanes, planes, blocks):
if inplanes != planes:
downsample = nn.Sequential(
nn.Conv2D(inplanes, planes, 3, 1, 1),
nn.BatchNorm2D(
planes, use_global_stats=True), )
else:
downsample = None
layers = []
layers.append(block(inplanes, planes, downsample))
for i in range(1, blocks):
layers.append(block(planes, planes, downsample=None))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.layer1_pool(x)
x = self.layer1(x)
x = self.layer1_conv(x)
x = self.layer1_bn(x)
x = self.layer1_relu(x)
x = self.layer2(x)
x = self.layer2_conv(x)
x = self.layer2_bn(x)
x = self.layer2_relu(x)
x = self.layer3(x)
x = self.layer3_conv(x)
x = self.layer3_bn(x)
x = self.layer3_relu(x)
x = self.layer4(x)
x = self.layer4_conv2(x)
x = self.layer4_conv2_bn(x)
x = self.layer4_conv2_relu(x)
return x
class Bottleneck(nn.Layer):
def __init__(self, input_dim):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2D(input_dim, input_dim, 1)
self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1)
self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class PositionalEncoding(nn.Layer):
"Implement the PE function."
def __init__(self, dropout, dim, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
pe = paddle.zeros([max_len, dim])
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, dim, 2).astype('float32') *
(-math.log(10000.0) / dim))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = paddle.unsqueeze(pe, 0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x)
class PositionwiseFeedForward(nn.Layer):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout, mode="downscale_in_infer")
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class Generator(nn.Layer):
"Define standard linear + softmax generation step."
def __init__(self, d_model, vocab):
super(Generator, self).__init__()
self.proj = nn.Linear(d_model, vocab)
self.relu = nn.ReLU()
def forward(self, x):
out = self.proj(x)
return out
class Embeddings(nn.Layer):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, x):
embed = self.lut(x) * math.sqrt(self.d_model)
return embed
class LayerNorm(nn.Layer):
"Construct a layernorm module (See citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = self.create_parameter(
shape=[features],
default_initializer=paddle.nn.initializer.Constant(1.0))
self.b_2 = self.create_parameter(
shape=[features],
default_initializer=paddle.nn.initializer.Constant(0.0))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class Decoder(nn.Layer):
def __init__(self):
super(Decoder, self).__init__()
self.mask_multihead = MultiHeadedAttention(
h=16, d_model=1024, dropout=0.1)
self.mul_layernorm1 = LayerNorm(1024)
self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
self.mul_layernorm2 = LayerNorm(1024)
self.pff = PositionwiseFeedForward(1024, 2048)
self.mul_layernorm3 = LayerNorm(1024)
def forward(self, text, conv_feature, attention_map=None):
text_max_length = text.shape[1]
mask = subsequent_mask(text_max_length)
result = text
result = self.mul_layernorm1(result + self.mask_multihead(
text, text, text, mask=mask)[0])
b, c, h, w = conv_feature.shape
conv_feature = paddle.transpose(
conv_feature.reshape([b, c, h * w]), [0, 2, 1])
word_image_align, attention_map = self.multihead(
result,
conv_feature,
conv_feature,
mask=None,
attention_map=attention_map)
result = self.mul_layernorm2(result + word_image_align)
result = self.mul_layernorm3(result + self.pff(result))
return result, attention_map
class BasicBlock(nn.Layer):
def __init__(self, inplanes, planes, downsample):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2D(
inplanes, planes, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(
planes, planes, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample != None:
residual = self.downsample(residual)
out += residual
out = self.relu(out)
return out
class Encoder(nn.Layer):
def __init__(self):
super(Encoder, self).__init__()
self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])
def forward(self, input):
conv_result = self.cnn(input)
return conv_result
class Transformer(nn.Layer):
def __init__(self, in_channels=1):
super(Transformer, self).__init__()
word_n_class = get_alphabet_len()
self.embedding_word_with_upperword = Embeddings(512, word_n_class)
self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
self.encoder = Encoder()
self.decoder = Decoder()
self.generator_word_with_upperword = Generator(1024, word_n_class)
for p in self.parameters():
if p.dim() > 1:
nn.initializer.XavierNormal(p)
def forward(self, image, text_length, text_input, attention_map=None):
if image.shape[1] == 3:
R = image[:, 0:1, :, :]
G = image[:, 1:2, :, :]
B = image[:, 2:3, :, :]
image = 0.299 * R + 0.587 * G + 0.114 * B
conv_feature = self.encoder(image) # batch, 1024, 8, 32
max_length = max(text_length)
text_input = text_input[:, :max_length]
text_embedding = self.embedding_word_with_upperword(
text_input) # batch, text_max_length, 512
postion_embedding = self.pe(
paddle.zeros(text_embedding.shape)) # batch, text_max_length, 512
text_input_with_pe = paddle.concat([text_embedding, postion_embedding],
2) # batch, text_max_length, 1024
batch, seq_len, _ = text_input_with_pe.shape
text_input_with_pe, word_attention_map = self.decoder(
text_input_with_pe, conv_feature)
word_decoder_result = self.generator_word_with_upperword(
text_input_with_pe)
if self.training:
total_length = paddle.sum(text_length)
probs_res = paddle.zeros([total_length, get_alphabet_len()])
start = 0
for index, length in enumerate(text_length):
length = int(length.numpy())
probs_res[start:start + length, :] = word_decoder_result[
index, 0:0 + length, :]
start = start + length
return probs_res, word_attention_map, None
else:
return word_decoder_result
...@@ -18,10 +18,10 @@ __all__ = ['build_transform'] ...@@ -18,10 +18,10 @@ __all__ = ['build_transform']
def build_transform(config): def build_transform(config):
from .tps import TPS from .tps import TPS
from .stn import STN_ON from .stn import STN_ON
from .tsrn import TSRN
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN']
support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
......
...@@ -153,4 +153,4 @@ class TPSSpatialTransformer(nn.Layer): ...@@ -153,4 +153,4 @@ class TPSSpatialTransformer(nn.Layer):
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid = 2.0 * grid - 1.0 grid = 2.0 * grid - 1.0
output_maps = grid_sample(input, grid, canvas=None) output_maps = grid_sample(input, grid, canvas=None)
return output_maps, source_coordinate return output_maps, source_coordinate
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/model/tsrn.py
"""
import math
import paddle
import paddle.nn.functional as F
from paddle import nn
from collections import OrderedDict
import sys
import numpy as np
import warnings
import math, copy
import cv2
warnings.filterwarnings("ignore")
from .tps_spatial_transformer import TPSSpatialTransformer
from .stn import STN as STN_model
from ppocr.modeling.heads.sr_rensnet_transformer import Transformer
class TSRN(nn.Layer):
def __init__(self,
in_channels,
scale_factor=2,
width=128,
height=32,
STN=False,
srb_nums=5,
mask=False,
hidden_units=32,
infer_mode=False,
**kwargs):
super(TSRN, self).__init__()
in_planes = 3
if mask:
in_planes = 4
assert math.log(scale_factor, 2) % 1 == 0
upsample_block_num = int(math.log(scale_factor, 2))
self.block1 = nn.Sequential(
nn.Conv2D(
in_planes, 2 * hidden_units, kernel_size=9, padding=4),
nn.PReLU())
self.srb_nums = srb_nums
for i in range(srb_nums):
setattr(self, 'block%d' % (i + 2),
RecurrentResidualBlock(2 * hidden_units))
setattr(
self,
'block%d' % (srb_nums + 2),
nn.Sequential(
nn.Conv2D(
2 * hidden_units,
2 * hidden_units,
kernel_size=3,
padding=1),
nn.BatchNorm2D(2 * hidden_units)))
block_ = [
UpsampleBLock(2 * hidden_units, 2)
for _ in range(upsample_block_num)
]
block_.append(
nn.Conv2D(
2 * hidden_units, in_planes, kernel_size=9, padding=4))
setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
self.tps_inputsize = [height // scale_factor, width // scale_factor]
tps_outputsize = [height // scale_factor, width // scale_factor]
num_control_points = 20
tps_margins = [0.05, 0.05]
self.stn = STN
if self.stn:
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STN_model(
in_channels=in_planes,
num_ctrlpoints=num_control_points,
activation='none')
self.out_channels = in_channels
self.r34_transformer = Transformer()
for param in self.r34_transformer.parameters():
param.trainable = False
self.infer_mode = infer_mode
def forward(self, x):
output = {}
if self.infer_mode:
output["lr_img"] = x
y = x
else:
output["lr_img"] = x[0]
output["hr_img"] = x[1]
y = x[0]
if self.stn and self.training:
_, ctrl_points_x = self.stn_head(y)
y, _ = self.tps(y, ctrl_points_x)
block = {'1': self.block1(y)}
for i in range(self.srb_nums + 1):
block[str(i + 2)] = getattr(self,
'block%d' % (i + 2))(block[str(i + 1)])
block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
((block['1'] + block[str(self.srb_nums + 2)]))
sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
output["sr_img"] = sr_img
if self.training:
hr_img = x[1]
length = x[2]
input_tensor = x[3]
# add transformer
sr_pred, word_attention_map_pred, _ = self.r34_transformer(
sr_img, length, input_tensor)
hr_pred, word_attention_map_gt, _ = self.r34_transformer(
hr_img, length, input_tensor)
output["hr_img"] = hr_img
output["hr_pred"] = hr_pred
output["word_attention_map_gt"] = word_attention_map_gt
output["sr_pred"] = sr_pred
output["word_attention_map_pred"] = word_attention_map_pred
return output
class RecurrentResidualBlock(nn.Layer):
def __init__(self, channels):
super(RecurrentResidualBlock, self).__init__()
self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2D(channels)
self.gru1 = GruBlock(channels, channels)
self.prelu = mish()
self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2D(channels)
self.gru2 = GruBlock(channels, channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose(
[0, 1, 3, 2])
return self.gru2(x + residual)
class UpsampleBLock(nn.Layer):
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
self.conv = nn.Conv2D(
in_channels, in_channels * up_scale**2, kernel_size=3, padding=1)
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.prelu = mish()
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x
class mish(nn.Layer):
def __init__(self, ):
super(mish, self).__init__()
self.activated = True
def forward(self, x):
if self.activated:
x = x * (paddle.tanh(F.softplus(x)))
return x
class GruBlock(nn.Layer):
def __init__(self, in_channels, out_channels):
super(GruBlock, self).__init__()
assert out_channels % 2 == 0
self.conv1 = nn.Conv2D(
in_channels, out_channels, kernel_size=1, padding=0)
self.gru = nn.GRU(out_channels,
out_channels // 2,
direction='bidirectional')
def forward(self, x):
# x: b, c, w, h
x = self.conv1(x)
x = x.transpose([0, 2, 3, 1]) # b, w, h, c
batch_size, w, h, c = x.shape
x = x.reshape([-1, h, c]) # b*w, h, c
x, _ = self.gru(x)
x = x.reshape([-1, w, h, c])
x = x.transpose([0, 3, 1, 2])
return x
...@@ -148,10 +148,14 @@ def load_pretrained_params(model, path): ...@@ -148,10 +148,14 @@ def load_pretrained_params(model, path):
"The {}.pdparams does not exists!".format(path) "The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams') params = paddle.load(path + '.pdparams')
state_dict = model.state_dict() state_dict = model.state_dict()
new_state_dict = {} new_state_dict = {}
is_float16 = False is_float16 = False
for k1 in params.keys(): for k1 in params.keys():
if k1 not in state_dict.keys(): if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1)) logger.warning("The pretrained params {} not in model".format(k1))
else: else:
......
...@@ -78,6 +78,12 @@ def export_single_model(model, ...@@ -78,6 +78,12 @@ def export_single_model(model,
shape=[None, 3, 64, 512], dtype="float32"), shape=[None, 3, 64, 512], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["model_type"] == "sr":
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 16, 64], dtype="float32")
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ViTSTR": elif arch_config["algorithm"] == "ViTSTR":
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
...@@ -195,6 +201,9 @@ def main(): ...@@ -195,6 +201,9 @@ def main():
else: # base rec model else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num config["Architecture"]["Head"]["out_channels"] = char_num
# for sr algorithm
if config["Architecture"]["model_type"] == "sr":
config['Architecture']["Transform"]['infer_mode'] = True
model = build_model(config["Architecture"]) model = build_model(config["Architecture"])
load_model(config, model, model_type=config['Architecture']["model_type"]) load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval() model.eval()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from PIL import Image
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import math
import time
import traceback
import paddle
import tools.infer.utility as utility
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
logger = get_logger()
class TextSR(object):
def __init__(self, args):
self.sr_image_shape = [int(v) for v in args.sr_image_shape.split(",")]
self.sr_batch_num = args.sr_batch_num
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'sr', logger)
self.benchmark = args.benchmark
if args.benchmark:
import auto_log
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
model_name="sr",
model_precision=args.precision,
batch_size=args.sr_batch_num,
data_shape="dynamic",
save_path=None, #args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
time_keys=[
'preprocess_time', 'inference_time', 'postprocess_time'
],
warmup=0,
logger=logger)
def resize_norm_img(self, img):
imgC, imgH, imgW = self.sr_image_shape
img = img.resize((imgW // 2, imgH // 2), Image.BICUBIC)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
def __call__(self, img_list):
img_num = len(img_list)
batch_num = self.sr_batch_num
st = time.time()
st = time.time()
all_result = [] * img_num
if self.benchmark:
self.autolog.times.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.sr_image_shape
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[ino])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.benchmark:
self.autolog.times.stamp()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if len(outputs) != 1:
preds = outputs
else:
preds = outputs[0]
all_result.append(outputs)
if self.benchmark:
self.autolog.times.end(stamp=True)
return all_result, time.time() - st
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextSR(args)
valid_image_file_list = []
img_list = []
# warmup 2 times
if args.warmup:
img = np.random.uniform(0, 255, [16, 64, 3]).astype(np.uint8)
for i in range(2):
res = text_recognizer([img] * int(args.sr_batch_num))
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = Image.open(image_file).convert("RGB")
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
preds, _ = text_recognizer(img_list)
for beg_no in range(len(preds)):
sr_img = preds[beg_no][1]
lr_img = preds[beg_no][0]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(valid_image_file_list[
beg_no * args.sr_batch_num + i])[-1]
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
fm_sr[:, :, ::-1])
logger.info("The visualized image saved in infer_result/sr_{}".
format(img_name_pure))
except Exception as E:
logger.info(traceback.format_exc())
logger.info(E)
exit()
if args.benchmark:
text_recognizer.autolog.report()
if __name__ == "__main__":
main(utility.parse_args())
...@@ -121,6 +121,11 @@ def init_args(): ...@@ -121,6 +121,11 @@ def init_args():
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=False) parser.add_argument("--warmup", type=str2bool, default=False)
# SR parmas
parser.add_argument("--sr_model_dir", type=str)
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
parser.add_argument("--sr_batch_num", type=int, default=1)
# #
parser.add_argument( parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results") "--draw_img_save_dir", type=str, default="./inference_results")
...@@ -156,6 +161,8 @@ def create_predictor(args, mode, logger): ...@@ -156,6 +161,8 @@ def create_predictor(args, mode, logger):
model_dir = args.table_model_dir model_dir = args.table_model_dir
elif mode == 'ser': elif mode == 'ser':
model_dir = args.ser_model_dir model_dir = args.ser_model_dir
elif mode == "sr":
model_dir = args.sr_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = args.e2e_model_dir
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
import json
from PIL import Image
import cv2
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
def main():
global_config = config['Global']
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# sr transform
config['Architecture']["Transform"]['infer_mode'] = True
model = build_model(config['Architecture'])
load_model(config, model)
# create data ops
transforms = []
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name in ['SRResize']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['imge_lr']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path', "./infer_result")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
img = Image.open(file).convert("RGB")
data = {'image_lr': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
sr_img = preds["sr_img"][0]
lr_img = preds["lr_img"][0]
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(file)[-1]
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
fm_sr[:, :, ::-1])
logger.info("The visualized image saved in infer_result/sr_{}".format(
img_name_pure))
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main()
...@@ -25,6 +25,8 @@ import datetime ...@@ -25,6 +25,8 @@ import datetime
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from tqdm import tqdm from tqdm import tqdm
import cv2
import numpy as np
from argparse import ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats from ppocr.utils.stats import TrainingStats
...@@ -262,6 +264,7 @@ def train(config, ...@@ -262,6 +264,7 @@ def train(config,
config, 'Train', device, logger, seed=epoch) config, 'Train', device, logger, seed=epoch)
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start train_reader_cost += time.time() - reader_start
...@@ -289,7 +292,7 @@ def train(config, ...@@ -289,7 +292,7 @@ def train(config,
else: else:
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa', 'sr']:
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
...@@ -297,11 +300,12 @@ def train(config, ...@@ -297,11 +300,12 @@ def train(config,
avg_loss = loss['loss'] avg_loss = loss['loss']
avg_loss.backward() avg_loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
if model_type in ['kie']: if model_type in ['kie', 'sr']:
eval_class(preds, batch) eval_class(preds, batch)
elif model_type in ['table']: elif model_type in ['table']:
post_result = post_process_class(preds, batch) post_result = post_process_class(preds, batch)
...@@ -347,8 +351,8 @@ def train(config, ...@@ -347,8 +351,8 @@ def train(config,
len(train_dataloader) - idx - 1) * eta_meter.avg len(train_dataloader) - idx - 1) * eta_meter.avg
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec))) eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \ strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
'{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \ '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
'ips: {:.5f} samples/s, eta: {}'.format( 'ips: {:.5f} samples/s, eta: {}'.format(
epoch, epoch_num, global_step, logs, epoch, epoch_num, global_step, logs,
train_reader_cost / print_batch_step, train_reader_cost / print_batch_step,
train_batch_cost / print_batch_step, train_batch_cost / print_batch_step,
...@@ -480,12 +484,13 @@ def eval(model, ...@@ -480,12 +484,13 @@ def eval(model,
leave=True) leave=True)
max_iter = len(valid_dataloader) - 1 if platform.system( max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader) ) == "Windows" else len(valid_dataloader)
sum_images = 0
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= max_iter: if idx >= max_iter:
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(level='O2'): with paddle.amp.auto_cast(level='O2'):
...@@ -493,6 +498,20 @@ def eval(model, ...@@ -493,6 +498,20 @@ def eval(model,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa']:
preds = model(batch) preds = model(batch)
elif model_type in ['sr']:
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images,
i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images,
i), fm_lr)
else: else:
preds = model(images) preds = model(images)
else: else:
...@@ -500,6 +519,20 @@ def eval(model, ...@@ -500,6 +519,20 @@ def eval(model,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa']:
preds = model(batch) preds = model(batch)
elif model_type in ['sr']:
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images,
i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images,
i), fm_lr)
else: else:
preds = model(images) preds = model(images)
...@@ -517,12 +550,15 @@ def eval(model, ...@@ -517,12 +550,15 @@ def eval(model,
elif model_type in ['table', 'vqa']: elif model_type in ['table', 'vqa']:
post_result = post_process_class(preds, batch_numpy) post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy) eval_class(post_result, batch_numpy)
elif model_type in ['sr']:
eval_class(preds, batch_numpy)
else: else:
post_result = post_process_class(preds, batch_numpy[1]) post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy) eval_class(post_result, batch_numpy)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
sum_images += 1
# Get final metric,eg. acc or hmean # Get final metric,eg. acc or hmean
metric = eval_class.get_metric() metric = eval_class.get_metric()
...@@ -616,7 +652,8 @@ def preprocess(is_train=False): ...@@ -616,7 +652,8 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN' 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt'
] ]
if use_xpu: if use_xpu:
......
...@@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer): ...@@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1 config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
model = apply_to_static(model, config, logger) model = apply_to_static(model, config, logger)
# build loss # build loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册