diff --git a/configs/rec/rec_r31_robustscanner.yml b/configs/rec/rec_r31_robustscanner.yml index 4d2ae57b79ca9a34b8137d82c0a2293e2273fe74..3e1d3009ce42981a191afbbc9375b9898a9c3004 100644 --- a/configs/rec/rec_r31_robustscanner.yml +++ b/configs/rec/rec_r31_robustscanner.yml @@ -15,7 +15,7 @@ Global: infer_img: ./inference/rec_inference # for data or label process character_dict_path: ppocr/utils/dict90.txt - max_text_length: 40 + max_text_length: &max_text_length 40 infer_mode: False use_space_char: False rm_symbol: True @@ -38,7 +38,7 @@ Architecture: algorithm: RobustScanner Transform: Backbone: - name: ResNet31V2 + name: ResNet31 Head: name: RobustScannerHead enc_outchannles: 128 @@ -49,7 +49,7 @@ Architecture: mask: True padding_idx: 92 encode_value: False - max_seq_len: 40 + max_text_length: *max_text_length Loss: name: SARLoss @@ -64,8 +64,9 @@ Metric: Train: dataset: - name: LMDBDataSet - data_dir: I:/dataset/OCR/deep_text_recognition/data_lmdb/evaluation/CUTE80 + name: SimpleDataSet + label_file_list: ['./train_data/train_list.txt'] + data_dir: ./train_data/ transforms: - DecodeImage: # load image img_mode: BGR @@ -74,20 +75,20 @@ Train: - RobustScannerRecResizeImg: image_shape: [3, 48, 48, 160] # h:48 w:[48,160] width_downsample_ratio: 0.25 - max_seq_len: 40 + max_text_length: *max_text_length - KeepKeys: keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order loader: shuffle: True - batch_size_per_card: 4 + batch_size_per_card: 64 drop_last: True - num_workers: 0 + num_workers: 8 use_shared_memory: False Eval: dataset: name: LMDBDataSet - data_dir: I:/dataset/OCR/deep_text_recognition/data_lmdb/evaluation/CUTE80 + data_dir: ./train_data/data_lmdb_release/evaluation/ transforms: - DecodeImage: # load image img_mode: BGR @@ -95,14 +96,14 @@ Eval: - SARLabelEncode: # Class handling label - RobustScannerRecResizeImg: image_shape: [3, 48, 48, 160] - max_seq_len: 40 + max_seq_len: *max_text_length width_downsample_ratio: 0.25 - KeepKeys: keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order loader: shuffle: False drop_last: False - batch_size_per_card: 1 - num_workers: 0 + batch_size_per_card: 64 + num_workers: 4 use_shared_memory: False diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 056d05ed413fd54597ba680a5206304eb5a3989a..c32d5a77b64599e7817f620a2dec6b1724162ec0 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -85,7 +85,7 @@ |SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | -|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | [训练模型]() | +|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | coming soon | diff --git a/doc/doc_ch/algorithm_rec_robustscanner.md b/doc/doc_ch/algorithm_rec_robustscanner.md index f504cf5f6e2a1e580cf092ac07daee80396097e0..869f9a7c00b617de87ab3c96326e18e536bc18a8 100644 --- a/doc/doc_ch/algorithm_rec_robustscanner.md +++ b/doc/doc_ch/algorithm_rec_robustscanner.md @@ -26,7 +26,7 @@ Zhang |模型|骨干网络|配置文件|Acc|下载链接| | --- | --- | --- | --- | --- | -|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[训练模型]()| +|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| 注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。 @@ -71,7 +71,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pr ### 4.1 Python推理 -首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。( [模型下载地址]() ),可以使用如下命令进行转换: +首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换: ``` python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner @@ -85,7 +85,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" ### 4.2 C++推理 -由于C++预处理后处理还未支持SAR,所以暂未支持 +由于C++预处理后处理还未支持RobustScanner,所以暂未支持 ### 4.3 Serving服务化部署 @@ -104,11 +104,10 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" ## 引用 ```bibtex -@article{Li2019ShowAA, - title={Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition}, - author={Hui Li and Peng Wang and Chunhua Shen and Guyu Zhang}, - journal={ArXiv}, - year={2019}, - volume={abs/1811.00751} +@article{2020RobustScanner, + title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition}, + author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang}, + journal={ECCV2020}, + year={2020}, } ``` diff --git a/doc/doc_en/algorithm_rec_robustscanner_en.md b/doc/doc_en/algorithm_rec_robustscanner_en.md index 9b6c772ac41adfa2c81e0b3f960c858ba1d17d7f..a5454476de9bfdbdb099157d1fa314be09e9bc0a 100644 --- a/doc/doc_en/algorithm_rec_robustscanner_en.md +++ b/doc/doc_en/algorithm_rec_robustscanner_en.md @@ -1,4 +1,4 @@ -# SAR +# RobustScanner - [1. Introduction](#1) - [2. Environment](#2) @@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval |Model|Backbone|config|Acc|Download link| | --- | --- | --- | --- | --- | -|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[train model]()| +|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper. @@ -71,7 +71,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pr ### 4.1 Python Inference -First, the model saved during the RobustScanner text recognition training process is converted into an inference model. ( [Model download link]() ), you can use the following command to convert: +First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert: ``` python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner @@ -105,11 +105,10 @@ Not supported ## Citation ```bibtex -@article{Li2019ShowAA, - title={Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition}, - author={Hui Li and Peng Wang and Chunhua Shen and Guyu Zhang}, - journal={ArXiv}, - year={2019}, - volume={abs/1811.00751} +@article{2020RobustScanner, + title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition}, + author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang}, + journal={ECCV2020}, + year={2020}, } ``` diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index b97b78ce4402a96fb8cdf2117a68e377e5d7b1ca..aa4523329d3881a4cadb185f00beea38bf109cd3 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -268,16 +268,16 @@ class PRENResizeImg(object): return data class RobustScannerRecResizeImg(object): - def __init__(self, image_shape, max_seq_len, width_downsample_ratio=0.25, **kwargs): + def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs): self.image_shape = image_shape self.width_downsample_ratio = width_downsample_ratio - self.max_seq_len = max_seq_len + self.max_text_length = max_text_length def __call__(self, data): img = data['image'] norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar( img, self.image_shape, self.width_downsample_ratio) - word_positons = robustscanner_other_inputs(self.max_seq_len) + word_positons = np.array(range(0, self.max_text_length)).astype('int64') data['image'] = norm_img data['resized_shape'] = resize_shape data['pad_shape'] = pad_shape @@ -429,9 +429,6 @@ def srn_other_inputs(image_shape, num_heads, max_text_length): gsrm_slf_attn_bias2 ] -def robustscanner_other_inputs(max_text_length): - word_pos = np.array(range(0, max_text_length)).astype('int64') - return word_pos def flag(): """ diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index a90051f1dc52be8a62bada9ddd8b6e7f6ff6f1a8..0cc894dbfbb44e7433d9e07a41ce2b9f5a6f4bca 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -28,7 +28,6 @@ def build_backbone(config, model_type): from .rec_mv1_enhance import MobileNetV1Enhance from .rec_nrtr_mtb import MTB from .rec_resnet_31 import ResNet31 - from .rec_resnet_31_v2 import ResNet31V2 from .rec_resnet_aster import ResNet_ASTER from .rec_micronet import MicroNet from .rec_efficientb3_pren import EfficientNetb3_PREN diff --git a/ppocr/modeling/backbones/rec_resnet_31.py b/ppocr/modeling/backbones/rec_resnet_31.py index 965170138d00a53fca720b3b5f535a3dd34272d9..e1d77c405dfed1541f6a0197af14d4edeb908803 100644 --- a/ppocr/modeling/backbones/rec_resnet_31.py +++ b/ppocr/modeling/backbones/rec_resnet_31.py @@ -27,9 +27,12 @@ import paddle.nn as nn import paddle.nn.functional as F import numpy as np -__all__ = ["ResNet31"] +__all__ = ["ResNet31V2"] +conv_weight_attr = nn.initializer.KaimingNormal() +bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1) + def conv3x3(in_channel, out_channel, stride=1): return nn.Conv2D( in_channel, @@ -37,6 +40,7 @@ def conv3x3(in_channel, out_channel, stride=1): kernel_size=3, stride=stride, padding=1, + weight_attr=conv_weight_attr, bias_attr=False) @@ -46,10 +50,10 @@ class BasicBlock(nn.Layer): def __init__(self, in_channels, channels, stride=1, downsample=False): super().__init__() self.conv1 = conv3x3(in_channels, channels, stride) - self.bn1 = nn.BatchNorm2D(channels) + self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr) self.relu = nn.ReLU() self.conv2 = conv3x3(channels, channels) - self.bn2 = nn.BatchNorm2D(channels) + self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr) self.downsample = downsample if downsample: self.downsample = nn.Sequential( @@ -58,8 +62,9 @@ class BasicBlock(nn.Layer): channels * self.expansion, 1, stride, + weight_attr=conv_weight_attr, bias_attr=False), - nn.BatchNorm2D(channels * self.expansion), ) + nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr)) else: self.downsample = nn.Sequential() self.stride = stride @@ -108,13 +113,13 @@ class ResNet31(nn.Layer): # conv 1 (Conv Conv) self.conv1_1 = nn.Conv2D( - in_channels, channels[0], kernel_size=3, stride=1, padding=1) - self.bn1_1 = nn.BatchNorm2D(channels[0]) + in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr) self.relu1_1 = nn.ReLU() self.conv1_2 = nn.Conv2D( - channels[0], channels[1], kernel_size=3, stride=1, padding=1) - self.bn1_2 = nn.BatchNorm2D(channels[1]) + channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr) self.relu1_2 = nn.ReLU() # conv 2 (Max-pooling, Residual block, Conv) @@ -122,8 +127,8 @@ class ResNet31(nn.Layer): kernel_size=2, stride=2, padding=0, ceil_mode=True) self.block2 = self._make_layer(channels[1], channels[2], layers[0]) self.conv2 = nn.Conv2D( - channels[2], channels[2], kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2D(channels[2]) + channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr) self.relu2 = nn.ReLU() # conv 3 (Max-pooling, Residual block, Conv) @@ -131,8 +136,8 @@ class ResNet31(nn.Layer): kernel_size=2, stride=2, padding=0, ceil_mode=True) self.block3 = self._make_layer(channels[2], channels[3], layers[1]) self.conv3 = nn.Conv2D( - channels[3], channels[3], kernel_size=3, stride=1, padding=1) - self.bn3 = nn.BatchNorm2D(channels[3]) + channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr) self.relu3 = nn.ReLU() # conv 4 (Max-pooling, Residual block, Conv) @@ -140,8 +145,8 @@ class ResNet31(nn.Layer): kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True) self.block4 = self._make_layer(channels[3], channels[4], layers[2]) self.conv4 = nn.Conv2D( - channels[4], channels[4], kernel_size=3, stride=1, padding=1) - self.bn4 = nn.BatchNorm2D(channels[4]) + channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr) self.relu4 = nn.ReLU() # conv 5 ((Max-pooling), Residual block, Conv) @@ -151,8 +156,8 @@ class ResNet31(nn.Layer): kernel_size=2, stride=2, padding=0, ceil_mode=True) self.block5 = self._make_layer(channels[4], channels[5], layers[3]) self.conv5 = nn.Conv2D( - channels[5], channels[5], kernel_size=3, stride=1, padding=1) - self.bn5 = nn.BatchNorm2D(channels[5]) + channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr) self.relu5 = nn.ReLU() self.out_channels = channels[-1] @@ -168,8 +173,9 @@ class ResNet31(nn.Layer): output_channels, kernel_size=1, stride=1, + weight_attr=conv_weight_attr, bias_attr=False), - nn.BatchNorm2D(output_channels), ) + nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr)) layers.append( BasicBlock( diff --git a/ppocr/modeling/backbones/rec_resnet_31_v2.py b/ppocr/modeling/backbones/rec_resnet_31_v2.py deleted file mode 100644 index 7812b6296e33fc1f193dab88a7df788a6aa581d3..0000000000000000000000000000000000000000 --- a/ppocr/modeling/backbones/rec_resnet_31_v2.py +++ /dev/null @@ -1,216 +0,0 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This code is refer from: -https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py -https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import paddle -from paddle import ParamAttr -import paddle.nn as nn -import paddle.nn.functional as F -import numpy as np - -__all__ = ["ResNet31V2"] - - -conv_weight_attr = nn.initializer.KaimingNormal() -bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1) - -def conv3x3(in_channel, out_channel, stride=1): - return nn.Conv2D( - in_channel, - out_channel, - kernel_size=3, - stride=stride, - padding=1, - weight_attr=conv_weight_attr, - bias_attr=False) - - -class BasicBlock(nn.Layer): - expansion = 1 - - def __init__(self, in_channels, channels, stride=1, downsample=False): - super().__init__() - self.conv1 = conv3x3(in_channels, channels, stride) - self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr) - self.relu = nn.ReLU() - self.conv2 = conv3x3(channels, channels) - self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr) - self.downsample = downsample - if downsample: - self.downsample = nn.Sequential( - nn.Conv2D( - in_channels, - channels * self.expansion, - 1, - stride, - weight_attr=conv_weight_attr, - bias_attr=False), - nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr)) - else: - self.downsample = nn.Sequential() - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNet31V2(nn.Layer): - ''' - Args: - in_channels (int): Number of channels of input image tensor. - layers (list[int]): List of BasicBlock number for each stage. - channels (list[int]): List of out_channels of Conv2d layer. - out_indices (None | Sequence[int]): Indices of output stages. - last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage. - ''' - - def __init__(self, - in_channels=3, - layers=[1, 2, 5, 3], - channels=[64, 128, 256, 256, 512, 512, 512], - out_indices=None, - last_stage_pool=False): - super(ResNet31V2, self).__init__() - assert isinstance(in_channels, int) - assert isinstance(last_stage_pool, bool) - - self.out_indices = out_indices - self.last_stage_pool = last_stage_pool - - # conv 1 (Conv Conv) - self.conv1_1 = nn.Conv2D( - in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) - self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr) - self.relu1_1 = nn.ReLU() - - self.conv1_2 = nn.Conv2D( - channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) - self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr) - self.relu1_2 = nn.ReLU() - - # conv 2 (Max-pooling, Residual block, Conv) - self.pool2 = nn.MaxPool2D( - kernel_size=2, stride=2, padding=0, ceil_mode=True) - self.block2 = self._make_layer(channels[1], channels[2], layers[0]) - self.conv2 = nn.Conv2D( - channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) - self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr) - self.relu2 = nn.ReLU() - - # conv 3 (Max-pooling, Residual block, Conv) - self.pool3 = nn.MaxPool2D( - kernel_size=2, stride=2, padding=0, ceil_mode=True) - self.block3 = self._make_layer(channels[2], channels[3], layers[1]) - self.conv3 = nn.Conv2D( - channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) - self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr) - self.relu3 = nn.ReLU() - - # conv 4 (Max-pooling, Residual block, Conv) - self.pool4 = nn.MaxPool2D( - kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True) - self.block4 = self._make_layer(channels[3], channels[4], layers[2]) - self.conv4 = nn.Conv2D( - channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) - self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr) - self.relu4 = nn.ReLU() - - # conv 5 ((Max-pooling), Residual block, Conv) - self.pool5 = None - if self.last_stage_pool: - self.pool5 = nn.MaxPool2D( - kernel_size=2, stride=2, padding=0, ceil_mode=True) - self.block5 = self._make_layer(channels[4], channels[5], layers[3]) - self.conv5 = nn.Conv2D( - channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) - self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr) - self.relu5 = nn.ReLU() - - self.out_channels = channels[-1] - - def _make_layer(self, input_channels, output_channels, blocks): - layers = [] - for _ in range(blocks): - downsample = None - if input_channels != output_channels: - downsample = nn.Sequential( - nn.Conv2D( - input_channels, - output_channels, - kernel_size=1, - stride=1, - weight_attr=conv_weight_attr, - bias_attr=False), - nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr)) - - layers.append( - BasicBlock( - input_channels, output_channels, downsample=downsample)) - input_channels = output_channels - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1_1(x) - x = self.bn1_1(x) - x = self.relu1_1(x) - - x = self.conv1_2(x) - x = self.bn1_2(x) - x = self.relu1_2(x) - - outs = [] - for i in range(4): - layer_index = i + 2 - pool_layer = getattr(self, f'pool{layer_index}') - block_layer = getattr(self, f'block{layer_index}') - conv_layer = getattr(self, f'conv{layer_index}') - bn_layer = getattr(self, f'bn{layer_index}') - relu_layer = getattr(self, f'relu{layer_index}') - - if pool_layer is not None: - x = pool_layer(x) - x = block_layer(x) - x = conv_layer(x) - x = bn_layer(x) - x = relu_layer(x) - - outs.append(x) - - if self.out_indices is not None: - return tuple([outs[i] for i in self.out_indices]) - - return x diff --git a/ppocr/modeling/heads/rec_robustscanner_head.py b/ppocr/modeling/heads/rec_robustscanner_head.py index b458937978fc4a39f1f14230dff9d8d4fba4cfd9..fc889d59cb558adaf0b422af27fec65db4d8d63f 100644 --- a/ppocr/modeling/heads/rec_robustscanner_head.py +++ b/ppocr/modeling/heads/rec_robustscanner_head.py @@ -217,18 +217,7 @@ class SequenceAttentionDecoder(BaseDecoder): else: value = paddle.reshape(feat, [n, c_feat, h * w]) - # mask = None - # if valid_ratios is not None: - # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool') - # for i, valid_ratio in enumerate(valid_ratios): - # valid_width = min(w, math.ceil(w * valid_ratio)) - # if valid_width < w: - # mask[i, :, :, valid_width:] = True - # # mask = mask.view(n, h * w) - # mask = paddle.reshape(mask, (n, len_q, h * w)) - attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) - # attn_out = attn_out.permute(0, 2, 1).contiguous() attn_out = paddle.transpose(attn_out, (0, 2, 1)) if self.return_feature: @@ -253,8 +242,6 @@ class SequenceAttentionDecoder(BaseDecoder): seq_len = self.max_seq_len batch_size = feat.shape[0] - # decode_sequence = (feat.new_ones( - # (batch_size, seq_len)) * self.start_idx).long() decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx) outputs = [] @@ -303,20 +290,8 @@ class SequenceAttentionDecoder(BaseDecoder): value = key else: value = paddle.reshape(feat, [n, c_feat, h * w]) - # len_q = query.shape[2] - # mask = None - # if valid_ratios is not None: - # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool') - # for i, valid_ratio in enumerate(valid_ratios): - # valid_width = min(w, math.ceil(w * valid_ratio)) - # if valid_width < w: - # mask[i, :, :, valid_width:] = True - # # mask = mask.view(n, h * w) - # mask = paddle.reshape(mask, (n, len_q, h * w)) # [n, c, l] - # attn_out = self.attention_layer(query, key, value, mask) - attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) out = attn_out[:, :, current_step] @@ -445,7 +420,6 @@ class PositionAttentionDecoder(BaseDecoder): before the prediction projection layer, whose shape is :math:`(N, T, D_m)`. """ - # n, c_enc, h, w = out_enc.shape assert c_enc == self.dim_model _, c_feat, _, _ = feat.shape @@ -453,8 +427,6 @@ class PositionAttentionDecoder(BaseDecoder): _, len_q = targets.shape assert len_q <= self.max_seq_len - # position_index = self._get_position_index(len_q, n) - position_out_enc = self.position_aware_module(out_enc) query = self.embedding(position_index) @@ -465,16 +437,6 @@ class PositionAttentionDecoder(BaseDecoder): else: value = paddle.reshape(feat,(n, c_feat, h * w)) - # mask = None - # if valid_ratios is not None: - # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool') - # for i, valid_ratio in enumerate(valid_ratios): - # valid_width = min(w, math.ceil(w * valid_ratio)) - # if valid_width < w: - # mask[i, :, :, valid_width:] = True - # # mask = mask.view(n, h * w) - # mask = paddle.reshape(mask, (n, len_q, h * w)) - attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v] @@ -498,7 +460,6 @@ class PositionAttentionDecoder(BaseDecoder): before the prediction projection layer, whose shape is :math:`(N, T, D_m)`. """ - # seq_len = self.max_seq_len n, c_enc, h, w = out_enc.shape assert c_enc == self.dim_model _, c_feat, _, _ = feat.shape @@ -516,16 +477,6 @@ class PositionAttentionDecoder(BaseDecoder): value = paddle.reshape(out_enc,(n, c_enc, h * w)) else: value = paddle.reshape(feat,(n, c_feat, h * w)) - # len_q = query.shape[2] - # mask = None - # if valid_ratios is not None: - # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool') - # for i, valid_ratio in enumerate(valid_ratios): - # valid_width = min(w, math.ceil(w * valid_ratio)) - # if valid_width < w: - # mask[i, :, :, valid_width:] = True - # # mask = mask.view(n, h * w) - # mask = paddle.reshape(mask, (n, len_q, h * w)) attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v] @@ -676,9 +627,6 @@ class RobustScannerDecoder(BaseDecoder): seq_len = self.max_seq_len batch_size = feat.shape[0] - # decode_sequence = (feat.new_ones( - # (batch_size, seq_len)) * self.start_idx).long() - decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx) position_glimpse = self.position_decoder.forward_test( @@ -712,7 +660,7 @@ class RobustScannerHead(nn.Layer): hybrid_dec_dropout=0, position_dec_rnn_layers=2, start_idx=0, - max_seq_len=40, + max_text_length=40, mask=True, padding_idx=None, encode_value=False, @@ -731,7 +679,7 @@ class RobustScannerHead(nn.Layer): hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers, hybrid_decoder_dropout=hybrid_dec_dropout, position_decoder_rnn_layers=position_dec_rnn_layers, - max_seq_len=max_seq_len, + max_seq_len=max_text_length, start_idx=start_idx, mask=mask, padding_idx=padding_idx, diff --git a/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml index 20ec9be9662c1fbf13f20ca5cc6451b8ad8a5da6..a49f332aae0eeb1b7a9d728d37095e5d88c3e6f5 100644 --- a/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml +++ b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml @@ -15,7 +15,7 @@ Global: infer_img: # for data or label process character_dict_path: ppocr/utils/dict90.txt - max_text_length: 40 + max_text_length: &max_text_length 40 infer_mode: False use_space_char: False rm_symbol: True @@ -38,7 +38,7 @@ Architecture: algorithm: RobustScanner Transform: Backbone: - name: ResNet31V2 + name: ResNet31 Head: name: RobustScannerHead enc_outchannles: 128 @@ -75,7 +75,7 @@ Train: - RobustScannerRecResizeImg: image_shape: [3, 48, 48, 160] # h:48 w:[48,160] width_downsample_ratio: 0.25 - max_seq_len: 40 + max_seq_len: *max_text_length - KeepKeys: keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order loader: @@ -97,7 +97,7 @@ Eval: - SARLabelEncode: # Class handling label - RobustScannerRecResizeImg: image_shape: [3, 48, 48, 160] - max_seq_len: 40 + max_seq_len: *max_text_length width_downsample_ratio: 0.25 - KeepKeys: keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order