diff --git a/configs/rec/rec_r31_robustscanner.yml b/configs/rec/rec_r31_robustscanner.yml new file mode 100644 index 0000000000000000000000000000000000000000..4d2ae57b79ca9a34b8137d82c0a2293e2273fe74 --- /dev/null +++ b/configs/rec/rec_r31_robustscanner.yml @@ -0,0 +1,108 @@ +Global: + use_gpu: true + epoch_num: 5 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./output/rec/rec_r31_robustscanner/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: ./inference/rec_inference + # for data or label process + character_dict_path: ppocr/utils/dict90.txt + max_text_length: 40 + infer_mode: False + use_space_char: False + rm_symbol: True + save_res_path: ./output/rec/predicts_robustscanner.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4] + values: [0.001, 0.0001, 0.00001] + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: RobustScanner + Transform: + Backbone: + name: ResNet31V2 + Head: + name: RobustScannerHead + enc_outchannles: 128 + hybrid_dec_rnn_layers: 2 + hybrid_dec_dropout: 0 + position_dec_rnn_layers: 2 + start_idx: 91 + mask: True + padding_idx: 92 + encode_value: False + max_seq_len: 40 + +Loss: + name: SARLoss + +PostProcess: + name: SARLabelDecode + +Metric: + name: RecMetric + is_filter: True + + +Train: + dataset: + name: LMDBDataSet + data_dir: I:/dataset/OCR/deep_text_recognition/data_lmdb/evaluation/CUTE80 + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - RobustScannerRecResizeImg: + image_shape: [3, 48, 48, 160] # h:48 w:[48,160] + width_downsample_ratio: 0.25 + max_seq_len: 40 + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 4 + drop_last: True + num_workers: 0 + use_shared_memory: False + +Eval: + dataset: + name: LMDBDataSet + data_dir: I:/dataset/OCR/deep_text_recognition/data_lmdb/evaluation/CUTE80 + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - RobustScannerRecResizeImg: + image_shape: [3, 48, 48, 160] + max_seq_len: 40 + width_downsample_ratio: 0.25 + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 0 + use_shared_memory: False + diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 6227a21498eda7d8527e21e7f2567995251d9e47..056d05ed413fd54597ba680a5206304eb5a3989a 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -66,6 +66,7 @@ - [x] [SAR](./algorithm_rec_sar.md) - [x] [SEED](./algorithm_rec_seed.md) - [x] [SVTR](./algorithm_rec_svtr.md) +- [x] [RobustScanner](./algorithm_rec_robustscanner.md) 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -84,6 +85,7 @@ |SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | +|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | [训练模型]() | diff --git a/doc/doc_ch/algorithm_rec_robustscanner.md b/doc/doc_ch/algorithm_rec_robustscanner.md new file mode 100644 index 0000000000000000000000000000000000000000..f504cf5f6e2a1e580cf092ac07daee80396097e0 --- /dev/null +++ b/doc/doc_ch/algorithm_rec_robustscanner.md @@ -0,0 +1,114 @@ +# RobustScanner + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf) +> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne +Zhang +> ECCV, 2020 + +使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下: + +|模型|骨干网络|配置文件|Acc|下载链接| +| --- | --- | --- | --- | --- | +|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[训练模型]()| + +注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。 + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 + +训练 + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: + +``` +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml +``` + +评估 + +``` +# GPU 评估, Global.pretrained_model 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +预测: + +``` +# 预测使用的配置文件必须与训练一致 +python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。( [模型下载地址]() ),可以使用如下命令进行转换: + +``` +python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner +``` +RobustScanner文本识别模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False +``` + + +### 4.2 C++推理 + +由于C++预处理后处理还未支持SAR,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@article{Li2019ShowAA, + title={Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition}, + author={Hui Li and Peng Wang and Chunhua Shen and Guyu Zhang}, + journal={ArXiv}, + year={2019}, + volume={abs/1811.00751} +} +``` diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 383cbe39bbd2eb8ca85f497888920ce87cb1837e..7579da88e328540ce8794802117d704f7bd6ddab 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -65,6 +65,7 @@ Supported text recognition algorithms (Click the link to get the tutorial): - [x] [SAR](./algorithm_rec_sar_en.md) - [x] [SEED](./algorithm_rec_seed_en.md) - [x] [SVTR](./algorithm_rec_svtr_en.md) +- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -83,7 +84,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | - +|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | [trained model]() | diff --git a/doc/doc_en/algorithm_rec_robustscanner_en.md b/doc/doc_en/algorithm_rec_robustscanner_en.md new file mode 100644 index 0000000000000000000000000000000000000000..9b6c772ac41adfa2c81e0b3f960c858ba1d17d7f --- /dev/null +++ b/doc/doc_en/algorithm_rec_robustscanner_en.md @@ -0,0 +1,115 @@ +# SAR + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf) +> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne +Zhang +> ECCV, 2020 + +Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows: + +|Model|Backbone|config|Acc|Download link| +| --- | --- | --- | --- | --- | +|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[train model]()| + +Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper. + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml +``` + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, the model saved during the RobustScanner text recognition training process is converted into an inference model. ( [Model download link]() ), you can use the following command to convert: + +``` +python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner +``` + +For RobustScanner text recognition model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False +``` + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@article{Li2019ShowAA, + title={Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition}, + author={Hui Li and Peng Wang and Chunhua Shen and Guyu Zhang}, + journal={ArXiv}, + year={2019}, + volume={abs/1811.00751} +} +``` diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 548832fb0d116ba2de622bd97562b591d74501d8..c5853dd20c2583fe7296bf7512c7ea0c93d090eb 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ - SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg + SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, \ + RobustScannerRecResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 7483dffe5b6d9a0a2204702757fcb49762a1cc7a..11af19e7ab452783c6cab701f0c20157f5d4ca56 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -206,6 +206,23 @@ class PRENResizeImg(object): data['image'] = resized_img.astype(np.float32) return data +class RobustScannerRecResizeImg(object): + def __init__(self, image_shape, max_seq_len, width_downsample_ratio=0.25, **kwargs): + self.image_shape = image_shape + self.width_downsample_ratio = width_downsample_ratio + self.max_seq_len = max_seq_len + + def __call__(self, data): + img = data['image'] + norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar( + img, self.image_shape, self.width_downsample_ratio) + word_positons = robustscanner_other_inputs(self.max_seq_len) + data['image'] = norm_img + data['resized_shape'] = resize_shape + data['pad_shape'] = pad_shape + data['valid_ratio'] = valid_ratio + data['word_positons'] = word_positons + return data def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape @@ -351,6 +368,9 @@ def srn_other_inputs(image_shape, num_heads, max_text_length): gsrm_slf_attn_bias2 ] +def robustscanner_other_inputs(max_text_length): + word_pos = np.array(range(0, max_text_length)).astype('int64') + return word_pos def flag(): """ diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 072d6e0f84d4126d256c26aa5baf17c9dc4e63df..a90051f1dc52be8a62bada9ddd8b6e7f6ff6f1a8 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -28,6 +28,7 @@ def build_backbone(config, model_type): from .rec_mv1_enhance import MobileNetV1Enhance from .rec_nrtr_mtb import MTB from .rec_resnet_31 import ResNet31 + from .rec_resnet_31_v2 import ResNet31V2 from .rec_resnet_aster import ResNet_ASTER from .rec_micronet import MicroNet from .rec_efficientb3_pren import EfficientNetb3_PREN @@ -35,7 +36,7 @@ def build_backbone(config, model_type): support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN', - 'SVTRNet' + 'SVTRNet', "ResNet31V2" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_resnet_31_v2.py b/ppocr/modeling/backbones/rec_resnet_31_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..7812b6296e33fc1f193dab88a7df788a6aa581d3 --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_31_v2.py @@ -0,0 +1,216 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + +__all__ = ["ResNet31V2"] + + +conv_weight_attr = nn.initializer.KaimingNormal() +bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1) + +def conv3x3(in_channel, out_channel, stride=1): + return nn.Conv2D( + in_channel, + out_channel, + kernel_size=3, + stride=stride, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, in_channels, channels, stride=1, downsample=False): + super().__init__() + self.conv1 = conv3x3(in_channels, channels, stride) + self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr) + self.relu = nn.ReLU() + self.conv2 = conv3x3(channels, channels) + self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr) + self.downsample = downsample + if downsample: + self.downsample = nn.Sequential( + nn.Conv2D( + in_channels, + channels * self.expansion, + 1, + stride, + weight_attr=conv_weight_attr, + bias_attr=False), + nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr)) + else: + self.downsample = nn.Sequential() + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet31V2(nn.Layer): + ''' + Args: + in_channels (int): Number of channels of input image tensor. + layers (list[int]): List of BasicBlock number for each stage. + channels (list[int]): List of out_channels of Conv2d layer. + out_indices (None | Sequence[int]): Indices of output stages. + last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage. + ''' + + def __init__(self, + in_channels=3, + layers=[1, 2, 5, 3], + channels=[64, 128, 256, 256, 512, 512, 512], + out_indices=None, + last_stage_pool=False): + super(ResNet31V2, self).__init__() + assert isinstance(in_channels, int) + assert isinstance(last_stage_pool, bool) + + self.out_indices = out_indices + self.last_stage_pool = last_stage_pool + + # conv 1 (Conv Conv) + self.conv1_1 = nn.Conv2D( + in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr) + self.relu1_1 = nn.ReLU() + + self.conv1_2 = nn.Conv2D( + channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr) + self.relu1_2 = nn.ReLU() + + # conv 2 (Max-pooling, Residual block, Conv) + self.pool2 = nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block2 = self._make_layer(channels[1], channels[2], layers[0]) + self.conv2 = nn.Conv2D( + channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr) + self.relu2 = nn.ReLU() + + # conv 3 (Max-pooling, Residual block, Conv) + self.pool3 = nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block3 = self._make_layer(channels[2], channels[3], layers[1]) + self.conv3 = nn.Conv2D( + channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr) + self.relu3 = nn.ReLU() + + # conv 4 (Max-pooling, Residual block, Conv) + self.pool4 = nn.MaxPool2D( + kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True) + self.block4 = self._make_layer(channels[3], channels[4], layers[2]) + self.conv4 = nn.Conv2D( + channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr) + self.relu4 = nn.ReLU() + + # conv 5 ((Max-pooling), Residual block, Conv) + self.pool5 = None + if self.last_stage_pool: + self.pool5 = nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block5 = self._make_layer(channels[4], channels[5], layers[3]) + self.conv5 = nn.Conv2D( + channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr) + self.relu5 = nn.ReLU() + + self.out_channels = channels[-1] + + def _make_layer(self, input_channels, output_channels, blocks): + layers = [] + for _ in range(blocks): + downsample = None + if input_channels != output_channels: + downsample = nn.Sequential( + nn.Conv2D( + input_channels, + output_channels, + kernel_size=1, + stride=1, + weight_attr=conv_weight_attr, + bias_attr=False), + nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr)) + + layers.append( + BasicBlock( + input_channels, output_channels, downsample=downsample)) + input_channels = output_channels + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu1_1(x) + + x = self.conv1_2(x) + x = self.bn1_2(x) + x = self.relu1_2(x) + + outs = [] + for i in range(4): + layer_index = i + 2 + pool_layer = getattr(self, f'pool{layer_index}') + block_layer = getattr(self, f'block{layer_index}') + conv_layer = getattr(self, f'conv{layer_index}') + bn_layer = getattr(self, f'bn{layer_index}') + relu_layer = getattr(self, f'relu{layer_index}') + + if pool_layer is not None: + x = pool_layer(x) + x = block_layer(x) + x = conv_layer(x) + x = bn_layer(x) + x = relu_layer(x) + + outs.append(x) + + if self.out_indices is not None: + return tuple([outs[i] for i in self.out_indices]) + + return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 1670ea38e66baa683e6faab0ec4b12bc517f3c41..fd2d89315b3c8ef6f9b5edc418b80249bc8d20a0 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -33,6 +33,7 @@ def build_head(config): from .rec_aster_head import AsterHead from .rec_pren_head import PRENHead from .rec_multi_head import MultiHead + from .rec_robustscanner_head import RobustScannerHead # cls head from .cls_head import ClsHead @@ -46,7 +47,7 @@ def build_head(config): 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', - 'MultiHead' + 'MultiHead', 'RobustScannerHead' ] #table head diff --git a/ppocr/modeling/heads/rec_robustscanner_head.py b/ppocr/modeling/heads/rec_robustscanner_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b458937978fc4a39f1f14230dff9d8d4fba4cfd9 --- /dev/null +++ b/ppocr/modeling/heads/rec_robustscanner_head.py @@ -0,0 +1,764 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/channel_reduction_encoder.py +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/robust_scanner_decoder.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +class BaseDecoder(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + + def forward_train(self, feat, out_enc, targets, img_metas): + raise NotImplementedError + + def forward_test(self, feat, out_enc, img_metas): + raise NotImplementedError + + def forward(self, + feat, + out_enc, + label=None, + valid_ratios=None, + word_positions=None, + train_mode=True): + self.train_mode = train_mode + + if train_mode: + return self.forward_train(feat, out_enc, label, valid_ratios, word_positions) + return self.forward_test(feat, out_enc, valid_ratios, word_positions) + +class ChannelReductionEncoder(nn.Layer): + """Change the channel number with a one by one convoluational layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + """ + + def __init__(self, + in_channels, + out_channels, + **kwargs): + super(ChannelReductionEncoder, self).__init__() + + self.layer = nn.Conv2D( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, weight_attr=nn.initializer.XavierNormal()) + + def forward(self, feat): + """ + Args: + feat (Tensor): Image features with the shape of + :math:`(N, C_{in}, H, W)`. + + Returns: + Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`. + """ + return self.layer(feat) + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + +class DotProductAttentionLayer(nn.Layer): + + def __init__(self, dim_model=None): + super().__init__() + + self.scale = dim_model**-0.5 if dim_model is not None else 1. + + def forward(self, query, key, value, h, w, valid_ratios=None): + query = paddle.transpose(query, (0, 2, 1)) + logits = paddle.matmul(query, key) * self.scale + n, c, t = logits.shape + # reshape to (n, c, h, w) + logits = paddle.reshape(logits, [n, c, h, w]) + if valid_ratios is not None: + # cal mask of attention weight + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, int(w * valid_ratio + 0.5)) + if valid_width < w: + logits[i, :, :, valid_width:] = float('-inf') + + # reshape to (n, c, h, w) + logits = paddle.reshape(logits, [n, c, t]) + weights = F.softmax(logits, axis=2) + value = paddle.transpose(value, (0, 2, 1)) + glimpse = paddle.matmul(weights, value) + glimpse = paddle.transpose(glimpse, (0, 2, 1)) + return glimpse + +class SequenceAttentionDecoder(BaseDecoder): + """Sequence attention decoder for RobustScanner. + + RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for + Robust Text Recognition `_ + + Args: + num_classes (int): Number of output classes :math:`C`. + rnn_layers (int): Number of RNN layers. + dim_input (int): Dimension :math:`D_i` of input vector ``feat``. + dim_model (int): Dimension :math:`D_m` of the model. Should also be the + same as encoder output vector ``out_enc``. + max_seq_len (int): Maximum output sequence length :math:`T`. + start_idx (int): The index of ``. + mask (bool): Whether to mask input features according to + ``img_meta['valid_ratio']``. + padding_idx (int): The index of ``. + dropout (float): Dropout rate. + return_feature (bool): Return feature or logits as the result. + encode_value (bool): Whether to use the output of encoder ``out_enc`` + as `value` of attention layer. If False, the original feature + ``feat`` will be used. + + Warning: + This decoder will not predict the final class which is assumed to be + ``. Therefore, its output size is always :math:`C - 1`. `` + is also ignored by loss as specified in + :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. + """ + + def __init__(self, + num_classes=None, + rnn_layers=2, + dim_input=512, + dim_model=128, + max_seq_len=40, + start_idx=0, + mask=True, + padding_idx=None, + dropout=0, + return_feature=False, + encode_value=False): + super().__init__() + + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.return_feature = return_feature + self.encode_value = encode_value + self.max_seq_len = max_seq_len + self.start_idx = start_idx + self.mask = mask + + self.embedding = nn.Embedding( + self.num_classes, self.dim_model, padding_idx=padding_idx) + + self.sequence_layer = nn.LSTM( + input_size=dim_model, + hidden_size=dim_model, + num_layers=rnn_layers, + time_major=False, + dropout=dropout) + + self.attention_layer = DotProductAttentionLayer() + + self.prediction = None + if not self.return_feature: + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear( + dim_model if encode_value else dim_input, pred_num_classes) + + def forward_train(self, feat, out_enc, targets, valid_ratios): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a + character. + valid_ratios (Tensor): valid length ratio of img. + Returns: + Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if + ``return_feature=False``. Otherwise it would be the hidden feature + before the prediction projection layer, whose shape is + :math:`(N, T, D_m)`. + """ + + tgt_embedding = self.embedding(targets) + + n, c_enc, h, w = out_enc.shape + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.shape + assert c_feat == self.dim_input + _, len_q, c_q = tgt_embedding.shape + assert c_q == self.dim_model + assert len_q <= self.max_seq_len + + query, _ = self.sequence_layer(tgt_embedding) + query = paddle.transpose(query, (0, 2, 1)) + key = paddle.reshape(out_enc, [n, c_enc, h * w]) + if self.encode_value: + value = key + else: + value = paddle.reshape(feat, [n, c_feat, h * w]) + + # mask = None + # if valid_ratios is not None: + # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool') + # for i, valid_ratio in enumerate(valid_ratios): + # valid_width = min(w, math.ceil(w * valid_ratio)) + # if valid_width < w: + # mask[i, :, :, valid_width:] = True + # # mask = mask.view(n, h * w) + # mask = paddle.reshape(mask, (n, len_q, h * w)) + + attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) + # attn_out = attn_out.permute(0, 2, 1).contiguous() + attn_out = paddle.transpose(attn_out, (0, 2, 1)) + + if self.return_feature: + return attn_out + + out = self.prediction(attn_out) + + return out + + def forward_test(self, feat, out_enc, valid_ratios): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + valid_ratios (Tensor): valid length ratio of img. + + Returns: + Tensor: The output logit sequence tensor of shape + :math:`(N, T, C-1)`. + """ + seq_len = self.max_seq_len + batch_size = feat.shape[0] + + # decode_sequence = (feat.new_ones( + # (batch_size, seq_len)) * self.start_idx).long() + decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx) + + outputs = [] + for i in range(seq_len): + step_out = self.forward_test_step(feat, out_enc, decode_sequence, + i, valid_ratios) + outputs.append(step_out) + max_idx = paddle.argmax(step_out, axis=1, keepdim=False) + if i < seq_len - 1: + decode_sequence[:, i + 1] = max_idx + + outputs = paddle.stack(outputs, 1) + + return outputs + + def forward_test_step(self, feat, out_enc, decode_sequence, current_step, + valid_ratios): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that + stores history decoding result. + current_step (int): Current decoding step. + valid_ratios (Tensor): valid length ratio of img + + Returns: + Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted + tokens at current time step. + """ + + embed = self.embedding(decode_sequence) + + n, c_enc, h, w = out_enc.shape + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.shape + assert c_feat == self.dim_input + _, _, c_q = embed.shape + assert c_q == self.dim_model + + query, _ = self.sequence_layer(embed) + query = paddle.transpose(query, (0, 2, 1)) + key = paddle.reshape(out_enc, [n, c_enc, h * w]) + if self.encode_value: + value = key + else: + value = paddle.reshape(feat, [n, c_feat, h * w]) + # len_q = query.shape[2] + # mask = None + # if valid_ratios is not None: + # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool') + # for i, valid_ratio in enumerate(valid_ratios): + # valid_width = min(w, math.ceil(w * valid_ratio)) + # if valid_width < w: + # mask[i, :, :, valid_width:] = True + # # mask = mask.view(n, h * w) + # mask = paddle.reshape(mask, (n, len_q, h * w)) + + # [n, c, l] + # attn_out = self.attention_layer(query, key, value, mask) + + attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) + out = attn_out[:, :, current_step] + + if self.return_feature: + return out + + out = self.prediction(out) + out = F.softmax(out, dim=-1) + + return out + + +class PositionAwareLayer(nn.Layer): + + def __init__(self, dim_model, rnn_layers=2): + super().__init__() + + self.dim_model = dim_model + + self.rnn = nn.LSTM( + input_size=dim_model, + hidden_size=dim_model, + num_layers=rnn_layers, + time_major=False) + + self.mixer = nn.Sequential( + nn.Conv2D( + dim_model, dim_model, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2D( + dim_model, dim_model, kernel_size=3, stride=1, padding=1)) + + def forward(self, img_feature): + n, c, h, w = img_feature.shape + rnn_input = paddle.transpose(img_feature, (0, 2, 3, 1)) + rnn_input = paddle.reshape(rnn_input, (n * h, w, c)) + rnn_output, _ = self.rnn(rnn_input) + rnn_output = paddle.reshape(rnn_output, (n, h, w, c)) + rnn_output = paddle.transpose(rnn_output, (0, 3, 1, 2)) + out = self.mixer(rnn_output) + return out + + +class PositionAttentionDecoder(BaseDecoder): + """Position attention decoder for RobustScanner. + + RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for + Robust Text Recognition `_ + + Args: + num_classes (int): Number of output classes :math:`C`. + rnn_layers (int): Number of RNN layers. + dim_input (int): Dimension :math:`D_i` of input vector ``feat``. + dim_model (int): Dimension :math:`D_m` of the model. Should also be the + same as encoder output vector ``out_enc``. + max_seq_len (int): Maximum output sequence length :math:`T`. + mask (bool): Whether to mask input features according to + ``img_meta['valid_ratio']``. + return_feature (bool): Return feature or logits as the result. + encode_value (bool): Whether to use the output of encoder ``out_enc`` + as `value` of attention layer. If False, the original feature + ``feat`` will be used. + + Warning: + This decoder will not predict the final class which is assumed to be + ``. Therefore, its output size is always :math:`C - 1`. `` + is also ignored by loss + + """ + + def __init__(self, + num_classes=None, + rnn_layers=2, + dim_input=512, + dim_model=128, + max_seq_len=40, + mask=True, + return_feature=False, + encode_value=False): + super().__init__() + + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.max_seq_len = max_seq_len + self.return_feature = return_feature + self.encode_value = encode_value + self.mask = mask + + self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model) + + self.position_aware_module = PositionAwareLayer( + self.dim_model, rnn_layers) + + self.attention_layer = DotProductAttentionLayer() + + self.prediction = None + if not self.return_feature: + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear( + dim_model if encode_value else dim_input, pred_num_classes) + + def _get_position_index(self, length, batch_size): + position_index_list = [] + for i in range(batch_size): + position_index = paddle.arange(0, end=length, step=1, dtype='int64') + position_index_list.append(position_index) + batch_position_index = paddle.stack(position_index_list, axis=0) + return batch_position_index + + def forward_train(self, feat, out_enc, targets, valid_ratios, position_index): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + targets (dict): A dict with the key ``padded_targets``, a + tensor of shape :math:`(N, T)`. Each element is the index of a + character. + valid_ratios (Tensor): valid length ratio of img. + position_index (Tensor): The position of each word. + + Returns: + Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if + ``return_feature=False``. Otherwise it will be the hidden feature + before the prediction projection layer, whose shape is + :math:`(N, T, D_m)`. + """ + # + n, c_enc, h, w = out_enc.shape + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.shape + assert c_feat == self.dim_input + _, len_q = targets.shape + assert len_q <= self.max_seq_len + + # position_index = self._get_position_index(len_q, n) + + position_out_enc = self.position_aware_module(out_enc) + + query = self.embedding(position_index) + query = paddle.transpose(query, (0, 2, 1)) + key = paddle.reshape(position_out_enc, (n, c_enc, h * w)) + if self.encode_value: + value = paddle.reshape(out_enc,(n, c_enc, h * w)) + else: + value = paddle.reshape(feat,(n, c_feat, h * w)) + + # mask = None + # if valid_ratios is not None: + # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool') + # for i, valid_ratio in enumerate(valid_ratios): + # valid_width = min(w, math.ceil(w * valid_ratio)) + # if valid_width < w: + # mask[i, :, :, valid_width:] = True + # # mask = mask.view(n, h * w) + # mask = paddle.reshape(mask, (n, len_q, h * w)) + + attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) + attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v] + + if self.return_feature: + return attn_out + + return self.prediction(attn_out) + + def forward_test(self, feat, out_enc, valid_ratios, position_index): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + valid_ratios (Tensor): valid length ratio of img + position_index (Tensor): The position of each word. + + Returns: + Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if + ``return_feature=False``. Otherwise it would be the hidden feature + before the prediction projection layer, whose shape is + :math:`(N, T, D_m)`. + """ + # seq_len = self.max_seq_len + n, c_enc, h, w = out_enc.shape + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.shape + assert c_feat == self.dim_input + + # the _get_position_index is not ok for export_model + # position_index = self._get_position_index(self.max_seq_len, n) + + position_out_enc = self.position_aware_module(out_enc) + + query = self.embedding(position_index) + query = paddle.transpose(query, (0, 2, 1)) + key = paddle.reshape(position_out_enc, (n, c_enc, h * w)) + if self.encode_value: + value = paddle.reshape(out_enc,(n, c_enc, h * w)) + else: + value = paddle.reshape(feat,(n, c_feat, h * w)) + # len_q = query.shape[2] + # mask = None + # if valid_ratios is not None: + # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool') + # for i, valid_ratio in enumerate(valid_ratios): + # valid_width = min(w, math.ceil(w * valid_ratio)) + # if valid_width < w: + # mask[i, :, :, valid_width:] = True + # # mask = mask.view(n, h * w) + # mask = paddle.reshape(mask, (n, len_q, h * w)) + + attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) + attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v] + + if self.return_feature: + return attn_out + + return self.prediction(attn_out) + +class RobustScannerFusionLayer(nn.Layer): + + def __init__(self, dim_model, dim=-1): + super(RobustScannerFusionLayer, self).__init__() + + self.dim_model = dim_model + self.dim = dim + self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2) + + def forward(self, x0, x1): + assert x0.shape == x1.shape + fusion_input = paddle.concat([x0, x1], self.dim) + output = self.linear_layer(fusion_input) + output = F.glu(output, self.dim) + return output + +class RobustScannerDecoder(BaseDecoder): + """Decoder for RobustScanner. + + RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for + Robust Text Recognition `_ + + Args: + num_classes (int): Number of output classes :math:`C`. + dim_input (int): Dimension :math:`D_i` of input vector ``feat``. + dim_model (int): Dimension :math:`D_m` of the model. Should also be the + same as encoder output vector ``out_enc``. + max_seq_len (int): Maximum output sequence length :math:`T`. + start_idx (int): The index of ``. + mask (bool): Whether to mask input features according to + ``img_meta['valid_ratio']``. + padding_idx (int): The index of ``. + encode_value (bool): Whether to use the output of encoder ``out_enc`` + as `value` of attention layer. If False, the original feature + ``feat`` will be used. + + Warning: + This decoder will not predict the final class which is assumed to be + ``. Therefore, its output size is always :math:`C - 1`. `` + is also ignored by loss as specified in + :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. + """ + + def __init__(self, + num_classes=None, + dim_input=512, + dim_model=128, + hybrid_decoder_rnn_layers=2, + hybrid_decoder_dropout=0, + position_decoder_rnn_layers=2, + max_seq_len=40, + start_idx=0, + mask=True, + padding_idx=None, + encode_value=False): + super().__init__() + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.max_seq_len = max_seq_len + self.encode_value = encode_value + self.start_idx = start_idx + self.padding_idx = padding_idx + self.mask = mask + + # init hybrid decoder + self.hybrid_decoder = SequenceAttentionDecoder( + num_classes=num_classes, + rnn_layers=hybrid_decoder_rnn_layers, + dim_input=dim_input, + dim_model=dim_model, + max_seq_len=max_seq_len, + start_idx=start_idx, + mask=mask, + padding_idx=padding_idx, + dropout=hybrid_decoder_dropout, + encode_value=encode_value, + return_feature=True + ) + + # init position decoder + self.position_decoder = PositionAttentionDecoder( + num_classes=num_classes, + rnn_layers=position_decoder_rnn_layers, + dim_input=dim_input, + dim_model=dim_model, + max_seq_len=max_seq_len, + mask=mask, + encode_value=encode_value, + return_feature=True + ) + + + self.fusion_module = RobustScannerFusionLayer( + self.dim_model if encode_value else dim_input) + + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear(dim_model if encode_value else dim_input, + pred_num_classes) + + def forward_train(self, feat, out_enc, target, valid_ratios, word_positions): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + target (dict): A dict with the key ``padded_targets``, a + tensor of shape :math:`(N, T)`. Each element is the index of a + character. + valid_ratios (Tensor): + word_positions (Tensor): The position of each word. + + Returns: + Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`. + """ + hybrid_glimpse = self.hybrid_decoder.forward_train( + feat, out_enc, target, valid_ratios) + position_glimpse = self.position_decoder.forward_train( + feat, out_enc, target, valid_ratios, word_positions) + + fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse) + + out = self.prediction(fusion_out) + + return out + + def forward_test(self, feat, out_enc, valid_ratios, word_positions): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + valid_ratios (Tensor): + word_positions (Tensor): The position of each word. + Returns: + Tensor: The output logit sequence tensor of shape + :math:`(N, T, C-1)`. + """ + seq_len = self.max_seq_len + batch_size = feat.shape[0] + + # decode_sequence = (feat.new_ones( + # (batch_size, seq_len)) * self.start_idx).long() + + decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx) + + position_glimpse = self.position_decoder.forward_test( + feat, out_enc, valid_ratios, word_positions) + + outputs = [] + for i in range(seq_len): + hybrid_glimpse_step = self.hybrid_decoder.forward_test_step( + feat, out_enc, decode_sequence, i, valid_ratios) + + fusion_out = self.fusion_module(hybrid_glimpse_step, + position_glimpse[:, i, :]) + + char_out = self.prediction(fusion_out) + char_out = F.softmax(char_out, -1) + outputs.append(char_out) + max_idx = paddle.argmax(char_out, axis=1, keepdim=False) + if i < seq_len - 1: + decode_sequence[:, i + 1] = max_idx + + outputs = paddle.stack(outputs, 1) + + return outputs + +class RobustScannerHead(nn.Layer): + def __init__(self, + out_channels, # 90 + unknown + start + padding + in_channels, + enc_outchannles=128, + hybrid_dec_rnn_layers=2, + hybrid_dec_dropout=0, + position_dec_rnn_layers=2, + start_idx=0, + max_seq_len=40, + mask=True, + padding_idx=None, + encode_value=False, + **kwargs): + super(RobustScannerHead, self).__init__() + + # encoder module + self.encoder = ChannelReductionEncoder( + in_channels=in_channels, out_channels=enc_outchannles) + + # decoder module + self.decoder =RobustScannerDecoder( + num_classes=out_channels, + dim_input=in_channels, + dim_model=enc_outchannles, + hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers, + hybrid_decoder_dropout=hybrid_dec_dropout, + position_decoder_rnn_layers=position_dec_rnn_layers, + max_seq_len=max_seq_len, + start_idx=start_idx, + mask=mask, + padding_idx=padding_idx, + encode_value=encode_value) + + def forward(self, inputs, targets=None): + ''' + targets: [label, valid_ratio, word_positions] + ''' + out_enc = self.encoder(inputs) + valid_ratios = None + word_positions = targets[-1] + + if len(targets) > 1: + valid_ratios = targets[-2] + + if self.training: + label = targets[0] # label + label = paddle.to_tensor(label, dtype='int64') + final_out = self.decoder( + inputs, out_enc, label, valid_ratios, word_positions) + if not self.training: + final_out = self.decoder( + inputs, + out_enc, + label=None, + valid_ratios=valid_ratios, + word_positions=word_positions, + train_mode=False) + return final_out diff --git a/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml new file mode 100644 index 0000000000000000000000000000000000000000..20ec9be9662c1fbf13f20ca5cc6451b8ad8a5da6 --- /dev/null +++ b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml @@ -0,0 +1,110 @@ +Global: + use_gpu: true + epoch_num: 5 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./output/rec/rec_r31_robustscanner/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + # for data or label process + character_dict_path: ppocr/utils/dict90.txt + max_text_length: 40 + infer_mode: False + use_space_char: False + rm_symbol: True + save_res_path: ./output/rec/predicts_robustscanner.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4] + values: [0.001, 0.0001, 0.00001] + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: RobustScanner + Transform: + Backbone: + name: ResNet31V2 + Head: + name: RobustScannerHead + enc_outchannles: 128 + hybrid_dec_rnn_layers: 2 + hybrid_dec_dropout: 0 + position_dec_rnn_layers: 2 + start_idx: 91 + mask: True + padding_idx: 92 + encode_value: False + max_seq_len: 40 + +Loss: + name: SARLoss + +PostProcess: + name: SARLabelDecode + +Metric: + name: RecMetric + is_filter: True + + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - RobustScannerRecResizeImg: + image_shape: [3, 48, 48, 160] # h:48 w:[48,160] + width_downsample_ratio: 0.25 + max_seq_len: 40 + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 16 + drop_last: True + num_workers: 0 + use_shared_memory: False + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data + label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - RobustScannerRecResizeImg: + image_shape: [3, 48, 48, 160] + max_seq_len: 40 + width_downsample_ratio: 0.25 + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 16 + num_workers: 0 + use_shared_memory: False + diff --git a/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt b/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..2f58d8f3e407bd3bc4431c7f56a81a42cae44584 --- /dev/null +++ b/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt @@ -0,0 +1,52 @@ +===========================train_params=========================== +model_name:rec_r31_robustscanner +python:python +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=5 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./inference/rec_inference +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/rec_r31_robustscanner/best_accuracy +infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o +infer_quant:False +inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner" +--use_gpu:True|False +--enable_mkldnn:True|False +--cpu_threads:1|6 +--rec_batch_num:1|6 +--use_tensorrt:False|False +--precision:fp32|int8 +--rec_model_dir: +--image_dir:./inference/rec_inference +--save_log_path:./test/output/ +--benchmark:True +null:null + diff --git a/tools/eval.py b/tools/eval.py index cab28334396c54f1526f830044de0772b5402a11..6f5189fd68bf3832386f59466c9ccf066fb3b663 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -73,7 +73,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "RobustScanner"] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: diff --git a/tools/export_model.py b/tools/export_model.py index c0cbcd361cec31c51616a7154836c234f076a86e..1a894bdf2c86e94b9fcdae0ae9c1dff32075688f 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -73,6 +73,22 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): shape=[None, 3, 64, 512], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "RobustScanner": + max_seq_len = arch_config["Head"]["max_seq_len"] + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, 160], dtype="float32"), + + [ + paddle.static.InputSpec( + shape=[None, ], + dtype="float32"), + paddle.static.InputSpec( + shape=[None, max_seq_len], + dtype="int64") + ] + ] + model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 3664ef2caf4b888d6a3918202256c99cc54c5eb1..6647c9ab524bba50f50c47c14ac509f8073b7923 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -69,6 +69,14 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "RobustScanner": + postprocess_params = { + 'name': 'SARLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char, + "rm_symbol": True + + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -266,7 +274,8 @@ class TextRecognizer(object): for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] - imgC, imgH, imgW = self.rec_image_shape + # imgC, imgH, imgW = self.rec_image_shape + imgH, imgW = self.rec_image_shape[-2:] max_wh_ratio = imgW / imgH # max_wh_ratio = 0 for ino in range(beg_img_no, end_img_no): @@ -300,6 +309,18 @@ class TextRecognizer(object): self.rec_image_shape) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) + elif self.rec_algorithm == "RobustScanner": + norm_img, _, _, valid_ratio = self.resize_norm_img_sar( + img_list[indices[ino]], self.rec_image_shape, width_downsample_ratio=0.25) + norm_img = norm_img[np.newaxis, :] + valid_ratio = np.expand_dims(valid_ratio, axis=0) + valid_ratios = [] + valid_ratios.append(valid_ratio) + norm_img_batch.append(norm_img) + word_positions_list = [] + word_positions = np.array(range(0, 40)).astype('int64') + word_positions = np.expand_dims(word_positions, axis=0) + word_positions_list.append(word_positions) else: norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) @@ -351,6 +372,35 @@ class TextRecognizer(object): norm_img_batch, valid_ratios, ] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs[0] + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle( + input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs[0] + elif self.rec_algorithm == "RobustScanner": + valid_ratios = np.concatenate(valid_ratios) + word_positions_list = np.concatenate(word_positions_list) + inputs = [ + norm_img_batch, + valid_ratios, + word_positions_list + ] + if self.use_onnx: input_dict = {} input_dict[self.input_tensor.name] = norm_img_batch diff --git a/tools/infer_rec.py b/tools/infer_rec.py index a08fa25b467482da4a2996912ad2cc8cc7c398da..670733cb98d4a31bf3d27630cfe06e0e9e37114b 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -96,6 +96,8 @@ def main(): ] elif config['Architecture']['algorithm'] == "SAR": op[op_name]['keep_keys'] = ['image', 'valid_ratio'] + elif config['Architecture']['algorithm'] == "RobustScanner": + op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons'] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) @@ -131,6 +133,12 @@ def main(): if config['Architecture']['algorithm'] == "SAR": valid_ratio = np.expand_dims(batch[-1], axis=0) img_metas = [paddle.to_tensor(valid_ratio)] + if config['Architecture']['algorithm'] == "RobustScanner": + valid_ratio = np.expand_dims(batch[1], axis=0) + word_positons = np.expand_dims(batch[2], axis=0) + img_metas = [paddle.to_tensor(valid_ratio), + paddle.to_tensor(word_positons), + ] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) @@ -138,6 +146,8 @@ def main(): preds = model(images, others) elif config['Architecture']['algorithm'] == "SAR": preds = model(images, img_metas) + elif config['Architecture']['algorithm'] == "RobustScanner": + preds = model(images, img_metas) else: preds = model(images) post_result = post_process_class(preds) diff --git a/tools/program.py b/tools/program.py index 7c02dc0149f36085ef05ca378b79d27e92d6dd57..d188174992fdb5008998cc42dc86ab1977ab9640 100755 --- a/tools/program.py +++ b/tools/program.py @@ -202,7 +202,7 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "RobustScanner"] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: @@ -559,7 +559,8 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', + 'RobustScanner' ] device = 'cpu'