From 7054013004a9282b5e0b4354a0b689a680ff08b4 Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Fri, 12 Aug 2022 10:49:54 +0800 Subject: [PATCH] Submit SR model (#6933) * add sr model * update for eval * submit sr * polish code * polish code * polish code * update sr model * update doc * update doc * update doc * fix typo * format code * update metric * fix export --- configs/sr/sr_tsrn_transformer_strock.yml | 85 ++++ doc/doc_ch/algorithm_sr_gestalt.md | 127 ++++++ doc/doc_en/algorithm_sr_gestalt_en.md | 136 ++++++ ppocr/data/__init__.py | 5 +- ppocr/data/imaug/label_ops.py | 48 ++ ppocr/data/imaug/operators.py | 50 ++ ppocr/data/lmdb_dataset.py | 58 +++ ppocr/losses/__init__.py | 5 +- ppocr/losses/stroke_focus_loss.py | 68 +++ ppocr/metrics/__init__.py | 4 +- ppocr/metrics/rec_metric.py | 1 + ppocr/metrics/sr_metric.py | 155 +++++++ ppocr/modeling/architectures/base_model.py | 16 +- .../modeling/heads/sr_rensnet_transformer.py | 430 ++++++++++++++++++ ppocr/modeling/transforms/__init__.py | 4 +- .../transforms/tps_spatial_transformer.py | 2 +- ppocr/modeling/transforms/tsrn.py | 219 +++++++++ ppocr/utils/save_load.py | 4 + tools/export_model.py | 9 + tools/infer/predict_sr.py | 155 +++++++ tools/infer/utility.py | 7 + tools/infer_sr.py | 100 ++++ tools/program.py | 49 +- tools/train.py | 1 + 24 files changed, 1719 insertions(+), 19 deletions(-) create mode 100644 configs/sr/sr_tsrn_transformer_strock.yml create mode 100644 doc/doc_ch/algorithm_sr_gestalt.md create mode 100644 doc/doc_en/algorithm_sr_gestalt_en.md create mode 100644 ppocr/losses/stroke_focus_loss.py create mode 100644 ppocr/metrics/sr_metric.py create mode 100644 ppocr/modeling/heads/sr_rensnet_transformer.py create mode 100644 ppocr/modeling/transforms/tsrn.py create mode 100755 tools/infer/predict_sr.py create mode 100755 tools/infer_sr.py diff --git a/configs/sr/sr_tsrn_transformer_strock.yml b/configs/sr/sr_tsrn_transformer_strock.yml new file mode 100644 index 00000000..c8c308c4 --- /dev/null +++ b/configs/sr/sr_tsrn_transformer_strock.yml @@ -0,0 +1,85 @@ +Global: + use_gpu: true + epoch_num: 500 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/sr/sr_tsrn_transformer_strock/ + save_epoch_step: 3 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 1000] + cal_metric_during_train: False + pretrained_model: + checkpoints: + save_inference_dir: sr_output + use_visualdl: False + infer_img: doc/imgs_words_en/word_52.png + # for data or label process + character_dict_path: ./train_data/srdata/english_decomposition.txt + max_text_length: 100 + infer_mode: False + use_space_char: False + save_res_path: ./output/sr/predicts_gestalt.txt + +Optimizer: + name: Adam + beta1: 0.5 + beta2: 0.999 + clip_norm: 0.25 + lr: + learning_rate: 0.0001 + +Architecture: + model_type: sr + algorithm: Gestalt + Transform: + name: TSRN + STN: True + infer_mode: False + +Loss: + name: StrokeFocusLoss + character_dict_path: ./train_data/srdata/english_decomposition.txt + +PostProcess: + name: None + +Metric: + name: SRMetric + main_indicator: all + +Train: + dataset: + name: LMDBDataSetSR + data_dir: ./train_data/srdata/train + transforms: + - SRResize: + imgH: 32 + imgW: 128 + down_sample_scale: 2 + - SRLabelEncode: # Class handling label + - KeepKeys: + keep_keys: ['img_lr', 'img_hr', 'length', 'input_tensor', 'label'] # dataloader will return list in this order + loader: + shuffle: False + batch_size_per_card: 16 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSetSR + data_dir: ./train_data/srdata/test + transforms: + - SRResize: + imgH: 32 + imgW: 128 + down_sample_scale: 2 + - SRLabelEncode: # Class handling label + - KeepKeys: + keep_keys: ['img_lr', 'img_hr','length', 'input_tensor', 'label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 16 + num_workers: 4 + diff --git a/doc/doc_ch/algorithm_sr_gestalt.md b/doc/doc_ch/algorithm_sr_gestalt.md new file mode 100644 index 00000000..aac82b1b --- /dev/null +++ b/doc/doc_ch/algorithm_sr_gestalt.md @@ -0,0 +1,127 @@ +# Text Gestalt + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf) + +> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang + +> AAAI, 2022 + +参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) 数据下载说明,在TextZoom测试集合上超分算法效果如下: + +|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接| +|---|---|---|---|---|---| +|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[训练模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)| + + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 + +- 训练 + +在完成数据准备后,便可以启动训练,训练命令如下: + +``` +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml + +``` + +- 评估 + +``` +# GPU 评估, Global.pretrained_model 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +- 预测: + +``` +# 预测使用的配置文件必须与训练一致 +python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png +``` + +![](../imgs_words_en/word_52.png) + +执行命令后,上面图像的超分结果如下: + +![](../imgs_results/sr_word_52.png) + + +## 4. 推理部署 + + +### 4.1 Python推理 + +首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Gestalt 训练的[模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) 为例,可以使用如下命令进行转换: +```shell +python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out +``` +Text-Gestalt 文本超分模型推理,可以执行如下命令: +``` +python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128 + +``` + +执行命令后,图像的超分结果如下: + +![](../imgs_results/sr_word_52.png) + + +### 4.2 C++推理 + +暂未支持 + + +### 4.3 Serving服务化部署 + +暂未支持 + + +### 4.4 更多推理部署 + +暂未支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@inproceedings{chen2022text, + title={Text gestalt: Stroke-aware scene text image super-resolution}, + author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={36}, + number={1}, + pages={285--293}, + year={2022} +} +``` diff --git a/doc/doc_en/algorithm_sr_gestalt_en.md b/doc/doc_en/algorithm_sr_gestalt_en.md new file mode 100644 index 00000000..516b90cb --- /dev/null +++ b/doc/doc_en/algorithm_sr_gestalt_en.md @@ -0,0 +1,136 @@ +# Text Gestalt + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + + +## 1. Introduction + +Paper: +> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf) + +> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang + +> AAAI, 2022 + +Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows: + +|Model|Backbone|config|Acc|Download link| +|---|---|---|---|---|---| +|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[train model](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)| + + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (long training period, not recommended) + +python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter + +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml + +``` + + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training + +python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png +``` + +![](../imgs_words_en/word_52.png) + +After executing the command, the super-resolution result of the above image is as follows: + +![](../imgs_results/sr_word_52.png) + + +## 4. Inference and Deployment + + +### 4.1 Python Inference + +First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) ), you can use the following command to convert: + +```shell +python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out +``` + +For Text-Gestalt super-resolution model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128 + +``` + +After executing the command, the super-resolution result of the above image is as follows: + +![](../imgs_results/sr_word_52.png) + + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@inproceedings{chen2022text, + title={Text gestalt: Stroke-aware scene text image super-resolution}, + author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={36}, + number={1}, + pages={285--293}, + year={2022} +} +``` diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 78c32796..b602a346 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -34,7 +34,7 @@ import paddle.distributed as dist from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet -from ppocr.data.lmdb_dataset import LMDBDataSet +from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pubtab_dataset import PubTabDataSet @@ -54,7 +54,8 @@ def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) support_dict = [ - 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet' + 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet', + 'LMDBDataSetSR' ] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 1656c695..68e5f719 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1236,6 +1236,54 @@ class ABINetLabelEncode(BaseRecLabelEncode): return dict_character +class SRLabelEncode(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(SRLabelEncode, self).__init__(max_text_length, + character_dict_path, use_space_char) + self.dic = {} + with open(character_dict_path, 'r') as fin: + for line in fin.readlines(): + line = line.strip() + character, sequence = line.split() + self.dic[character] = sequence + english_stroke_alphabet = '0123456789' + self.english_stroke_dict = {} + for index in range(len(english_stroke_alphabet)): + self.english_stroke_dict[english_stroke_alphabet[index]] = index + + def encode(self, label): + stroke_sequence = '' + for character in label: + if character not in self.dic: + continue + else: + stroke_sequence += self.dic[character] + stroke_sequence += '0' + label = stroke_sequence + + length = len(label) + + input_tensor = np.zeros(self.max_text_len).astype("int64") + for j in range(length - 1): + input_tensor[j + 1] = self.english_stroke_dict[label[j]] + + return length, input_tensor + + def __call__(self, data): + text = data['label'] + length, input_tensor = self.encode(text) + + data["length"] = length + data["input_tensor"] = input_tensor + if text is None: + return None + return data + + class SPINLabelEncode(AttnLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 04cc2848..f8ed2892 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -24,6 +24,7 @@ import six import cv2 import numpy as np import math +from PIL import Image class DecodeImage(object): @@ -440,3 +441,52 @@ class KieResize(object): points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) return points + + +class SRResize(object): + def __init__(self, + imgH=32, + imgW=128, + down_sample_scale=4, + keep_ratio=False, + min_ratio=1, + mask=False, + infer_mode=False, + **kwargs): + self.imgH = imgH + self.imgW = imgW + self.keep_ratio = keep_ratio + self.min_ratio = min_ratio + self.down_sample_scale = down_sample_scale + self.mask = mask + self.infer_mode = infer_mode + + def __call__(self, data): + imgH = self.imgH + imgW = self.imgW + images_lr = data["image_lr"] + transform2 = ResizeNormalize( + (imgW // self.down_sample_scale, imgH // self.down_sample_scale)) + images_lr = transform2(images_lr) + data["img_lr"] = images_lr + if self.infer_mode: + return data + + images_HR = data["image_hr"] + label_strs = data["label"] + transform = ResizeNormalize((imgW, imgH)) + images_HR = transform(images_HR) + data["img_hr"] = images_HR + return data + + +class ResizeNormalize(object): + def __init__(self, size, interpolation=Image.BICUBIC): + self.size = size + self.interpolation = interpolation + + def __call__(self, img): + img = img.resize(self.size, self.interpolation) + img_numpy = np.array(img).astype("float32") + img_numpy = img_numpy.transpose((2, 0, 1)) / 255 + return img_numpy diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e1b49809..3a51cefe 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -16,6 +16,9 @@ import os from paddle.io import Dataset import lmdb import cv2 +import string +import six +from PIL import Image from .imaug import transform, create_operators @@ -116,3 +119,58 @@ class LMDBDataSet(Dataset): def __len__(self): return self.data_idx_order_list.shape[0] + + +class LMDBDataSetSR(LMDBDataSet): + def buf2PIL(self, txn, key, type='RGB'): + imgbuf = txn.get(key) + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + im = Image.open(buf).convert(type) + return im + + def str_filt(self, str_, voc_type): + alpha_dict = { + 'digit': string.digits, + 'lower': string.digits + string.ascii_lowercase, + 'upper': string.digits + string.ascii_letters, + 'all': string.digits + string.ascii_letters + string.punctuation + } + if voc_type == 'lower': + str_ = str_.lower() + for char in str_: + if char not in alpha_dict[voc_type]: + str_ = str_.replace(char, '') + return str_ + + def get_lmdb_sample_info(self, txn, index): + self.voc_type = 'upper' + self.max_len = 100 + self.test = False + label_key = b'label-%09d' % index + word = str(txn.get(label_key).decode()) + img_HR_key = b'image_hr-%09d' % index # 128*32 + img_lr_key = b'image_lr-%09d' % index # 64*16 + try: + img_HR = self.buf2PIL(txn, img_HR_key, 'RGB') + img_lr = self.buf2PIL(txn, img_lr_key, 'RGB') + except IOError or len(word) > self.max_len: + return self[index + 1] + label_str = self.str_filt(word, self.voc_type) + return img_HR, img_lr, label_str + + def __getitem__(self, idx): + lmdb_idx, file_idx = self.data_idx_order_list[idx] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], + file_idx) + if sample_info is None: + return self.__getitem__(np.random.randint(self.__len__())) + img_HR, img_lr, label_str = sample_info + data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str} + outs = transform(data, self.ops) + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index bb82c7e0..8986e5e5 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -57,6 +57,9 @@ from .table_master_loss import TableMasterLoss # vqa token loss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss +# sr loss +from .stroke_focus_loss import StrokeFocusLoss + def build_loss(config): support_dict = [ @@ -64,7 +67,7 @@ def build_loss(config): 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', - 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss' + 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss','StrokeFocusLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/stroke_focus_loss.py b/ppocr/losses/stroke_focus_loss.py new file mode 100644 index 00000000..002bbc34 --- /dev/null +++ b/ppocr/losses/stroke_focus_loss.py @@ -0,0 +1,68 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/stroke_focus_loss.py +""" +import cv2 +import sys +import time +import string +import random +import numpy as np +import paddle.nn as nn +import paddle + + +class StrokeFocusLoss(nn.Layer): + def __init__(self, character_dict_path=None, **kwargs): + super(StrokeFocusLoss, self).__init__(character_dict_path) + self.mse_loss = nn.MSELoss() + self.ce_loss = nn.CrossEntropyLoss() + self.l1_loss = nn.L1Loss() + self.english_stroke_alphabet = '0123456789' + self.english_stroke_dict = {} + for index in range(len(self.english_stroke_alphabet)): + self.english_stroke_dict[self.english_stroke_alphabet[ + index]] = index + + stroke_decompose_lines = open(character_dict_path, 'r').readlines() + self.dic = {} + for line in stroke_decompose_lines: + line = line.strip() + character, sequence = line.split() + self.dic[character] = sequence + + def forward(self, pred, data): + + sr_img = pred["sr_img"] + hr_img = pred["hr_img"] + + mse_loss = self.mse_loss(sr_img, hr_img) + word_attention_map_gt = pred["word_attention_map_gt"] + word_attention_map_pred = pred["word_attention_map_pred"] + + hr_pred = pred["hr_pred"] + sr_pred = pred["sr_pred"] + + attention_loss = paddle.nn.functional.l1_loss(word_attention_map_gt, + word_attention_map_pred) + + loss = (mse_loss + attention_loss * 50) * 100 + + return { + "mse_loss": mse_loss, + "attention_loss": attention_loss, + "loss": loss + } diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index c244066c..853647c0 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -30,13 +30,13 @@ from .table_metric import TableMetric from .kie_metric import KIEMetric from .vqa_token_ser_metric import VQASerTokenMetric from .vqa_token_re_metric import VQAReTokenMetric - +from .sr_metric import SRMetric def build_metric(config): support_dict = [ "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', - 'VQAReTokenMetric' + 'VQAReTokenMetric', 'SRMetric' ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 515b9372..d858ae28 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -16,6 +16,7 @@ import Levenshtein import string + class RecMetric(object): def __init__(self, main_indicator='acc', diff --git a/ppocr/metrics/sr_metric.py b/ppocr/metrics/sr_metric.py new file mode 100644 index 00000000..51c3ad66 --- /dev/null +++ b/ppocr/metrics/sr_metric.py @@ -0,0 +1,155 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/utils/ssim_psnr.py +""" + +from math import exp + +import paddle +import paddle.nn.functional as F +import paddle.nn as nn +import string + + +class SSIM(nn.Layer): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = self.create_window(window_size, self.channel) + + def gaussian(self, window_size, sigma): + gauss = paddle.to_tensor([ + exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) + for x in range(window_size) + ]) + return gauss / gauss.sum() + + def create_window(self, window_size, channel): + _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0) + window = _2D_window.expand([channel, 1, window_size, window_size]) + return window + + def _ssim(self, img1, img2, window, window_size, channel, + size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=window_size // 2, + groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=window_size // 2, + groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=window_size // 2, + groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean([1, 2, 3]) + + def ssim(self, img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.shape + window = self.create_window(window_size, channel) + + return self._ssim(img1, img2, window, window_size, channel, + size_average) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.shape + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = self.create_window(self.window_size, channel) + + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, + self.size_average) + + +class SRMetric(object): + def __init__(self, main_indicator='all', **kwargs): + self.main_indicator = main_indicator + self.eps = 1e-5 + self.psnr_result = [] + self.ssim_result = [] + self.calculate_ssim = SSIM() + self.reset() + + def reset(self): + self.correct_num = 0 + self.all_num = 0 + self.norm_edit_dis = 0 + self.psnr_result = [] + self.ssim_result = [] + + def calculate_psnr(self, img1, img2): + # img1 and img2 have range [0, 1] + mse = ((img1 * 255 - img2 * 255)**2).mean() + if mse == 0: + return float('inf') + return 20 * paddle.log10(255.0 / paddle.sqrt(mse)) + + def _normalize_text(self, text): + text = ''.join( + filter(lambda x: x in (string.digits + string.ascii_letters), text)) + return text.lower() + + def __call__(self, pred_label, *args, **kwargs): + metric = {} + images_sr = pred_label["sr_img"] + images_hr = pred_label["hr_img"] + psnr = self.calculate_psnr(images_sr, images_hr) + ssim = self.calculate_ssim(images_sr, images_hr) + self.psnr_result.append(psnr) + self.ssim_result.append(ssim) + + def get_metric(self): + """ + return metrics { + 'acc': 0, + 'norm_edit_dis': 0, + } + """ + self.psnr_avg = sum(self.psnr_result) / len(self.psnr_result) + self.psnr_avg = round(self.psnr_avg.item(), 6) + self.ssim_avg = sum(self.ssim_result) / len(self.ssim_result) + self.ssim_avg = round(self.ssim_avg.item(), 6) + + self.all_avg = self.psnr_avg + self.ssim_avg + + self.reset() + return { + 'psnr_avg': self.psnr_avg, + "ssim_avg": self.ssim_avg, + "all": self.all_avg + } diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index ed2a909c..5612d366 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + from paddle import nn from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone @@ -46,9 +47,13 @@ class BaseModel(nn.Layer): in_channels = self.transform.out_channels # build backbone, backbone is need for del, rec and cls - config["Backbone"]['in_channels'] = in_channels - self.backbone = build_backbone(config["Backbone"], model_type) - in_channels = self.backbone.out_channels + if 'Backbone' not in config or config['Backbone'] is None: + self.use_backbone = False + else: + self.use_backbone = True + config["Backbone"]['in_channels'] = in_channels + self.backbone = build_backbone(config["Backbone"], model_type) + in_channels = self.backbone.out_channels # build neck # for rec, neck can be cnn,rnn or reshape(None) @@ -77,7 +82,8 @@ class BaseModel(nn.Layer): y = dict() if self.use_transform: x = self.transform(x) - x = self.backbone(x) + if self.use_backbone: + x = self.backbone(x) if isinstance(x, dict): y.update(x) else: @@ -109,4 +115,4 @@ class BaseModel(nn.Layer): else: return {final_name: x} else: - return x + return x \ No newline at end of file diff --git a/ppocr/modeling/heads/sr_rensnet_transformer.py b/ppocr/modeling/heads/sr_rensnet_transformer.py new file mode 100644 index 00000000..a004a126 --- /dev/null +++ b/ppocr/modeling/heads/sr_rensnet_transformer.py @@ -0,0 +1,430 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py +""" +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import math, copy +import numpy as np + +# stroke-level alphabet +alphabet = '0123456789' + + +def get_alphabet_len(): + return len(alphabet) + + +def subsequent_mask(size): + """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = paddle.ones([1, size, size], dtype='float32') + mask_inf = paddle.triu( + paddle.full( + shape=[1, size, size], dtype='float32', fill_value='-inf'), + diagonal=1) + mask = mask + mask_inf + padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype)) + return padding_mask + + +def clones(module, N): + return nn.LayerList([copy.deepcopy(module) for _ in range(N)]) + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def attention(query, key, value, mask=None, dropout=None, attention_map=None): + d_k = query.shape[-1] + scores = paddle.matmul(query, + paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k) + + if mask is not None: + scores = masked_fill(scores, mask == 0, float('-inf')) + else: + pass + + p_attn = F.softmax(scores, axis=-1) + + if dropout is not None: + p_attn = dropout(p_attn) + return paddle.matmul(p_attn, value), p_attn + + +class MultiHeadedAttention(nn.Layer): + def __init__(self, h, d_model, dropout=0.1, compress_attention=False): + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer") + self.compress_attention = compress_attention + self.compress_attention_linear = nn.Linear(h, 1) + + def forward(self, query, key, value, mask=None, attention_map=None): + if mask is not None: + mask = mask.unsqueeze(1) + nbatches = query.shape[0] + + query, key, value = \ + [paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3]) + for l, x in zip(self.linears, (query, key, value))] + + x, attention_map = attention( + query, + key, + value, + mask=mask, + dropout=self.dropout, + attention_map=attention_map) + + x = paddle.reshape( + paddle.transpose(x, [0, 2, 1, 3]), + [nbatches, -1, self.h * self.d_k]) + + return self.linears[-1](x), attention_map + + +class ResNet(nn.Layer): + def __init__(self, num_in, block, layers): + super(ResNet, self).__init__() + + self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2D(64, use_global_stats=True) + self.relu1 = nn.ReLU() + self.pool = nn.MaxPool2D((2, 2), (2, 2)) + + self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2D(128, use_global_stats=True) + self.relu2 = nn.ReLU() + + self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2)) + self.layer1 = self._make_layer(block, 128, 256, layers[0]) + self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1) + self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True) + self.layer1_relu = nn.ReLU() + + self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2)) + self.layer2 = self._make_layer(block, 256, 256, layers[1]) + self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1) + self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True) + self.layer2_relu = nn.ReLU() + + self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2)) + self.layer3 = self._make_layer(block, 256, 512, layers[2]) + self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1) + self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True) + self.layer3_relu = nn.ReLU() + + self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2)) + self.layer4 = self._make_layer(block, 512, 512, layers[3]) + self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1) + self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True) + self.layer4_conv2_relu = nn.ReLU() + + def _make_layer(self, block, inplanes, planes, blocks): + + if inplanes != planes: + downsample = nn.Sequential( + nn.Conv2D(inplanes, planes, 3, 1, 1), + nn.BatchNorm2D( + planes, use_global_stats=True), ) + else: + downsample = None + layers = [] + layers.append(block(inplanes, planes, downsample)) + for i in range(1, blocks): + layers.append(block(planes, planes, downsample=None)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.pool(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + + x = self.layer1_pool(x) + x = self.layer1(x) + x = self.layer1_conv(x) + x = self.layer1_bn(x) + x = self.layer1_relu(x) + + x = self.layer2(x) + x = self.layer2_conv(x) + x = self.layer2_bn(x) + x = self.layer2_relu(x) + + x = self.layer3(x) + x = self.layer3_conv(x) + x = self.layer3_bn(x) + x = self.layer3_relu(x) + + x = self.layer4(x) + x = self.layer4_conv2(x) + x = self.layer4_conv2_bn(x) + x = self.layer4_conv2_relu(x) + + return x + + +class Bottleneck(nn.Layer): + def __init__(self, input_dim): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2D(input_dim, input_dim, 1) + self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True) + self.relu = nn.ReLU() + + self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1) + self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class PositionalEncoding(nn.Layer): + "Implement the PE function." + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer") + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = paddle.unsqueeze(pe, 0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :paddle.shape(x)[1]] + return self.dropout(x) + + +class PositionwiseFeedForward(nn.Layer): + "Implements FFN equation." + + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout, mode="downscale_in_infer") + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class Generator(nn.Layer): + "Define standard linear + softmax generation step." + + def __init__(self, d_model, vocab): + super(Generator, self).__init__() + self.proj = nn.Linear(d_model, vocab) + self.relu = nn.ReLU() + + def forward(self, x): + out = self.proj(x) + return out + + +class Embeddings(nn.Layer): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, x): + embed = self.lut(x) * math.sqrt(self.d_model) + return embed + + +class LayerNorm(nn.Layer): + "Construct a layernorm module (See citation for details)." + + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = self.create_parameter( + shape=[features], + default_initializer=paddle.nn.initializer.Constant(1.0)) + self.b_2 = self.create_parameter( + shape=[features], + default_initializer=paddle.nn.initializer.Constant(0.0)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + + +class Decoder(nn.Layer): + def __init__(self): + super(Decoder, self).__init__() + + self.mask_multihead = MultiHeadedAttention( + h=16, d_model=1024, dropout=0.1) + self.mul_layernorm1 = LayerNorm(1024) + + self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1) + self.mul_layernorm2 = LayerNorm(1024) + + self.pff = PositionwiseFeedForward(1024, 2048) + self.mul_layernorm3 = LayerNorm(1024) + + def forward(self, text, conv_feature, attention_map=None): + text_max_length = text.shape[1] + mask = subsequent_mask(text_max_length) + result = text + result = self.mul_layernorm1(result + self.mask_multihead( + text, text, text, mask=mask)[0]) + b, c, h, w = conv_feature.shape + conv_feature = paddle.transpose( + conv_feature.reshape([b, c, h * w]), [0, 2, 1]) + word_image_align, attention_map = self.multihead( + result, + conv_feature, + conv_feature, + mask=None, + attention_map=attention_map) + result = self.mul_layernorm2(result + word_image_align) + result = self.mul_layernorm3(result + self.pff(result)) + + return result, attention_map + + +class BasicBlock(nn.Layer): + def __init__(self, inplanes, planes, downsample): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2D( + inplanes, planes, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2D( + planes, planes, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample != None: + residual = self.downsample(residual) + + out += residual + out = self.relu(out) + + return out + + +class Encoder(nn.Layer): + def __init__(self): + super(Encoder, self).__init__() + self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3]) + + def forward(self, input): + conv_result = self.cnn(input) + return conv_result + + +class Transformer(nn.Layer): + def __init__(self, in_channels=1): + super(Transformer, self).__init__() + + word_n_class = get_alphabet_len() + self.embedding_word_with_upperword = Embeddings(512, word_n_class) + self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000) + + self.encoder = Encoder() + self.decoder = Decoder() + self.generator_word_with_upperword = Generator(1024, word_n_class) + + for p in self.parameters(): + if p.dim() > 1: + nn.initializer.XavierNormal(p) + + def forward(self, image, text_length, text_input, attention_map=None): + if image.shape[1] == 3: + R = image[:, 0:1, :, :] + G = image[:, 1:2, :, :] + B = image[:, 2:3, :, :] + image = 0.299 * R + 0.587 * G + 0.114 * B + + conv_feature = self.encoder(image) # batch, 1024, 8, 32 + max_length = max(text_length) + text_input = text_input[:, :max_length] + + text_embedding = self.embedding_word_with_upperword( + text_input) # batch, text_max_length, 512 + postion_embedding = self.pe( + paddle.zeros(text_embedding.shape)) # batch, text_max_length, 512 + text_input_with_pe = paddle.concat([text_embedding, postion_embedding], + 2) # batch, text_max_length, 1024 + batch, seq_len, _ = text_input_with_pe.shape + + text_input_with_pe, word_attention_map = self.decoder( + text_input_with_pe, conv_feature) + + word_decoder_result = self.generator_word_with_upperword( + text_input_with_pe) + + if self.training: + total_length = paddle.sum(text_length) + probs_res = paddle.zeros([total_length, get_alphabet_len()]) + start = 0 + + for index, length in enumerate(text_length): + length = int(length.numpy()) + probs_res[start:start + length, :] = word_decoder_result[ + index, 0:0 + length, :] + + start = start + length + + return probs_res, word_attention_map, None + else: + return word_decoder_result diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py index 7e4ffdf4..b22c60bb 100755 --- a/ppocr/modeling/transforms/__init__.py +++ b/ppocr/modeling/transforms/__init__.py @@ -18,10 +18,10 @@ __all__ = ['build_transform'] def build_transform(config): from .tps import TPS from .stn import STN_ON + from .tsrn import TSRN from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN - - support_dict = ['TPS', 'STN_ON', 'GA_SPIN'] + support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN'] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py index cb1cb10a..e7ec2c84 100644 --- a/ppocr/modeling/transforms/tps_spatial_transformer.py +++ b/ppocr/modeling/transforms/tps_spatial_transformer.py @@ -153,4 +153,4 @@ class TPSSpatialTransformer(nn.Layer): # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] grid = 2.0 * grid - 1.0 output_maps = grid_sample(input, grid, canvas=None) - return output_maps, source_coordinate + return output_maps, source_coordinate \ No newline at end of file diff --git a/ppocr/modeling/transforms/tsrn.py b/ppocr/modeling/transforms/tsrn.py new file mode 100644 index 00000000..31aa90ea --- /dev/null +++ b/ppocr/modeling/transforms/tsrn.py @@ -0,0 +1,219 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/model/tsrn.py +""" + +import math +import paddle +import paddle.nn.functional as F +from paddle import nn +from collections import OrderedDict +import sys +import numpy as np +import warnings +import math, copy +import cv2 + +warnings.filterwarnings("ignore") + +from .tps_spatial_transformer import TPSSpatialTransformer +from .stn import STN as STN_model +from ppocr.modeling.heads.sr_rensnet_transformer import Transformer + + +class TSRN(nn.Layer): + def __init__(self, + in_channels, + scale_factor=2, + width=128, + height=32, + STN=False, + srb_nums=5, + mask=False, + hidden_units=32, + infer_mode=False, + **kwargs): + super(TSRN, self).__init__() + in_planes = 3 + if mask: + in_planes = 4 + assert math.log(scale_factor, 2) % 1 == 0 + upsample_block_num = int(math.log(scale_factor, 2)) + self.block1 = nn.Sequential( + nn.Conv2D( + in_planes, 2 * hidden_units, kernel_size=9, padding=4), + nn.PReLU()) + self.srb_nums = srb_nums + for i in range(srb_nums): + setattr(self, 'block%d' % (i + 2), + RecurrentResidualBlock(2 * hidden_units)) + + setattr( + self, + 'block%d' % (srb_nums + 2), + nn.Sequential( + nn.Conv2D( + 2 * hidden_units, + 2 * hidden_units, + kernel_size=3, + padding=1), + nn.BatchNorm2D(2 * hidden_units))) + + block_ = [ + UpsampleBLock(2 * hidden_units, 2) + for _ in range(upsample_block_num) + ] + block_.append( + nn.Conv2D( + 2 * hidden_units, in_planes, kernel_size=9, padding=4)) + setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_)) + self.tps_inputsize = [height // scale_factor, width // scale_factor] + tps_outputsize = [height // scale_factor, width // scale_factor] + num_control_points = 20 + tps_margins = [0.05, 0.05] + self.stn = STN + if self.stn: + self.tps = TPSSpatialTransformer( + output_image_size=tuple(tps_outputsize), + num_control_points=num_control_points, + margins=tuple(tps_margins)) + + self.stn_head = STN_model( + in_channels=in_planes, + num_ctrlpoints=num_control_points, + activation='none') + self.out_channels = in_channels + + self.r34_transformer = Transformer() + for param in self.r34_transformer.parameters(): + param.trainable = False + self.infer_mode = infer_mode + + def forward(self, x): + output = {} + if self.infer_mode: + output["lr_img"] = x + y = x + else: + output["lr_img"] = x[0] + output["hr_img"] = x[1] + y = x[0] + if self.stn and self.training: + _, ctrl_points_x = self.stn_head(y) + y, _ = self.tps(y, ctrl_points_x) + block = {'1': self.block1(y)} + for i in range(self.srb_nums + 1): + block[str(i + 2)] = getattr(self, + 'block%d' % (i + 2))(block[str(i + 1)]) + + block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \ + ((block['1'] + block[str(self.srb_nums + 2)])) + + sr_img = paddle.tanh(block[str(self.srb_nums + 3)]) + + output["sr_img"] = sr_img + + if self.training: + hr_img = x[1] + length = x[2] + input_tensor = x[3] + + # add transformer + sr_pred, word_attention_map_pred, _ = self.r34_transformer( + sr_img, length, input_tensor) + + hr_pred, word_attention_map_gt, _ = self.r34_transformer( + hr_img, length, input_tensor) + + output["hr_img"] = hr_img + output["hr_pred"] = hr_pred + output["word_attention_map_gt"] = word_attention_map_gt + output["sr_pred"] = sr_pred + output["word_attention_map_pred"] = word_attention_map_pred + + return output + + +class RecurrentResidualBlock(nn.Layer): + def __init__(self, channels): + super(RecurrentResidualBlock, self).__init__() + self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2D(channels) + self.gru1 = GruBlock(channels, channels) + self.prelu = mish() + self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2D(channels) + self.gru2 = GruBlock(channels, channels) + + def forward(self, x): + residual = self.conv1(x) + residual = self.bn1(residual) + residual = self.prelu(residual) + residual = self.conv2(residual) + residual = self.bn2(residual) + residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose( + [0, 1, 3, 2]) + + return self.gru2(x + residual) + + +class UpsampleBLock(nn.Layer): + def __init__(self, in_channels, up_scale): + super(UpsampleBLock, self).__init__() + self.conv = nn.Conv2D( + in_channels, in_channels * up_scale**2, kernel_size=3, padding=1) + + self.pixel_shuffle = nn.PixelShuffle(up_scale) + self.prelu = mish() + + def forward(self, x): + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.prelu(x) + return x + + +class mish(nn.Layer): + def __init__(self, ): + super(mish, self).__init__() + self.activated = True + + def forward(self, x): + if self.activated: + x = x * (paddle.tanh(F.softplus(x))) + return x + + +class GruBlock(nn.Layer): + def __init__(self, in_channels, out_channels): + super(GruBlock, self).__init__() + assert out_channels % 2 == 0 + self.conv1 = nn.Conv2D( + in_channels, out_channels, kernel_size=1, padding=0) + self.gru = nn.GRU(out_channels, + out_channels // 2, + direction='bidirectional') + + def forward(self, x): + # x: b, c, w, h + x = self.conv1(x) + x = x.transpose([0, 2, 3, 1]) # b, w, h, c + batch_size, w, h, c = x.shape + x = x.reshape([-1, h, c]) # b*w, h, c + x, _ = self.gru(x) + x = x.reshape([-1, w, h, c]) + x = x.transpose([0, 3, 1, 2]) + return x diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index e77a6ce0..7cd205e8 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -148,10 +148,14 @@ def load_pretrained_params(model, path): "The {}.pdparams does not exists!".format(path) params = paddle.load(path + '.pdparams') + state_dict = model.state_dict() + new_state_dict = {} is_float16 = False + for k1 in params.keys(): + if k1 not in state_dict.keys(): logger.warning("The pretrained params {} not in model".format(k1)) else: diff --git a/tools/export_model.py b/tools/export_model.py index 78932c98..2443d66c 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -78,6 +78,12 @@ def export_single_model(model, shape=[None, 3, 64, 512], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["model_type"] == "sr": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 16, 64], dtype="float32") + ] + model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "ViTSTR": other_shape = [ paddle.static.InputSpec( @@ -195,6 +201,9 @@ def main(): else: # base rec model config["Architecture"]["Head"]["out_channels"] = char_num + # for sr algorithm + if config["Architecture"]["model_type"] == "sr": + config['Architecture']["Transform"]['infer_mode'] = True model = build_model(config["Architecture"]) load_model(config, model, model_type=config['Architecture']["model_type"]) model.eval() diff --git a/tools/infer/predict_sr.py b/tools/infer/predict_sr.py new file mode 100755 index 00000000..b10d90bf --- /dev/null +++ b/tools/infer/predict_sr.py @@ -0,0 +1,155 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from PIL import Image +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, __dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import math +import time +import traceback +import paddle + +import tools.infer.utility as utility +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif + +logger = get_logger() + + +class TextSR(object): + def __init__(self, args): + self.sr_image_shape = [int(v) for v in args.sr_image_shape.split(",")] + self.sr_batch_num = args.sr_batch_num + + self.predictor, self.input_tensor, self.output_tensors, self.config = \ + utility.create_predictor(args, 'sr', logger) + self.benchmark = args.benchmark + if args.benchmark: + import auto_log + pid = os.getpid() + gpu_id = utility.get_infer_gpuid() + self.autolog = auto_log.AutoLogger( + model_name="sr", + model_precision=args.precision, + batch_size=args.sr_batch_num, + data_shape="dynamic", + save_path=None, #args.save_log_path, + inference_config=self.config, + pids=pid, + process_name=None, + gpu_ids=gpu_id if args.use_gpu else None, + time_keys=[ + 'preprocess_time', 'inference_time', 'postprocess_time' + ], + warmup=0, + logger=logger) + + def resize_norm_img(self, img): + imgC, imgH, imgW = self.sr_image_shape + img = img.resize((imgW // 2, imgH // 2), Image.BICUBIC) + img_numpy = np.array(img).astype("float32") + img_numpy = img_numpy.transpose((2, 0, 1)) / 255 + return img_numpy + + def __call__(self, img_list): + img_num = len(img_list) + batch_num = self.sr_batch_num + st = time.time() + st = time.time() + all_result = [] * img_num + if self.benchmark: + self.autolog.times.start() + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + imgC, imgH, imgW = self.sr_image_shape + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[ino]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + if self.benchmark: + self.autolog.times.stamp() + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if len(outputs) != 1: + preds = outputs + else: + preds = outputs[0] + all_result.append(outputs) + if self.benchmark: + self.autolog.times.end(stamp=True) + return all_result, time.time() - st + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + text_recognizer = TextSR(args) + valid_image_file_list = [] + img_list = [] + + # warmup 2 times + if args.warmup: + img = np.random.uniform(0, 255, [16, 64, 3]).astype(np.uint8) + for i in range(2): + res = text_recognizer([img] * int(args.sr_batch_num)) + + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = Image.open(image_file).convert("RGB") + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + valid_image_file_list.append(image_file) + img_list.append(img) + try: + preds, _ = text_recognizer(img_list) + for beg_no in range(len(preds)): + sr_img = preds[beg_no][1] + lr_img = preds[beg_no][0] + for i in (range(sr_img.shape[0])): + fm_sr = (sr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8) + fm_lr = (lr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8) + img_name_pure = os.path.split(valid_image_file_list[ + beg_no * args.sr_batch_num + i])[-1] + cv2.imwrite("infer_result/sr_{}".format(img_name_pure), + fm_sr[:, :, ::-1]) + logger.info("The visualized image saved in infer_result/sr_{}". + format(img_name_pure)) + + except Exception as E: + logger.info(traceback.format_exc()) + logger.info(E) + exit() + if args.benchmark: + text_recognizer.autolog.report() + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9345106e..9c89a4e7 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -121,6 +121,11 @@ def init_args(): parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--warmup", type=str2bool, default=False) + # SR parmas + parser.add_argument("--sr_model_dir", type=str) + parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128") + parser.add_argument("--sr_batch_num", type=int, default=1) + # parser.add_argument( "--draw_img_save_dir", type=str, default="./inference_results") @@ -156,6 +161,8 @@ def create_predictor(args, mode, logger): model_dir = args.table_model_dir elif mode == 'ser': model_dir = args.ser_model_dir + elif mode == "sr": + model_dir = args.sr_model_dir else: model_dir = args.e2e_model_dir diff --git a/tools/infer_sr.py b/tools/infer_sr.py new file mode 100755 index 00000000..0bc2f6aa --- /dev/null +++ b/tools/infer_sr.py @@ -0,0 +1,100 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys +import json +from PIL import Image +import cv2 + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, __dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import paddle + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +from ppocr.utils.utility import get_image_file_list +import tools.program as program + + +def main(): + global_config = config['Global'] + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # sr transform + config['Architecture']["Transform"]['infer_mode'] = True + + model = build_model(config['Architecture']) + + load_model(config, model) + + # create data ops + transforms = [] + for op in config['Eval']['dataset']['transforms']: + op_name = list(op)[0] + if 'Label' in op_name: + continue + elif op_name in ['SRResize']: + op[op_name]['infer_mode'] = True + elif op_name == 'KeepKeys': + op[op_name]['keep_keys'] = ['imge_lr'] + transforms.append(op) + global_config['infer_mode'] = True + ops = create_operators(transforms, global_config) + + save_res_path = config['Global'].get('save_res_path', "./infer_result") + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + + model.eval() + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + img = Image.open(file).convert("RGB") + data = {'image_lr': img} + batch = transform(data, ops) + images = np.expand_dims(batch[0], axis=0) + images = paddle.to_tensor(images) + + preds = model(images) + sr_img = preds["sr_img"][0] + lr_img = preds["lr_img"][0] + fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8) + fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8) + img_name_pure = os.path.split(file)[-1] + cv2.imwrite("infer_result/sr_{}".format(img_name_pure), + fm_sr[:, :, ::-1]) + logger.info("The visualized image saved in infer_result/sr_{}".format( + img_name_pure)) + + logger.info("success!") + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/tools/program.py b/tools/program.py index fd4e662b..34845f00 100755 --- a/tools/program.py +++ b/tools/program.py @@ -25,6 +25,8 @@ import datetime import paddle import paddle.distributed as dist from tqdm import tqdm +import cv2 +import numpy as np from argparse import ArgumentParser, RawDescriptionHelpFormatter from ppocr.utils.stats import TrainingStats @@ -262,6 +264,7 @@ def train(config, config, 'Train', device, logger, seed=epoch) max_iter = len(train_dataloader) - 1 if platform.system( ) == "Windows" else len(train_dataloader) + for idx, batch in enumerate(train_dataloader): profiler.add_profiler_step(profiler_options) train_reader_cost += time.time() - reader_start @@ -289,7 +292,7 @@ def train(config, else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) - elif model_type in ["kie", 'vqa']: + elif model_type in ["kie", 'vqa', 'sr']: preds = model(batch) else: preds = model(images) @@ -297,11 +300,12 @@ def train(config, avg_loss = loss['loss'] avg_loss.backward() optimizer.step() + optimizer.clear_grad() if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need batch = [item.numpy() for item in batch] - if model_type in ['kie']: + if model_type in ['kie', 'sr']: eval_class(preds, batch) elif model_type in ['table']: post_result = post_process_class(preds, batch) @@ -347,8 +351,8 @@ def train(config, len(train_dataloader) - idx - 1) * eta_meter.avg eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec))) strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \ - '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \ - 'ips: {:.5f} samples/s, eta: {}'.format( + '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \ + 'ips: {:.5f} samples/s, eta: {}'.format( epoch, epoch_num, global_step, logs, train_reader_cost / print_batch_step, train_batch_cost / print_batch_step, @@ -480,12 +484,13 @@ def eval(model, leave=True) max_iter = len(valid_dataloader) - 1 if platform.system( ) == "Windows" else len(valid_dataloader) + sum_images = 0 for idx, batch in enumerate(valid_dataloader): if idx >= max_iter: break images = batch[0] start = time.time() - + # use amp if scaler: with paddle.amp.auto_cast(level='O2'): @@ -493,6 +498,20 @@ def eval(model, preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: preds = model(batch) + elif model_type in ['sr']: + preds = model(batch) + sr_img = preds["sr_img"] + lr_img = preds["lr_img"] + + for i in (range(sr_img.shape[0])): + fm_sr = (sr_img[i].numpy() * 255).transpose( + 1, 2, 0).astype(np.uint8) + fm_lr = (lr_img[i].numpy() * 255).transpose( + 1, 2, 0).astype(np.uint8) + cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images, + i), fm_sr) + cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images, + i), fm_lr) else: preds = model(images) else: @@ -500,6 +519,20 @@ def eval(model, preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: preds = model(batch) + elif model_type in ['sr']: + preds = model(batch) + sr_img = preds["sr_img"] + lr_img = preds["lr_img"] + + for i in (range(sr_img.shape[0])): + fm_sr = (sr_img[i].numpy() * 255).transpose( + 1, 2, 0).astype(np.uint8) + fm_lr = (lr_img[i].numpy() * 255).transpose( + 1, 2, 0).astype(np.uint8) + cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images, + i), fm_sr) + cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images, + i), fm_lr) else: preds = model(images) @@ -517,12 +550,15 @@ def eval(model, elif model_type in ['table', 'vqa']: post_result = post_process_class(preds, batch_numpy) eval_class(post_result, batch_numpy) + elif model_type in ['sr']: + eval_class(preds, batch_numpy) else: post_result = post_process_class(preds, batch_numpy[1]) eval_class(post_result, batch_numpy) pbar.update(1) total_frame += len(images) + sum_images += 1 # Get final metric,eg. acc or hmean metric = eval_class.get_metric() @@ -616,7 +652,8 @@ def preprocess(is_train=False): 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', - 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN' + 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', + 'Gestalt' ] if use_xpu: diff --git a/tools/train.py b/tools/train.py index dc8cae8a..b44d76b3 100755 --- a/tools/train.py +++ b/tools/train.py @@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer): config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) + model = apply_to_static(model, config, logger) # build loss -- GitLab