diff --git a/configs/sr/sr_tsrn_transformer_strock.yml b/configs/sr/sr_tsrn_transformer_strock.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c8c308c4337ddbb2933714391762efbfda44bf32
--- /dev/null
+++ b/configs/sr/sr_tsrn_transformer_strock.yml
@@ -0,0 +1,85 @@
+Global:
+ use_gpu: true
+ epoch_num: 500
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/sr/sr_tsrn_transformer_strock/
+ save_epoch_step: 3
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 1000]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir: sr_output
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_52.png
+ # for data or label process
+ character_dict_path: ./train_data/srdata/english_decomposition.txt
+ max_text_length: 100
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/sr/predicts_gestalt.txt
+
+Optimizer:
+ name: Adam
+ beta1: 0.5
+ beta2: 0.999
+ clip_norm: 0.25
+ lr:
+ learning_rate: 0.0001
+
+Architecture:
+ model_type: sr
+ algorithm: Gestalt
+ Transform:
+ name: TSRN
+ STN: True
+ infer_mode: False
+
+Loss:
+ name: StrokeFocusLoss
+ character_dict_path: ./train_data/srdata/english_decomposition.txt
+
+PostProcess:
+ name: None
+
+Metric:
+ name: SRMetric
+ main_indicator: all
+
+Train:
+ dataset:
+ name: LMDBDataSetSR
+ data_dir: ./train_data/srdata/train
+ transforms:
+ - SRResize:
+ imgH: 32
+ imgW: 128
+ down_sample_scale: 2
+ - SRLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['img_lr', 'img_hr', 'length', 'input_tensor', 'label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ batch_size_per_card: 16
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSetSR
+ data_dir: ./train_data/srdata/test
+ transforms:
+ - SRResize:
+ imgH: 32
+ imgW: 128
+ down_sample_scale: 2
+ - SRLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['img_lr', 'img_hr','length', 'input_tensor', 'label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 16
+ num_workers: 4
+
diff --git a/doc/doc_ch/algorithm_sr_gestalt.md b/doc/doc_ch/algorithm_sr_gestalt.md
new file mode 100644
index 0000000000000000000000000000000000000000..aac82b1b62b10d070b7b67702198f462219acb6c
--- /dev/null
+++ b/doc/doc_ch/algorithm_sr_gestalt.md
@@ -0,0 +1,127 @@
+# Text Gestalt
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf)
+
+> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang
+
+> AAAI, 2022
+
+参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) 数据下载说明,在TextZoom测试集合上超分算法效果如下:
+
+|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接|
+|---|---|---|---|---|---|
+|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[训练模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)|
+
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
+
+- 训练
+
+在完成数据准备后,便可以启动训练,训练命令如下:
+
+```
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
+
+```
+
+- 评估
+
+```
+# GPU 评估, Global.pretrained_model 为待测权重
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+- 预测:
+
+```
+# 预测使用的配置文件必须与训练一致
+python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
+```
+
+![](../imgs_words_en/word_52.png)
+
+执行命令后,上面图像的超分结果如下:
+
+![](../imgs_results/sr_word_52.png)
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+
+首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Gestalt 训练的[模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) 为例,可以使用如下命令进行转换:
+```shell
+python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
+```
+Text-Gestalt 文本超分模型推理,可以执行如下命令:
+```
+python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
+
+```
+
+执行命令后,图像的超分结果如下:
+
+![](../imgs_results/sr_word_52.png)
+
+
+### 4.2 C++推理
+
+暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂未支持
+
+
+### 4.4 更多推理部署
+
+暂未支持
+
+
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@inproceedings{chen2022text,
+ title={Text gestalt: Stroke-aware scene text image super-resolution},
+ author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={36},
+ number={1},
+ pages={285--293},
+ year={2022}
+}
+```
diff --git a/doc/doc_en/algorithm_sr_gestalt_en.md b/doc/doc_en/algorithm_sr_gestalt_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..516b90cb3099c0627cf23ef608ffb7da31aacc35
--- /dev/null
+++ b/doc/doc_en/algorithm_sr_gestalt_en.md
@@ -0,0 +1,136 @@
+# Text Gestalt
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+
+## 1. Introduction
+
+Paper:
+> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf)
+
+> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang
+
+> AAAI, 2022
+
+Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+|---|---|---|---|---|---|
+|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[train model](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)|
+
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+
+python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
+
+```
+
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+
+python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
+```
+
+![](../imgs_words_en/word_52.png)
+
+After executing the command, the super-resolution result of the above image is as follows:
+
+![](../imgs_results/sr_word_52.png)
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+
+First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) ), you can use the following command to convert:
+
+```shell
+python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
+```
+
+For Text-Gestalt super-resolution model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
+
+```
+
+After executing the command, the super-resolution result of the above image is as follows:
+
+![](../imgs_results/sr_word_52.png)
+
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@inproceedings{chen2022text,
+ title={Text gestalt: Stroke-aware scene text image super-resolution},
+ author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={36},
+ number={1},
+ pages={285--293},
+ year={2022}
+}
+```
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 78c3279656e184a3a34bff3847d3936b5e8977b6..b602a346dbe4b0d45af287f25f05ead0f62daf44 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -34,7 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
-from ppocr.data.lmdb_dataset import LMDBDataSet
+from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
@@ -54,7 +54,8 @@ def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = [
- 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'
+ 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
+ 'LMDBDataSetSR'
]
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 1656c69529e19ee04fcb4343f28fe742dabb83b0..68e5f719be30f2947d6a67a5cb90d1ba0e357309 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -1236,6 +1236,54 @@ class ABINetLabelEncode(BaseRecLabelEncode):
return dict_character
+class SRLabelEncode(BaseRecLabelEncode):
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(SRLabelEncode, self).__init__(max_text_length,
+ character_dict_path, use_space_char)
+ self.dic = {}
+ with open(character_dict_path, 'r') as fin:
+ for line in fin.readlines():
+ line = line.strip()
+ character, sequence = line.split()
+ self.dic[character] = sequence
+ english_stroke_alphabet = '0123456789'
+ self.english_stroke_dict = {}
+ for index in range(len(english_stroke_alphabet)):
+ self.english_stroke_dict[english_stroke_alphabet[index]] = index
+
+ def encode(self, label):
+ stroke_sequence = ''
+ for character in label:
+ if character not in self.dic:
+ continue
+ else:
+ stroke_sequence += self.dic[character]
+ stroke_sequence += '0'
+ label = stroke_sequence
+
+ length = len(label)
+
+ input_tensor = np.zeros(self.max_text_len).astype("int64")
+ for j in range(length - 1):
+ input_tensor[j + 1] = self.english_stroke_dict[label[j]]
+
+ return length, input_tensor
+
+ def __call__(self, data):
+ text = data['label']
+ length, input_tensor = self.encode(text)
+
+ data["length"] = length
+ data["input_tensor"] = input_tensor
+ if text is None:
+ return None
+ return data
+
+
class SPINLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 04cc2848fb4d25baaf553c6eda235ddb0e86511f..f8ed28929707eb750ad6e8499a73568cae3a8e6b 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -24,6 +24,7 @@ import six
import cv2
import numpy as np
import math
+from PIL import Image
class DecodeImage(object):
@@ -440,3 +441,52 @@ class KieResize(object):
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points
+
+
+class SRResize(object):
+ def __init__(self,
+ imgH=32,
+ imgW=128,
+ down_sample_scale=4,
+ keep_ratio=False,
+ min_ratio=1,
+ mask=False,
+ infer_mode=False,
+ **kwargs):
+ self.imgH = imgH
+ self.imgW = imgW
+ self.keep_ratio = keep_ratio
+ self.min_ratio = min_ratio
+ self.down_sample_scale = down_sample_scale
+ self.mask = mask
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ imgH = self.imgH
+ imgW = self.imgW
+ images_lr = data["image_lr"]
+ transform2 = ResizeNormalize(
+ (imgW // self.down_sample_scale, imgH // self.down_sample_scale))
+ images_lr = transform2(images_lr)
+ data["img_lr"] = images_lr
+ if self.infer_mode:
+ return data
+
+ images_HR = data["image_hr"]
+ label_strs = data["label"]
+ transform = ResizeNormalize((imgW, imgH))
+ images_HR = transform(images_HR)
+ data["img_hr"] = images_HR
+ return data
+
+
+class ResizeNormalize(object):
+ def __init__(self, size, interpolation=Image.BICUBIC):
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, img):
+ img = img.resize(self.size, self.interpolation)
+ img_numpy = np.array(img).astype("float32")
+ img_numpy = img_numpy.transpose((2, 0, 1)) / 255
+ return img_numpy
diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py
index e1b49809d199096ad06b90c4562aa5dbfa634db1..3a51cefec2f1da2c96cceb6482d8303aa136b78a 100644
--- a/ppocr/data/lmdb_dataset.py
+++ b/ppocr/data/lmdb_dataset.py
@@ -16,6 +16,9 @@ import os
from paddle.io import Dataset
import lmdb
import cv2
+import string
+import six
+from PIL import Image
from .imaug import transform, create_operators
@@ -116,3 +119,58 @@ class LMDBDataSet(Dataset):
def __len__(self):
return self.data_idx_order_list.shape[0]
+
+
+class LMDBDataSetSR(LMDBDataSet):
+ def buf2PIL(self, txn, key, type='RGB'):
+ imgbuf = txn.get(key)
+ buf = six.BytesIO()
+ buf.write(imgbuf)
+ buf.seek(0)
+ im = Image.open(buf).convert(type)
+ return im
+
+ def str_filt(self, str_, voc_type):
+ alpha_dict = {
+ 'digit': string.digits,
+ 'lower': string.digits + string.ascii_lowercase,
+ 'upper': string.digits + string.ascii_letters,
+ 'all': string.digits + string.ascii_letters + string.punctuation
+ }
+ if voc_type == 'lower':
+ str_ = str_.lower()
+ for char in str_:
+ if char not in alpha_dict[voc_type]:
+ str_ = str_.replace(char, '')
+ return str_
+
+ def get_lmdb_sample_info(self, txn, index):
+ self.voc_type = 'upper'
+ self.max_len = 100
+ self.test = False
+ label_key = b'label-%09d' % index
+ word = str(txn.get(label_key).decode())
+ img_HR_key = b'image_hr-%09d' % index # 128*32
+ img_lr_key = b'image_lr-%09d' % index # 64*16
+ try:
+ img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
+ img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
+ except IOError or len(word) > self.max_len:
+ return self[index + 1]
+ label_str = self.str_filt(word, self.voc_type)
+ return img_HR, img_lr, label_str
+
+ def __getitem__(self, idx):
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
+ file_idx)
+ if sample_info is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ img_HR, img_lr, label_str = sample_info
+ data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
+ outs = transform(data, self.ops)
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index bb82c7e0060fc561b6ebd8a71968e4f0ce7003e1..8986e5e5b9f488b023781176011024276c437e11 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -57,6 +57,9 @@ from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
+# sr loss
+from .stroke_focus_loss import StrokeFocusLoss
+
def build_loss(config):
support_dict = [
@@ -64,7 +67,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
- 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss'
+ 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss','StrokeFocusLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/stroke_focus_loss.py b/ppocr/losses/stroke_focus_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..002bbc34774cc80599015492762ca448f593df0f
--- /dev/null
+++ b/ppocr/losses/stroke_focus_loss.py
@@ -0,0 +1,68 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/stroke_focus_loss.py
+"""
+import cv2
+import sys
+import time
+import string
+import random
+import numpy as np
+import paddle.nn as nn
+import paddle
+
+
+class StrokeFocusLoss(nn.Layer):
+ def __init__(self, character_dict_path=None, **kwargs):
+ super(StrokeFocusLoss, self).__init__(character_dict_path)
+ self.mse_loss = nn.MSELoss()
+ self.ce_loss = nn.CrossEntropyLoss()
+ self.l1_loss = nn.L1Loss()
+ self.english_stroke_alphabet = '0123456789'
+ self.english_stroke_dict = {}
+ for index in range(len(self.english_stroke_alphabet)):
+ self.english_stroke_dict[self.english_stroke_alphabet[
+ index]] = index
+
+ stroke_decompose_lines = open(character_dict_path, 'r').readlines()
+ self.dic = {}
+ for line in stroke_decompose_lines:
+ line = line.strip()
+ character, sequence = line.split()
+ self.dic[character] = sequence
+
+ def forward(self, pred, data):
+
+ sr_img = pred["sr_img"]
+ hr_img = pred["hr_img"]
+
+ mse_loss = self.mse_loss(sr_img, hr_img)
+ word_attention_map_gt = pred["word_attention_map_gt"]
+ word_attention_map_pred = pred["word_attention_map_pred"]
+
+ hr_pred = pred["hr_pred"]
+ sr_pred = pred["sr_pred"]
+
+ attention_loss = paddle.nn.functional.l1_loss(word_attention_map_gt,
+ word_attention_map_pred)
+
+ loss = (mse_loss + attention_loss * 50) * 100
+
+ return {
+ "mse_loss": mse_loss,
+ "attention_loss": attention_loss,
+ "loss": loss
+ }
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index c244066c9f35570143403dd485e3422786711832..853647c06cf0519a0e049e14c16a0d3e26f9845b 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -30,13 +30,13 @@ from .table_metric import TableMetric
from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric
-
+from .sr_metric import SRMetric
def build_metric(config):
support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
- 'VQAReTokenMetric'
+ 'VQAReTokenMetric', 'SRMetric'
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index 515b9372e38a7213cde29fdc9834ed6df45a0a80..d858ae28e999d546847727243b35f2ac902e1026 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -16,6 +16,7 @@ import Levenshtein
import string
+
class RecMetric(object):
def __init__(self,
main_indicator='acc',
diff --git a/ppocr/metrics/sr_metric.py b/ppocr/metrics/sr_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..51c3ad66564e61abdd91432e6dc9ea1d8918583b
--- /dev/null
+++ b/ppocr/metrics/sr_metric.py
@@ -0,0 +1,155 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/utils/ssim_psnr.py
+"""
+
+from math import exp
+
+import paddle
+import paddle.nn.functional as F
+import paddle.nn as nn
+import string
+
+
+class SSIM(nn.Layer):
+ def __init__(self, window_size=11, size_average=True):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.channel = 1
+ self.window = self.create_window(window_size, self.channel)
+
+ def gaussian(self, window_size, sigma):
+ gauss = paddle.to_tensor([
+ exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
+ for x in range(window_size)
+ ])
+ return gauss / gauss.sum()
+
+ def create_window(self, window_size, channel):
+ _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
+ window = _2D_window.expand([channel, 1, window_size, window_size])
+ return window
+
+ def _ssim(self, img1, img2, window, window_size, channel,
+ size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(
+ img1 * img1, window, padding=window_size // 2,
+ groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(
+ img2 * img2, window, padding=window_size // 2,
+ groups=channel) - mu2_sq
+ sigma12 = F.conv2d(
+ img1 * img2, window, padding=window_size // 2,
+ groups=channel) - mu1_mu2
+
+ C1 = 0.01**2
+ C2 = 0.03**2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
+ (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean([1, 2, 3])
+
+ def ssim(self, img1, img2, window_size=11, size_average=True):
+ (_, channel, _, _) = img1.shape
+ window = self.create_window(window_size, channel)
+
+ return self._ssim(img1, img2, window, window_size, channel,
+ size_average)
+
+ def forward(self, img1, img2):
+ (_, channel, _, _) = img1.shape
+
+ if channel == self.channel and self.window.dtype == img1.dtype:
+ window = self.window
+ else:
+ window = self.create_window(self.window_size, channel)
+
+ self.window = window
+ self.channel = channel
+
+ return self._ssim(img1, img2, window, self.window_size, channel,
+ self.size_average)
+
+
+class SRMetric(object):
+ def __init__(self, main_indicator='all', **kwargs):
+ self.main_indicator = main_indicator
+ self.eps = 1e-5
+ self.psnr_result = []
+ self.ssim_result = []
+ self.calculate_ssim = SSIM()
+ self.reset()
+
+ def reset(self):
+ self.correct_num = 0
+ self.all_num = 0
+ self.norm_edit_dis = 0
+ self.psnr_result = []
+ self.ssim_result = []
+
+ def calculate_psnr(self, img1, img2):
+ # img1 and img2 have range [0, 1]
+ mse = ((img1 * 255 - img2 * 255)**2).mean()
+ if mse == 0:
+ return float('inf')
+ return 20 * paddle.log10(255.0 / paddle.sqrt(mse))
+
+ def _normalize_text(self, text):
+ text = ''.join(
+ filter(lambda x: x in (string.digits + string.ascii_letters), text))
+ return text.lower()
+
+ def __call__(self, pred_label, *args, **kwargs):
+ metric = {}
+ images_sr = pred_label["sr_img"]
+ images_hr = pred_label["hr_img"]
+ psnr = self.calculate_psnr(images_sr, images_hr)
+ ssim = self.calculate_ssim(images_sr, images_hr)
+ self.psnr_result.append(psnr)
+ self.ssim_result.append(ssim)
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ 'norm_edit_dis': 0,
+ }
+ """
+ self.psnr_avg = sum(self.psnr_result) / len(self.psnr_result)
+ self.psnr_avg = round(self.psnr_avg.item(), 6)
+ self.ssim_avg = sum(self.ssim_result) / len(self.ssim_result)
+ self.ssim_avg = round(self.ssim_avg.item(), 6)
+
+ self.all_avg = self.psnr_avg + self.ssim_avg
+
+ self.reset()
+ return {
+ 'psnr_avg': self.psnr_avg,
+ "ssim_avg": self.ssim_avg,
+ "all": self.all_avg
+ }
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index ed2a909cb58d56ec5a67b897de1a171658228acb..5612d366ea9ccf3f45ab675fbaa374fd4fe5d773 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -14,6 +14,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
@@ -46,9 +47,13 @@ class BaseModel(nn.Layer):
in_channels = self.transform.out_channels
# build backbone, backbone is need for del, rec and cls
- config["Backbone"]['in_channels'] = in_channels
- self.backbone = build_backbone(config["Backbone"], model_type)
- in_channels = self.backbone.out_channels
+ if 'Backbone' not in config or config['Backbone'] is None:
+ self.use_backbone = False
+ else:
+ self.use_backbone = True
+ config["Backbone"]['in_channels'] = in_channels
+ self.backbone = build_backbone(config["Backbone"], model_type)
+ in_channels = self.backbone.out_channels
# build neck
# for rec, neck can be cnn,rnn or reshape(None)
@@ -77,7 +82,8 @@ class BaseModel(nn.Layer):
y = dict()
if self.use_transform:
x = self.transform(x)
- x = self.backbone(x)
+ if self.use_backbone:
+ x = self.backbone(x)
if isinstance(x, dict):
y.update(x)
else:
@@ -109,4 +115,4 @@ class BaseModel(nn.Layer):
else:
return {final_name: x}
else:
- return x
+ return x
\ No newline at end of file
diff --git a/ppocr/modeling/heads/sr_rensnet_transformer.py b/ppocr/modeling/heads/sr_rensnet_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a004a12663ac2061a329236c58e147a017c80ba6
--- /dev/null
+++ b/ppocr/modeling/heads/sr_rensnet_transformer.py
@@ -0,0 +1,430 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
+"""
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import math, copy
+import numpy as np
+
+# stroke-level alphabet
+alphabet = '0123456789'
+
+
+def get_alphabet_len():
+ return len(alphabet)
+
+
+def subsequent_mask(size):
+ """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
+ Unmasked positions are filled with float(0.0).
+ """
+ mask = paddle.ones([1, size, size], dtype='float32')
+ mask_inf = paddle.triu(
+ paddle.full(
+ shape=[1, size, size], dtype='float32', fill_value='-inf'),
+ diagonal=1)
+ mask = mask + mask_inf
+ padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
+ return padding_mask
+
+
+def clones(module, N):
+ return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
+
+
+def masked_fill(x, mask, value):
+ y = paddle.full(x.shape, value, x.dtype)
+ return paddle.where(mask, y, x)
+
+
+def attention(query, key, value, mask=None, dropout=None, attention_map=None):
+ d_k = query.shape[-1]
+ scores = paddle.matmul(query,
+ paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
+
+ if mask is not None:
+ scores = masked_fill(scores, mask == 0, float('-inf'))
+ else:
+ pass
+
+ p_attn = F.softmax(scores, axis=-1)
+
+ if dropout is not None:
+ p_attn = dropout(p_attn)
+ return paddle.matmul(p_attn, value), p_attn
+
+
+class MultiHeadedAttention(nn.Layer):
+ def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
+ super(MultiHeadedAttention, self).__init__()
+ assert d_model % h == 0
+ self.d_k = d_model // h
+ self.h = h
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
+ self.compress_attention = compress_attention
+ self.compress_attention_linear = nn.Linear(h, 1)
+
+ def forward(self, query, key, value, mask=None, attention_map=None):
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ nbatches = query.shape[0]
+
+ query, key, value = \
+ [paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
+ for l, x in zip(self.linears, (query, key, value))]
+
+ x, attention_map = attention(
+ query,
+ key,
+ value,
+ mask=mask,
+ dropout=self.dropout,
+ attention_map=attention_map)
+
+ x = paddle.reshape(
+ paddle.transpose(x, [0, 2, 1, 3]),
+ [nbatches, -1, self.h * self.d_k])
+
+ return self.linears[-1](x), attention_map
+
+
+class ResNet(nn.Layer):
+ def __init__(self, num_in, block, layers):
+ super(ResNet, self).__init__()
+
+ self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1)
+ self.bn1 = nn.BatchNorm2D(64, use_global_stats=True)
+ self.relu1 = nn.ReLU()
+ self.pool = nn.MaxPool2D((2, 2), (2, 2))
+
+ self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1)
+ self.bn2 = nn.BatchNorm2D(128, use_global_stats=True)
+ self.relu2 = nn.ReLU()
+
+ self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2))
+ self.layer1 = self._make_layer(block, 128, 256, layers[0])
+ self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1)
+ self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True)
+ self.layer1_relu = nn.ReLU()
+
+ self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2))
+ self.layer2 = self._make_layer(block, 256, 256, layers[1])
+ self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1)
+ self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True)
+ self.layer2_relu = nn.ReLU()
+
+ self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2))
+ self.layer3 = self._make_layer(block, 256, 512, layers[2])
+ self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1)
+ self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True)
+ self.layer3_relu = nn.ReLU()
+
+ self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2))
+ self.layer4 = self._make_layer(block, 512, 512, layers[3])
+ self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1)
+ self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True)
+ self.layer4_conv2_relu = nn.ReLU()
+
+ def _make_layer(self, block, inplanes, planes, blocks):
+
+ if inplanes != planes:
+ downsample = nn.Sequential(
+ nn.Conv2D(inplanes, planes, 3, 1, 1),
+ nn.BatchNorm2D(
+ planes, use_global_stats=True), )
+ else:
+ downsample = None
+ layers = []
+ layers.append(block(inplanes, planes, downsample))
+ for i in range(1, blocks):
+ layers.append(block(planes, planes, downsample=None))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu1(x)
+ x = self.pool(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu2(x)
+
+ x = self.layer1_pool(x)
+ x = self.layer1(x)
+ x = self.layer1_conv(x)
+ x = self.layer1_bn(x)
+ x = self.layer1_relu(x)
+
+ x = self.layer2(x)
+ x = self.layer2_conv(x)
+ x = self.layer2_bn(x)
+ x = self.layer2_relu(x)
+
+ x = self.layer3(x)
+ x = self.layer3_conv(x)
+ x = self.layer3_bn(x)
+ x = self.layer3_relu(x)
+
+ x = self.layer4(x)
+ x = self.layer4_conv2(x)
+ x = self.layer4_conv2_bn(x)
+ x = self.layer4_conv2_relu(x)
+
+ return x
+
+
+class Bottleneck(nn.Layer):
+ def __init__(self, input_dim):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2D(input_dim, input_dim, 1)
+ self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True)
+ self.relu = nn.ReLU()
+
+ self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1)
+ self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True)
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class PositionalEncoding(nn.Layer):
+ "Implement the PE function."
+
+ def __init__(self, dropout, dim, max_len=5000):
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
+
+ pe = paddle.zeros([max_len, dim])
+ position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
+ div_term = paddle.exp(
+ paddle.arange(0, dim, 2).astype('float32') *
+ (-math.log(10000.0) / dim))
+ pe[:, 0::2] = paddle.sin(position * div_term)
+ pe[:, 1::2] = paddle.cos(position * div_term)
+ pe = paddle.unsqueeze(pe, 0)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:, :paddle.shape(x)[1]]
+ return self.dropout(x)
+
+
+class PositionwiseFeedForward(nn.Layer):
+ "Implements FFN equation."
+
+ def __init__(self, d_model, d_ff, dropout=0.1):
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = nn.Linear(d_model, d_ff)
+ self.w_2 = nn.Linear(d_ff, d_model)
+ self.dropout = nn.Dropout(dropout, mode="downscale_in_infer")
+
+ def forward(self, x):
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
+
+
+class Generator(nn.Layer):
+ "Define standard linear + softmax generation step."
+
+ def __init__(self, d_model, vocab):
+ super(Generator, self).__init__()
+ self.proj = nn.Linear(d_model, vocab)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ out = self.proj(x)
+ return out
+
+
+class Embeddings(nn.Layer):
+ def __init__(self, d_model, vocab):
+ super(Embeddings, self).__init__()
+ self.lut = nn.Embedding(vocab, d_model)
+ self.d_model = d_model
+
+ def forward(self, x):
+ embed = self.lut(x) * math.sqrt(self.d_model)
+ return embed
+
+
+class LayerNorm(nn.Layer):
+ "Construct a layernorm module (See citation for details)."
+
+ def __init__(self, features, eps=1e-6):
+ super(LayerNorm, self).__init__()
+ self.a_2 = self.create_parameter(
+ shape=[features],
+ default_initializer=paddle.nn.initializer.Constant(1.0))
+ self.b_2 = self.create_parameter(
+ shape=[features],
+ default_initializer=paddle.nn.initializer.Constant(0.0))
+ self.eps = eps
+
+ def forward(self, x):
+ mean = x.mean(-1, keepdim=True)
+ std = x.std(-1, keepdim=True)
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
+
+
+class Decoder(nn.Layer):
+ def __init__(self):
+ super(Decoder, self).__init__()
+
+ self.mask_multihead = MultiHeadedAttention(
+ h=16, d_model=1024, dropout=0.1)
+ self.mul_layernorm1 = LayerNorm(1024)
+
+ self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
+ self.mul_layernorm2 = LayerNorm(1024)
+
+ self.pff = PositionwiseFeedForward(1024, 2048)
+ self.mul_layernorm3 = LayerNorm(1024)
+
+ def forward(self, text, conv_feature, attention_map=None):
+ text_max_length = text.shape[1]
+ mask = subsequent_mask(text_max_length)
+ result = text
+ result = self.mul_layernorm1(result + self.mask_multihead(
+ text, text, text, mask=mask)[0])
+ b, c, h, w = conv_feature.shape
+ conv_feature = paddle.transpose(
+ conv_feature.reshape([b, c, h * w]), [0, 2, 1])
+ word_image_align, attention_map = self.multihead(
+ result,
+ conv_feature,
+ conv_feature,
+ mask=None,
+ attention_map=attention_map)
+ result = self.mul_layernorm2(result + word_image_align)
+ result = self.mul_layernorm3(result + self.pff(result))
+
+ return result, attention_map
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self, inplanes, planes, downsample):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2D(
+ inplanes, planes, kernel_size=3, stride=1, padding=1)
+ self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
+ self.relu = nn.ReLU()
+ self.conv2 = nn.Conv2D(
+ planes, planes, kernel_size=3, stride=1, padding=1)
+ self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
+ self.downsample = downsample
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample != None:
+ residual = self.downsample(residual)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Encoder(nn.Layer):
+ def __init__(self):
+ super(Encoder, self).__init__()
+ self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])
+
+ def forward(self, input):
+ conv_result = self.cnn(input)
+ return conv_result
+
+
+class Transformer(nn.Layer):
+ def __init__(self, in_channels=1):
+ super(Transformer, self).__init__()
+
+ word_n_class = get_alphabet_len()
+ self.embedding_word_with_upperword = Embeddings(512, word_n_class)
+ self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
+
+ self.encoder = Encoder()
+ self.decoder = Decoder()
+ self.generator_word_with_upperword = Generator(1024, word_n_class)
+
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.initializer.XavierNormal(p)
+
+ def forward(self, image, text_length, text_input, attention_map=None):
+ if image.shape[1] == 3:
+ R = image[:, 0:1, :, :]
+ G = image[:, 1:2, :, :]
+ B = image[:, 2:3, :, :]
+ image = 0.299 * R + 0.587 * G + 0.114 * B
+
+ conv_feature = self.encoder(image) # batch, 1024, 8, 32
+ max_length = max(text_length)
+ text_input = text_input[:, :max_length]
+
+ text_embedding = self.embedding_word_with_upperword(
+ text_input) # batch, text_max_length, 512
+ postion_embedding = self.pe(
+ paddle.zeros(text_embedding.shape)) # batch, text_max_length, 512
+ text_input_with_pe = paddle.concat([text_embedding, postion_embedding],
+ 2) # batch, text_max_length, 1024
+ batch, seq_len, _ = text_input_with_pe.shape
+
+ text_input_with_pe, word_attention_map = self.decoder(
+ text_input_with_pe, conv_feature)
+
+ word_decoder_result = self.generator_word_with_upperword(
+ text_input_with_pe)
+
+ if self.training:
+ total_length = paddle.sum(text_length)
+ probs_res = paddle.zeros([total_length, get_alphabet_len()])
+ start = 0
+
+ for index, length in enumerate(text_length):
+ length = int(length.numpy())
+ probs_res[start:start + length, :] = word_decoder_result[
+ index, 0:0 + length, :]
+
+ start = start + length
+
+ return probs_res, word_attention_map, None
+ else:
+ return word_decoder_result
diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py
index 7e4ffdf46854416f71e1c8f4e131d1f0283bb725..b22c60bb3d5e1933056d37bad208f4c311139c8e 100755
--- a/ppocr/modeling/transforms/__init__.py
+++ b/ppocr/modeling/transforms/__init__.py
@@ -18,10 +18,10 @@ __all__ = ['build_transform']
def build_transform(config):
from .tps import TPS
from .stn import STN_ON
+ from .tsrn import TSRN
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
-
- support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
+ support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py
index cb1cb10aaa98dffa2f720dc81afdf82d25e071ca..e7ec2c848f192d766722f824962a7f8d0fed41f9 100644
--- a/ppocr/modeling/transforms/tps_spatial_transformer.py
+++ b/ppocr/modeling/transforms/tps_spatial_transformer.py
@@ -153,4 +153,4 @@ class TPSSpatialTransformer(nn.Layer):
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid = 2.0 * grid - 1.0
output_maps = grid_sample(input, grid, canvas=None)
- return output_maps, source_coordinate
+ return output_maps, source_coordinate
\ No newline at end of file
diff --git a/ppocr/modeling/transforms/tsrn.py b/ppocr/modeling/transforms/tsrn.py
new file mode 100644
index 0000000000000000000000000000000000000000..31aa90ea4b5d5e8f071487899b72219f3e5b36f5
--- /dev/null
+++ b/ppocr/modeling/transforms/tsrn.py
@@ -0,0 +1,219 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/model/tsrn.py
+"""
+
+import math
+import paddle
+import paddle.nn.functional as F
+from paddle import nn
+from collections import OrderedDict
+import sys
+import numpy as np
+import warnings
+import math, copy
+import cv2
+
+warnings.filterwarnings("ignore")
+
+from .tps_spatial_transformer import TPSSpatialTransformer
+from .stn import STN as STN_model
+from ppocr.modeling.heads.sr_rensnet_transformer import Transformer
+
+
+class TSRN(nn.Layer):
+ def __init__(self,
+ in_channels,
+ scale_factor=2,
+ width=128,
+ height=32,
+ STN=False,
+ srb_nums=5,
+ mask=False,
+ hidden_units=32,
+ infer_mode=False,
+ **kwargs):
+ super(TSRN, self).__init__()
+ in_planes = 3
+ if mask:
+ in_planes = 4
+ assert math.log(scale_factor, 2) % 1 == 0
+ upsample_block_num = int(math.log(scale_factor, 2))
+ self.block1 = nn.Sequential(
+ nn.Conv2D(
+ in_planes, 2 * hidden_units, kernel_size=9, padding=4),
+ nn.PReLU())
+ self.srb_nums = srb_nums
+ for i in range(srb_nums):
+ setattr(self, 'block%d' % (i + 2),
+ RecurrentResidualBlock(2 * hidden_units))
+
+ setattr(
+ self,
+ 'block%d' % (srb_nums + 2),
+ nn.Sequential(
+ nn.Conv2D(
+ 2 * hidden_units,
+ 2 * hidden_units,
+ kernel_size=3,
+ padding=1),
+ nn.BatchNorm2D(2 * hidden_units)))
+
+ block_ = [
+ UpsampleBLock(2 * hidden_units, 2)
+ for _ in range(upsample_block_num)
+ ]
+ block_.append(
+ nn.Conv2D(
+ 2 * hidden_units, in_planes, kernel_size=9, padding=4))
+ setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
+ self.tps_inputsize = [height // scale_factor, width // scale_factor]
+ tps_outputsize = [height // scale_factor, width // scale_factor]
+ num_control_points = 20
+ tps_margins = [0.05, 0.05]
+ self.stn = STN
+ if self.stn:
+ self.tps = TPSSpatialTransformer(
+ output_image_size=tuple(tps_outputsize),
+ num_control_points=num_control_points,
+ margins=tuple(tps_margins))
+
+ self.stn_head = STN_model(
+ in_channels=in_planes,
+ num_ctrlpoints=num_control_points,
+ activation='none')
+ self.out_channels = in_channels
+
+ self.r34_transformer = Transformer()
+ for param in self.r34_transformer.parameters():
+ param.trainable = False
+ self.infer_mode = infer_mode
+
+ def forward(self, x):
+ output = {}
+ if self.infer_mode:
+ output["lr_img"] = x
+ y = x
+ else:
+ output["lr_img"] = x[0]
+ output["hr_img"] = x[1]
+ y = x[0]
+ if self.stn and self.training:
+ _, ctrl_points_x = self.stn_head(y)
+ y, _ = self.tps(y, ctrl_points_x)
+ block = {'1': self.block1(y)}
+ for i in range(self.srb_nums + 1):
+ block[str(i + 2)] = getattr(self,
+ 'block%d' % (i + 2))(block[str(i + 1)])
+
+ block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
+ ((block['1'] + block[str(self.srb_nums + 2)]))
+
+ sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
+
+ output["sr_img"] = sr_img
+
+ if self.training:
+ hr_img = x[1]
+ length = x[2]
+ input_tensor = x[3]
+
+ # add transformer
+ sr_pred, word_attention_map_pred, _ = self.r34_transformer(
+ sr_img, length, input_tensor)
+
+ hr_pred, word_attention_map_gt, _ = self.r34_transformer(
+ hr_img, length, input_tensor)
+
+ output["hr_img"] = hr_img
+ output["hr_pred"] = hr_pred
+ output["word_attention_map_gt"] = word_attention_map_gt
+ output["sr_pred"] = sr_pred
+ output["word_attention_map_pred"] = word_attention_map_pred
+
+ return output
+
+
+class RecurrentResidualBlock(nn.Layer):
+ def __init__(self, channels):
+ super(RecurrentResidualBlock, self).__init__()
+ self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
+ self.bn1 = nn.BatchNorm2D(channels)
+ self.gru1 = GruBlock(channels, channels)
+ self.prelu = mish()
+ self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
+ self.bn2 = nn.BatchNorm2D(channels)
+ self.gru2 = GruBlock(channels, channels)
+
+ def forward(self, x):
+ residual = self.conv1(x)
+ residual = self.bn1(residual)
+ residual = self.prelu(residual)
+ residual = self.conv2(residual)
+ residual = self.bn2(residual)
+ residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose(
+ [0, 1, 3, 2])
+
+ return self.gru2(x + residual)
+
+
+class UpsampleBLock(nn.Layer):
+ def __init__(self, in_channels, up_scale):
+ super(UpsampleBLock, self).__init__()
+ self.conv = nn.Conv2D(
+ in_channels, in_channels * up_scale**2, kernel_size=3, padding=1)
+
+ self.pixel_shuffle = nn.PixelShuffle(up_scale)
+ self.prelu = mish()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.pixel_shuffle(x)
+ x = self.prelu(x)
+ return x
+
+
+class mish(nn.Layer):
+ def __init__(self, ):
+ super(mish, self).__init__()
+ self.activated = True
+
+ def forward(self, x):
+ if self.activated:
+ x = x * (paddle.tanh(F.softplus(x)))
+ return x
+
+
+class GruBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels):
+ super(GruBlock, self).__init__()
+ assert out_channels % 2 == 0
+ self.conv1 = nn.Conv2D(
+ in_channels, out_channels, kernel_size=1, padding=0)
+ self.gru = nn.GRU(out_channels,
+ out_channels // 2,
+ direction='bidirectional')
+
+ def forward(self, x):
+ # x: b, c, w, h
+ x = self.conv1(x)
+ x = x.transpose([0, 2, 3, 1]) # b, w, h, c
+ batch_size, w, h, c = x.shape
+ x = x.reshape([-1, h, c]) # b*w, h, c
+ x, _ = self.gru(x)
+ x = x.reshape([-1, w, h, c])
+ x = x.transpose([0, 3, 1, 2])
+ return x
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index e77a6ce0183611569193e1996e935f4bd30400a0..7cd205e8fd4a9234fb2a67c50b394501e1507bf2 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -148,10 +148,14 @@ def load_pretrained_params(model, path):
"The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams')
+
state_dict = model.state_dict()
+
new_state_dict = {}
is_float16 = False
+
for k1 in params.keys():
+
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
diff --git a/tools/export_model.py b/tools/export_model.py
index 78932c987d8bc57216ef3586c2bdc0cdbd6a9037..2443d66ca2a5d81cdb99964bfb88af29ae0c66e2 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -78,6 +78,12 @@ def export_single_model(model,
shape=[None, 3, 64, 512], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["model_type"] == "sr":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 16, 64], dtype="float32")
+ ]
+ model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ViTSTR":
other_shape = [
paddle.static.InputSpec(
@@ -195,6 +201,9 @@ def main():
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
+ # for sr algorithm
+ if config["Architecture"]["model_type"] == "sr":
+ config['Architecture']["Transform"]['infer_mode'] = True
model = build_model(config["Architecture"])
load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval()
diff --git a/tools/infer/predict_sr.py b/tools/infer/predict_sr.py
new file mode 100755
index 0000000000000000000000000000000000000000..b10d90bf1d6ce3de6d2947e9cc1f73443736518d
--- /dev/null
+++ b/tools/infer/predict_sr.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+from PIL import Image
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, __dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import math
+import time
+import traceback
+import paddle
+
+import tools.infer.utility as utility
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+
+logger = get_logger()
+
+
+class TextSR(object):
+ def __init__(self, args):
+ self.sr_image_shape = [int(v) for v in args.sr_image_shape.split(",")]
+ self.sr_batch_num = args.sr_batch_num
+
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 'sr', logger)
+ self.benchmark = args.benchmark
+ if args.benchmark:
+ import auto_log
+ pid = os.getpid()
+ gpu_id = utility.get_infer_gpuid()
+ self.autolog = auto_log.AutoLogger(
+ model_name="sr",
+ model_precision=args.precision,
+ batch_size=args.sr_batch_num,
+ data_shape="dynamic",
+ save_path=None, #args.save_log_path,
+ inference_config=self.config,
+ pids=pid,
+ process_name=None,
+ gpu_ids=gpu_id if args.use_gpu else None,
+ time_keys=[
+ 'preprocess_time', 'inference_time', 'postprocess_time'
+ ],
+ warmup=0,
+ logger=logger)
+
+ def resize_norm_img(self, img):
+ imgC, imgH, imgW = self.sr_image_shape
+ img = img.resize((imgW // 2, imgH // 2), Image.BICUBIC)
+ img_numpy = np.array(img).astype("float32")
+ img_numpy = img_numpy.transpose((2, 0, 1)) / 255
+ return img_numpy
+
+ def __call__(self, img_list):
+ img_num = len(img_list)
+ batch_num = self.sr_batch_num
+ st = time.time()
+ st = time.time()
+ all_result = [] * img_num
+ if self.benchmark:
+ self.autolog.times.start()
+ for beg_img_no in range(0, img_num, batch_num):
+ end_img_no = min(img_num, beg_img_no + batch_num)
+ norm_img_batch = []
+ imgC, imgH, imgW = self.sr_image_shape
+ for ino in range(beg_img_no, end_img_no):
+ norm_img = self.resize_norm_img(img_list[ino])
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+
+ norm_img_batch = np.concatenate(norm_img_batch)
+ norm_img_batch = norm_img_batch.copy()
+ if self.benchmark:
+ self.autolog.times.stamp()
+ self.input_tensor.copy_from_cpu(norm_img_batch)
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ if len(outputs) != 1:
+ preds = outputs
+ else:
+ preds = outputs[0]
+ all_result.append(outputs)
+ if self.benchmark:
+ self.autolog.times.end(stamp=True)
+ return all_result, time.time() - st
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ text_recognizer = TextSR(args)
+ valid_image_file_list = []
+ img_list = []
+
+ # warmup 2 times
+ if args.warmup:
+ img = np.random.uniform(0, 255, [16, 64, 3]).astype(np.uint8)
+ for i in range(2):
+ res = text_recognizer([img] * int(args.sr_batch_num))
+
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = Image.open(image_file).convert("RGB")
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ valid_image_file_list.append(image_file)
+ img_list.append(img)
+ try:
+ preds, _ = text_recognizer(img_list)
+ for beg_no in range(len(preds)):
+ sr_img = preds[beg_no][1]
+ lr_img = preds[beg_no][0]
+ for i in (range(sr_img.shape[0])):
+ fm_sr = (sr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
+ fm_lr = (lr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
+ img_name_pure = os.path.split(valid_image_file_list[
+ beg_no * args.sr_batch_num + i])[-1]
+ cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
+ fm_sr[:, :, ::-1])
+ logger.info("The visualized image saved in infer_result/sr_{}".
+ format(img_name_pure))
+
+ except Exception as E:
+ logger.info(traceback.format_exc())
+ logger.info(E)
+ exit()
+ if args.benchmark:
+ text_recognizer.autolog.report()
+
+
+if __name__ == "__main__":
+ main(utility.parse_args())
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 9345106e774cfbcf0e87a7cf5d8b6cdabb4cf490..9c89a4e7642e662cf9c370a0071cc87fecf47d55 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -121,6 +121,11 @@ def init_args():
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=False)
+ # SR parmas
+ parser.add_argument("--sr_model_dir", type=str)
+ parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
+ parser.add_argument("--sr_batch_num", type=int, default=1)
+
#
parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results")
@@ -156,6 +161,8 @@ def create_predictor(args, mode, logger):
model_dir = args.table_model_dir
elif mode == 'ser':
model_dir = args.ser_model_dir
+ elif mode == "sr":
+ model_dir = args.sr_model_dir
else:
model_dir = args.e2e_model_dir
diff --git a/tools/infer_sr.py b/tools/infer_sr.py
new file mode 100755
index 0000000000000000000000000000000000000000..0bc2f6aaa7c4400676268ec64d37e721af0f99c2
--- /dev/null
+++ b/tools/infer_sr.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+import json
+from PIL import Image
+import cv2
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, __dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import paddle
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import load_model
+from ppocr.utils.utility import get_image_file_list
+import tools.program as program
+
+
+def main():
+ global_config = config['Global']
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # sr transform
+ config['Architecture']["Transform"]['infer_mode'] = True
+
+ model = build_model(config['Architecture'])
+
+ load_model(config, model)
+
+ # create data ops
+ transforms = []
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ continue
+ elif op_name in ['SRResize']:
+ op[op_name]['infer_mode'] = True
+ elif op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = ['imge_lr']
+ transforms.append(op)
+ global_config['infer_mode'] = True
+ ops = create_operators(transforms, global_config)
+
+ save_res_path = config['Global'].get('save_res_path', "./infer_result")
+ if not os.path.exists(os.path.dirname(save_res_path)):
+ os.makedirs(os.path.dirname(save_res_path))
+
+ model.eval()
+ for file in get_image_file_list(config['Global']['infer_img']):
+ logger.info("infer_img: {}".format(file))
+ img = Image.open(file).convert("RGB")
+ data = {'image_lr': img}
+ batch = transform(data, ops)
+ images = np.expand_dims(batch[0], axis=0)
+ images = paddle.to_tensor(images)
+
+ preds = model(images)
+ sr_img = preds["sr_img"][0]
+ lr_img = preds["lr_img"][0]
+ fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
+ fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
+ img_name_pure = os.path.split(file)[-1]
+ cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
+ fm_sr[:, :, ::-1])
+ logger.info("The visualized image saved in infer_result/sr_{}".format(
+ img_name_pure))
+
+ logger.info("success!")
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main()
diff --git a/tools/program.py b/tools/program.py
index fd4e662bd926c386e23220c1e825ce201e301dec..34845f005f81aa20553cec98d231e8358698cff7 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -25,6 +25,8 @@ import datetime
import paddle
import paddle.distributed as dist
from tqdm import tqdm
+import cv2
+import numpy as np
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats
@@ -262,6 +264,7 @@ def train(config,
config, 'Train', device, logger, seed=epoch)
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
+
for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start
@@ -289,7 +292,7 @@ def train(config,
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
- elif model_type in ["kie", 'vqa']:
+ elif model_type in ["kie", 'vqa', 'sr']:
preds = model(batch)
else:
preds = model(images)
@@ -297,11 +300,12 @@ def train(config,
avg_loss = loss['loss']
avg_loss.backward()
optimizer.step()
+
optimizer.clear_grad()
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch]
- if model_type in ['kie']:
+ if model_type in ['kie', 'sr']:
eval_class(preds, batch)
elif model_type in ['table']:
post_result = post_process_class(preds, batch)
@@ -347,8 +351,8 @@ def train(config,
len(train_dataloader) - idx - 1) * eta_meter.avg
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
- '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
- 'ips: {:.5f} samples/s, eta: {}'.format(
+ '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
+ 'ips: {:.5f} samples/s, eta: {}'.format(
epoch, epoch_num, global_step, logs,
train_reader_cost / print_batch_step,
train_batch_cost / print_batch_step,
@@ -480,12 +484,13 @@ def eval(model,
leave=True)
max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader)
+ sum_images = 0
for idx, batch in enumerate(valid_dataloader):
if idx >= max_iter:
break
images = batch[0]
start = time.time()
-
+
# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
@@ -493,6 +498,20 @@ def eval(model,
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
+ elif model_type in ['sr']:
+ preds = model(batch)
+ sr_img = preds["sr_img"]
+ lr_img = preds["lr_img"]
+
+ for i in (range(sr_img.shape[0])):
+ fm_sr = (sr_img[i].numpy() * 255).transpose(
+ 1, 2, 0).astype(np.uint8)
+ fm_lr = (lr_img[i].numpy() * 255).transpose(
+ 1, 2, 0).astype(np.uint8)
+ cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images,
+ i), fm_sr)
+ cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images,
+ i), fm_lr)
else:
preds = model(images)
else:
@@ -500,6 +519,20 @@ def eval(model,
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
+ elif model_type in ['sr']:
+ preds = model(batch)
+ sr_img = preds["sr_img"]
+ lr_img = preds["lr_img"]
+
+ for i in (range(sr_img.shape[0])):
+ fm_sr = (sr_img[i].numpy() * 255).transpose(
+ 1, 2, 0).astype(np.uint8)
+ fm_lr = (lr_img[i].numpy() * 255).transpose(
+ 1, 2, 0).astype(np.uint8)
+ cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images,
+ i), fm_sr)
+ cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images,
+ i), fm_lr)
else:
preds = model(images)
@@ -517,12 +550,15 @@ def eval(model,
elif model_type in ['table', 'vqa']:
post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy)
+ elif model_type in ['sr']:
+ eval_class(preds, batch_numpy)
else:
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
pbar.update(1)
total_frame += len(images)
+ sum_images += 1
# Get final metric,eg. acc or hmean
metric = eval_class.get_metric()
@@ -616,7 +652,8 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
- 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN'
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
+ 'Gestalt'
]
if use_xpu:
diff --git a/tools/train.py b/tools/train.py
index dc8cae8a63744bb9bd486d9899680dbde9da1697..b44d76b3832aba24dd8bcad821fe21d22c8b320b 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
+
model = apply_to_static(model, config, logger)
# build loss