diff --git a/configs/rec/rec_r31_robustscanner.yml b/configs/rec/rec_r31_robustscanner.yml
index 4d2ae57b79ca9a34b8137d82c0a2293e2273fe74..3e1d3009ce42981a191afbbc9375b9898a9c3004 100644
--- a/configs/rec/rec_r31_robustscanner.yml
+++ b/configs/rec/rec_r31_robustscanner.yml
@@ -15,7 +15,7 @@ Global:
infer_img: ./inference/rec_inference
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
- max_text_length: 40
+ max_text_length: &max_text_length 40
infer_mode: False
use_space_char: False
rm_symbol: True
@@ -38,7 +38,7 @@ Architecture:
algorithm: RobustScanner
Transform:
Backbone:
- name: ResNet31V2
+ name: ResNet31
Head:
name: RobustScannerHead
enc_outchannles: 128
@@ -49,7 +49,7 @@ Architecture:
mask: True
padding_idx: 92
encode_value: False
- max_seq_len: 40
+ max_text_length: *max_text_length
Loss:
name: SARLoss
@@ -64,8 +64,9 @@ Metric:
Train:
dataset:
- name: LMDBDataSet
- data_dir: I:/dataset/OCR/deep_text_recognition/data_lmdb/evaluation/CUTE80
+ name: SimpleDataSet
+ label_file_list: ['./train_data/train_list.txt']
+ data_dir: ./train_data/
transforms:
- DecodeImage: # load image
img_mode: BGR
@@ -74,20 +75,20 @@ Train:
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
- max_seq_len: 40
+ max_text_length: *max_text_length
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: True
- batch_size_per_card: 4
+ batch_size_per_card: 64
drop_last: True
- num_workers: 0
+ num_workers: 8
use_shared_memory: False
Eval:
dataset:
name: LMDBDataSet
- data_dir: I:/dataset/OCR/deep_text_recognition/data_lmdb/evaluation/CUTE80
+ data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
@@ -95,14 +96,14 @@ Eval:
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160]
- max_seq_len: 40
+ max_seq_len: *max_text_length
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
- batch_size_per_card: 1
- num_workers: 0
+ batch_size_per_card: 64
+ num_workers: 4
use_shared_memory: False
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 056d05ed413fd54597ba680a5206304eb5a3989a..c32d5a77b64599e7817f620a2dec6b1724162ec0 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -85,7 +85,7 @@
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
-|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | [训练模型]() |
+|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | coming soon |
diff --git a/doc/doc_ch/algorithm_rec_robustscanner.md b/doc/doc_ch/algorithm_rec_robustscanner.md
index f504cf5f6e2a1e580cf092ac07daee80396097e0..869f9a7c00b617de87ab3c96326e18e536bc18a8 100644
--- a/doc/doc_ch/algorithm_rec_robustscanner.md
+++ b/doc/doc_ch/algorithm_rec_robustscanner.md
@@ -26,7 +26,7 @@ Zhang
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
-|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[训练模型]()|
+|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。
@@ -71,7 +71,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pr
### 4.1 Python推理
-首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。( [模型下载地址]() ),可以使用如下命令进行转换:
+首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
@@ -85,7 +85,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png"
### 4.2 C++推理
-由于C++预处理后处理还未支持SAR,所以暂未支持
+由于C++预处理后处理还未支持RobustScanner,所以暂未支持
### 4.3 Serving服务化部署
@@ -104,11 +104,10 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png"
## 引用
```bibtex
-@article{Li2019ShowAA,
- title={Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition},
- author={Hui Li and Peng Wang and Chunhua Shen and Guyu Zhang},
- journal={ArXiv},
- year={2019},
- volume={abs/1811.00751}
+@article{2020RobustScanner,
+ title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
+ author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
+ journal={ECCV2020},
+ year={2020},
}
```
diff --git a/doc/doc_en/algorithm_rec_robustscanner_en.md b/doc/doc_en/algorithm_rec_robustscanner_en.md
index 9b6c772ac41adfa2c81e0b3f960c858ba1d17d7f..a5454476de9bfdbdb099157d1fa314be09e9bc0a 100644
--- a/doc/doc_en/algorithm_rec_robustscanner_en.md
+++ b/doc/doc_en/algorithm_rec_robustscanner_en.md
@@ -1,4 +1,4 @@
-# SAR
+# RobustScanner
- [1. Introduction](#1)
- [2. Environment](#2)
@@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
-|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[train model]()|
+|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper.
@@ -71,7 +71,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pr
### 4.1 Python Inference
-First, the model saved during the RobustScanner text recognition training process is converted into an inference model. ( [Model download link]() ), you can use the following command to convert:
+First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
@@ -105,11 +105,10 @@ Not supported
## Citation
```bibtex
-@article{Li2019ShowAA,
- title={Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition},
- author={Hui Li and Peng Wang and Chunhua Shen and Guyu Zhang},
- journal={ArXiv},
- year={2019},
- volume={abs/1811.00751}
+@article{2020RobustScanner,
+ title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
+ author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
+ journal={ECCV2020},
+ year={2020},
}
```
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index b97b78ce4402a96fb8cdf2117a68e377e5d7b1ca..aa4523329d3881a4cadb185f00beea38bf109cd3 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -268,16 +268,16 @@ class PRENResizeImg(object):
return data
class RobustScannerRecResizeImg(object):
- def __init__(self, image_shape, max_seq_len, width_downsample_ratio=0.25, **kwargs):
+ def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
- self.max_seq_len = max_seq_len
+ self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio)
- word_positons = robustscanner_other_inputs(self.max_seq_len)
+ word_positons = np.array(range(0, self.max_text_length)).astype('int64')
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
@@ -429,9 +429,6 @@ def srn_other_inputs(image_shape, num_heads, max_text_length):
gsrm_slf_attn_bias2
]
-def robustscanner_other_inputs(max_text_length):
- word_pos = np.array(range(0, max_text_length)).astype('int64')
- return word_pos
def flag():
"""
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index a90051f1dc52be8a62bada9ddd8b6e7f6ff6f1a8..0cc894dbfbb44e7433d9e07a41ce2b9f5a6f4bca 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -28,7 +28,6 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
- from .rec_resnet_31_v2 import ResNet31V2
from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
diff --git a/ppocr/modeling/backbones/rec_resnet_31.py b/ppocr/modeling/backbones/rec_resnet_31.py
index 965170138d00a53fca720b3b5f535a3dd34272d9..e1d77c405dfed1541f6a0197af14d4edeb908803 100644
--- a/ppocr/modeling/backbones/rec_resnet_31.py
+++ b/ppocr/modeling/backbones/rec_resnet_31.py
@@ -27,9 +27,12 @@ import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
-__all__ = ["ResNet31"]
+__all__ = ["ResNet31V2"]
+conv_weight_attr = nn.initializer.KaimingNormal()
+bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1)
+
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2D(
in_channel,
@@ -37,6 +40,7 @@ def conv3x3(in_channel, out_channel, stride=1):
kernel_size=3,
stride=stride,
padding=1,
+ weight_attr=conv_weight_attr,
bias_attr=False)
@@ -46,10 +50,10 @@ class BasicBlock(nn.Layer):
def __init__(self, in_channels, channels, stride=1, downsample=False):
super().__init__()
self.conv1 = conv3x3(in_channels, channels, stride)
- self.bn1 = nn.BatchNorm2D(channels)
+ self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.relu = nn.ReLU()
self.conv2 = conv3x3(channels, channels)
- self.bn2 = nn.BatchNorm2D(channels)
+ self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.downsample = downsample
if downsample:
self.downsample = nn.Sequential(
@@ -58,8 +62,9 @@ class BasicBlock(nn.Layer):
channels * self.expansion,
1,
stride,
+ weight_attr=conv_weight_attr,
bias_attr=False),
- nn.BatchNorm2D(channels * self.expansion), )
+ nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr))
else:
self.downsample = nn.Sequential()
self.stride = stride
@@ -108,13 +113,13 @@ class ResNet31(nn.Layer):
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(
- in_channels, channels[0], kernel_size=3, stride=1, padding=1)
- self.bn1_1 = nn.BatchNorm2D(channels[0])
+ in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr)
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(
- channels[0], channels[1], kernel_size=3, stride=1, padding=1)
- self.bn1_2 = nn.BatchNorm2D(channels[1])
+ channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr)
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
@@ -122,8 +127,8 @@ class ResNet31(nn.Layer):
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.conv2 = nn.Conv2D(
- channels[2], channels[2], kernel_size=3, stride=1, padding=1)
- self.bn2 = nn.BatchNorm2D(channels[2])
+ channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr)
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
@@ -131,8 +136,8 @@ class ResNet31(nn.Layer):
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.conv3 = nn.Conv2D(
- channels[3], channels[3], kernel_size=3, stride=1, padding=1)
- self.bn3 = nn.BatchNorm2D(channels[3])
+ channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr)
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
@@ -140,8 +145,8 @@ class ResNet31(nn.Layer):
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.conv4 = nn.Conv2D(
- channels[4], channels[4], kernel_size=3, stride=1, padding=1)
- self.bn4 = nn.BatchNorm2D(channels[4])
+ channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr)
self.relu4 = nn.ReLU()
# conv 5 ((Max-pooling), Residual block, Conv)
@@ -151,8 +156,8 @@ class ResNet31(nn.Layer):
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.conv5 = nn.Conv2D(
- channels[5], channels[5], kernel_size=3, stride=1, padding=1)
- self.bn5 = nn.BatchNorm2D(channels[5])
+ channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr)
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
@@ -168,8 +173,9 @@ class ResNet31(nn.Layer):
output_channels,
kernel_size=1,
stride=1,
+ weight_attr=conv_weight_attr,
bias_attr=False),
- nn.BatchNorm2D(output_channels), )
+ nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr))
layers.append(
BasicBlock(
diff --git a/ppocr/modeling/backbones/rec_resnet_31_v2.py b/ppocr/modeling/backbones/rec_resnet_31_v2.py
deleted file mode 100644
index 7812b6296e33fc1f193dab88a7df788a6aa581d3..0000000000000000000000000000000000000000
--- a/ppocr/modeling/backbones/rec_resnet_31_v2.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This code is refer from:
-https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py
-https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import paddle
-from paddle import ParamAttr
-import paddle.nn as nn
-import paddle.nn.functional as F
-import numpy as np
-
-__all__ = ["ResNet31V2"]
-
-
-conv_weight_attr = nn.initializer.KaimingNormal()
-bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1)
-
-def conv3x3(in_channel, out_channel, stride=1):
- return nn.Conv2D(
- in_channel,
- out_channel,
- kernel_size=3,
- stride=stride,
- padding=1,
- weight_attr=conv_weight_attr,
- bias_attr=False)
-
-
-class BasicBlock(nn.Layer):
- expansion = 1
-
- def __init__(self, in_channels, channels, stride=1, downsample=False):
- super().__init__()
- self.conv1 = conv3x3(in_channels, channels, stride)
- self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
- self.relu = nn.ReLU()
- self.conv2 = conv3x3(channels, channels)
- self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
- self.downsample = downsample
- if downsample:
- self.downsample = nn.Sequential(
- nn.Conv2D(
- in_channels,
- channels * self.expansion,
- 1,
- stride,
- weight_attr=conv_weight_attr,
- bias_attr=False),
- nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr))
- else:
- self.downsample = nn.Sequential()
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-
-class ResNet31V2(nn.Layer):
- '''
- Args:
- in_channels (int): Number of channels of input image tensor.
- layers (list[int]): List of BasicBlock number for each stage.
- channels (list[int]): List of out_channels of Conv2d layer.
- out_indices (None | Sequence[int]): Indices of output stages.
- last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
- '''
-
- def __init__(self,
- in_channels=3,
- layers=[1, 2, 5, 3],
- channels=[64, 128, 256, 256, 512, 512, 512],
- out_indices=None,
- last_stage_pool=False):
- super(ResNet31V2, self).__init__()
- assert isinstance(in_channels, int)
- assert isinstance(last_stage_pool, bool)
-
- self.out_indices = out_indices
- self.last_stage_pool = last_stage_pool
-
- # conv 1 (Conv Conv)
- self.conv1_1 = nn.Conv2D(
- in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
- self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr)
- self.relu1_1 = nn.ReLU()
-
- self.conv1_2 = nn.Conv2D(
- channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
- self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr)
- self.relu1_2 = nn.ReLU()
-
- # conv 2 (Max-pooling, Residual block, Conv)
- self.pool2 = nn.MaxPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
- self.block2 = self._make_layer(channels[1], channels[2], layers[0])
- self.conv2 = nn.Conv2D(
- channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
- self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr)
- self.relu2 = nn.ReLU()
-
- # conv 3 (Max-pooling, Residual block, Conv)
- self.pool3 = nn.MaxPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
- self.block3 = self._make_layer(channels[2], channels[3], layers[1])
- self.conv3 = nn.Conv2D(
- channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
- self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr)
- self.relu3 = nn.ReLU()
-
- # conv 4 (Max-pooling, Residual block, Conv)
- self.pool4 = nn.MaxPool2D(
- kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
- self.block4 = self._make_layer(channels[3], channels[4], layers[2])
- self.conv4 = nn.Conv2D(
- channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
- self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr)
- self.relu4 = nn.ReLU()
-
- # conv 5 ((Max-pooling), Residual block, Conv)
- self.pool5 = None
- if self.last_stage_pool:
- self.pool5 = nn.MaxPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
- self.block5 = self._make_layer(channels[4], channels[5], layers[3])
- self.conv5 = nn.Conv2D(
- channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
- self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr)
- self.relu5 = nn.ReLU()
-
- self.out_channels = channels[-1]
-
- def _make_layer(self, input_channels, output_channels, blocks):
- layers = []
- for _ in range(blocks):
- downsample = None
- if input_channels != output_channels:
- downsample = nn.Sequential(
- nn.Conv2D(
- input_channels,
- output_channels,
- kernel_size=1,
- stride=1,
- weight_attr=conv_weight_attr,
- bias_attr=False),
- nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr))
-
- layers.append(
- BasicBlock(
- input_channels, output_channels, downsample=downsample))
- input_channels = output_channels
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = self.conv1_1(x)
- x = self.bn1_1(x)
- x = self.relu1_1(x)
-
- x = self.conv1_2(x)
- x = self.bn1_2(x)
- x = self.relu1_2(x)
-
- outs = []
- for i in range(4):
- layer_index = i + 2
- pool_layer = getattr(self, f'pool{layer_index}')
- block_layer = getattr(self, f'block{layer_index}')
- conv_layer = getattr(self, f'conv{layer_index}')
- bn_layer = getattr(self, f'bn{layer_index}')
- relu_layer = getattr(self, f'relu{layer_index}')
-
- if pool_layer is not None:
- x = pool_layer(x)
- x = block_layer(x)
- x = conv_layer(x)
- x = bn_layer(x)
- x = relu_layer(x)
-
- outs.append(x)
-
- if self.out_indices is not None:
- return tuple([outs[i] for i in self.out_indices])
-
- return x
diff --git a/ppocr/modeling/heads/rec_robustscanner_head.py b/ppocr/modeling/heads/rec_robustscanner_head.py
index b458937978fc4a39f1f14230dff9d8d4fba4cfd9..fc889d59cb558adaf0b422af27fec65db4d8d63f 100644
--- a/ppocr/modeling/heads/rec_robustscanner_head.py
+++ b/ppocr/modeling/heads/rec_robustscanner_head.py
@@ -217,18 +217,7 @@ class SequenceAttentionDecoder(BaseDecoder):
else:
value = paddle.reshape(feat, [n, c_feat, h * w])
- # mask = None
- # if valid_ratios is not None:
- # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool')
- # for i, valid_ratio in enumerate(valid_ratios):
- # valid_width = min(w, math.ceil(w * valid_ratio))
- # if valid_width < w:
- # mask[i, :, :, valid_width:] = True
- # # mask = mask.view(n, h * w)
- # mask = paddle.reshape(mask, (n, len_q, h * w))
-
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
- # attn_out = attn_out.permute(0, 2, 1).contiguous()
attn_out = paddle.transpose(attn_out, (0, 2, 1))
if self.return_feature:
@@ -253,8 +242,6 @@ class SequenceAttentionDecoder(BaseDecoder):
seq_len = self.max_seq_len
batch_size = feat.shape[0]
- # decode_sequence = (feat.new_ones(
- # (batch_size, seq_len)) * self.start_idx).long()
decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
outputs = []
@@ -303,20 +290,8 @@ class SequenceAttentionDecoder(BaseDecoder):
value = key
else:
value = paddle.reshape(feat, [n, c_feat, h * w])
- # len_q = query.shape[2]
- # mask = None
- # if valid_ratios is not None:
- # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool')
- # for i, valid_ratio in enumerate(valid_ratios):
- # valid_width = min(w, math.ceil(w * valid_ratio))
- # if valid_width < w:
- # mask[i, :, :, valid_width:] = True
- # # mask = mask.view(n, h * w)
- # mask = paddle.reshape(mask, (n, len_q, h * w))
# [n, c, l]
- # attn_out = self.attention_layer(query, key, value, mask)
-
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
out = attn_out[:, :, current_step]
@@ -445,7 +420,6 @@ class PositionAttentionDecoder(BaseDecoder):
before the prediction projection layer, whose shape is
:math:`(N, T, D_m)`.
"""
- #
n, c_enc, h, w = out_enc.shape
assert c_enc == self.dim_model
_, c_feat, _, _ = feat.shape
@@ -453,8 +427,6 @@ class PositionAttentionDecoder(BaseDecoder):
_, len_q = targets.shape
assert len_q <= self.max_seq_len
- # position_index = self._get_position_index(len_q, n)
-
position_out_enc = self.position_aware_module(out_enc)
query = self.embedding(position_index)
@@ -465,16 +437,6 @@ class PositionAttentionDecoder(BaseDecoder):
else:
value = paddle.reshape(feat,(n, c_feat, h * w))
- # mask = None
- # if valid_ratios is not None:
- # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool')
- # for i, valid_ratio in enumerate(valid_ratios):
- # valid_width = min(w, math.ceil(w * valid_ratio))
- # if valid_width < w:
- # mask[i, :, :, valid_width:] = True
- # # mask = mask.view(n, h * w)
- # mask = paddle.reshape(mask, (n, len_q, h * w))
-
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
@@ -498,7 +460,6 @@ class PositionAttentionDecoder(BaseDecoder):
before the prediction projection layer, whose shape is
:math:`(N, T, D_m)`.
"""
- # seq_len = self.max_seq_len
n, c_enc, h, w = out_enc.shape
assert c_enc == self.dim_model
_, c_feat, _, _ = feat.shape
@@ -516,16 +477,6 @@ class PositionAttentionDecoder(BaseDecoder):
value = paddle.reshape(out_enc,(n, c_enc, h * w))
else:
value = paddle.reshape(feat,(n, c_feat, h * w))
- # len_q = query.shape[2]
- # mask = None
- # if valid_ratios is not None:
- # mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool')
- # for i, valid_ratio in enumerate(valid_ratios):
- # valid_width = min(w, math.ceil(w * valid_ratio))
- # if valid_width < w:
- # mask[i, :, :, valid_width:] = True
- # # mask = mask.view(n, h * w)
- # mask = paddle.reshape(mask, (n, len_q, h * w))
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
@@ -676,9 +627,6 @@ class RobustScannerDecoder(BaseDecoder):
seq_len = self.max_seq_len
batch_size = feat.shape[0]
- # decode_sequence = (feat.new_ones(
- # (batch_size, seq_len)) * self.start_idx).long()
-
decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
position_glimpse = self.position_decoder.forward_test(
@@ -712,7 +660,7 @@ class RobustScannerHead(nn.Layer):
hybrid_dec_dropout=0,
position_dec_rnn_layers=2,
start_idx=0,
- max_seq_len=40,
+ max_text_length=40,
mask=True,
padding_idx=None,
encode_value=False,
@@ -731,7 +679,7 @@ class RobustScannerHead(nn.Layer):
hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
hybrid_decoder_dropout=hybrid_dec_dropout,
position_decoder_rnn_layers=position_dec_rnn_layers,
- max_seq_len=max_seq_len,
+ max_seq_len=max_text_length,
start_idx=start_idx,
mask=mask,
padding_idx=padding_idx,
diff --git a/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml
index 20ec9be9662c1fbf13f20ca5cc6451b8ad8a5da6..a49f332aae0eeb1b7a9d728d37095e5d88c3e6f5 100644
--- a/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml
+++ b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml
@@ -15,7 +15,7 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
- max_text_length: 40
+ max_text_length: &max_text_length 40
infer_mode: False
use_space_char: False
rm_symbol: True
@@ -38,7 +38,7 @@ Architecture:
algorithm: RobustScanner
Transform:
Backbone:
- name: ResNet31V2
+ name: ResNet31
Head:
name: RobustScannerHead
enc_outchannles: 128
@@ -75,7 +75,7 @@ Train:
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
- max_seq_len: 40
+ max_seq_len: *max_text_length
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
@@ -97,7 +97,7 @@ Eval:
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160]
- max_seq_len: 40
+ max_seq_len: *max_text_length
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order