提交 47cb2749 编写于 作者: xuyang2233's avatar xuyang2233

update pr

上级 a2ef524f
......@@ -15,7 +15,7 @@ Global:
infer_img: ./inference/rec_inference
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: 40
max_text_length: &max_text_length 40
infer_mode: False
use_space_char: False
rm_symbol: True
......@@ -38,7 +38,7 @@ Architecture:
algorithm: RobustScanner
Transform:
Backbone:
name: ResNet31V2
name: ResNet31
Head:
name: RobustScannerHead
enc_outchannles: 128
......@@ -49,7 +49,7 @@ Architecture:
mask: True
padding_idx: 92
encode_value: False
max_seq_len: 40
max_text_length: *max_text_length
Loss:
name: SARLoss
......@@ -64,8 +64,9 @@ Metric:
Train:
dataset:
name: LMDBDataSet
data_dir: I:/dataset/OCR/deep_text_recognition/data_lmdb/evaluation/CUTE80
name: SimpleDataSet
label_file_list: ['./train_data/train_list.txt']
data_dir: ./train_data/
transforms:
- DecodeImage: # load image
img_mode: BGR
......@@ -74,20 +75,20 @@ Train:
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
max_seq_len: 40
max_text_length: *max_text_length
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 4
batch_size_per_card: 64
drop_last: True
num_workers: 0
num_workers: 8
use_shared_memory: False
Eval:
dataset:
name: LMDBDataSet
data_dir: I:/dataset/OCR/deep_text_recognition/data_lmdb/evaluation/CUTE80
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
......@@ -95,14 +96,14 @@ Eval:
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160]
max_seq_len: 40
max_seq_len: *max_text_length
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1
num_workers: 0
batch_size_per_card: 64
num_workers: 4
use_shared_memory: False
......@@ -85,7 +85,7 @@
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | [训练模型]() |
|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | coming soon |
<a name="2"></a>
......
......@@ -26,7 +26,7 @@ Zhang
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[训练模型]()|
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。
......@@ -71,7 +71,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pr
<a name="4-1"></a>
### 4.1 Python推理
首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。[模型下载地址]() )可以使用如下命令进行转换:
首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
......@@ -85,7 +85,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png"
<a name="4-2"></a>
### 4.2 C++推理
由于C++预处理后处理还未支持SAR,所以暂未支持
由于C++预处理后处理还未支持RobustScanner,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
......@@ -104,11 +104,10 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png"
## 引用
```bibtex
@article{Li2019ShowAA,
title={Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition},
author={Hui Li and Peng Wang and Chunhua Shen and Guyu Zhang},
journal={ArXiv},
year={2019},
volume={abs/1811.00751}
@article{2020RobustScanner,
title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
journal={ECCV2020},
year={2020},
}
```
# SAR
# RobustScanner
- [1. Introduction](#1)
- [2. Environment](#2)
......@@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[train model]()|
|RobustScanner|ResNet31V2|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper.
......@@ -71,7 +71,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pr
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the RobustScanner text recognition training process is converted into an inference model. ( [Model download link]() ), you can use the following command to convert:
First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
......@@ -105,11 +105,10 @@ Not supported
## Citation
```bibtex
@article{Li2019ShowAA,
title={Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition},
author={Hui Li and Peng Wang and Chunhua Shen and Guyu Zhang},
journal={ArXiv},
year={2019},
volume={abs/1811.00751}
@article{2020RobustScanner,
title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
journal={ECCV2020},
year={2020},
}
```
......@@ -268,16 +268,16 @@ class PRENResizeImg(object):
return data
class RobustScannerRecResizeImg(object):
def __init__(self, image_shape, max_seq_len, width_downsample_ratio=0.25, **kwargs):
def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
self.max_seq_len = max_seq_len
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio)
word_positons = robustscanner_other_inputs(self.max_seq_len)
word_positons = np.array(range(0, self.max_text_length)).astype('int64')
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
......@@ -429,9 +429,6 @@ def srn_other_inputs(image_shape, num_heads, max_text_length):
gsrm_slf_attn_bias2
]
def robustscanner_other_inputs(max_text_length):
word_pos = np.array(range(0, max_text_length)).astype('int64')
return word_pos
def flag():
"""
......
......@@ -28,7 +28,6 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
from .rec_resnet_31_v2 import ResNet31V2
from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
......
......@@ -27,9 +27,12 @@ import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
__all__ = ["ResNet31"]
__all__ = ["ResNet31V2"]
conv_weight_attr = nn.initializer.KaimingNormal()
bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1)
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2D(
in_channel,
......@@ -37,6 +40,7 @@ def conv3x3(in_channel, out_channel, stride=1):
kernel_size=3,
stride=stride,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False)
......@@ -46,10 +50,10 @@ class BasicBlock(nn.Layer):
def __init__(self, in_channels, channels, stride=1, downsample=False):
super().__init__()
self.conv1 = conv3x3(in_channels, channels, stride)
self.bn1 = nn.BatchNorm2D(channels)
self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.relu = nn.ReLU()
self.conv2 = conv3x3(channels, channels)
self.bn2 = nn.BatchNorm2D(channels)
self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.downsample = downsample
if downsample:
self.downsample = nn.Sequential(
......@@ -58,8 +62,9 @@ class BasicBlock(nn.Layer):
channels * self.expansion,
1,
stride,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(channels * self.expansion), )
nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr))
else:
self.downsample = nn.Sequential()
self.stride = stride
......@@ -108,13 +113,13 @@ class ResNet31(nn.Layer):
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(
in_channels, channels[0], kernel_size=3, stride=1, padding=1)
self.bn1_1 = nn.BatchNorm2D(channels[0])
in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr)
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
self.bn1_2 = nn.BatchNorm2D(channels[1])
channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr)
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
......@@ -122,8 +127,8 @@ class ResNet31(nn.Layer):
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.conv2 = nn.Conv2D(
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(channels[2])
channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr)
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
......@@ -131,8 +136,8 @@ class ResNet31(nn.Layer):
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.conv3 = nn.Conv2D(
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2D(channels[3])
channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr)
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
......@@ -140,8 +145,8 @@ class ResNet31(nn.Layer):
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.conv4 = nn.Conv2D(
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2D(channels[4])
channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr)
self.relu4 = nn.ReLU()
# conv 5 ((Max-pooling), Residual block, Conv)
......@@ -151,8 +156,8 @@ class ResNet31(nn.Layer):
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.conv5 = nn.Conv2D(
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2D(channels[5])
channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr)
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
......@@ -168,8 +173,9 @@ class ResNet31(nn.Layer):
output_channels,
kernel_size=1,
stride=1,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(output_channels), )
nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr))
layers.append(
BasicBlock(
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
__all__ = ["ResNet31V2"]
conv_weight_attr = nn.initializer.KaimingNormal()
bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1)
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2D(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, in_channels, channels, stride=1, downsample=False):
super().__init__()
self.conv1 = conv3x3(in_channels, channels, stride)
self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.relu = nn.ReLU()
self.conv2 = conv3x3(channels, channels)
self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.downsample = downsample
if downsample:
self.downsample = nn.Sequential(
nn.Conv2D(
in_channels,
channels * self.expansion,
1,
stride,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr))
else:
self.downsample = nn.Sequential()
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet31V2(nn.Layer):
'''
Args:
in_channels (int): Number of channels of input image tensor.
layers (list[int]): List of BasicBlock number for each stage.
channels (list[int]): List of out_channels of Conv2d layer.
out_indices (None | Sequence[int]): Indices of output stages.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
'''
def __init__(self,
in_channels=3,
layers=[1, 2, 5, 3],
channels=[64, 128, 256, 256, 512, 512, 512],
out_indices=None,
last_stage_pool=False):
super(ResNet31V2, self).__init__()
assert isinstance(in_channels, int)
assert isinstance(last_stage_pool, bool)
self.out_indices = out_indices
self.last_stage_pool = last_stage_pool
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(
in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr)
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(
channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr)
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
self.pool2 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.conv2 = nn.Conv2D(
channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr)
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
self.pool3 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.conv3 = nn.Conv2D(
channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr)
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
self.pool4 = nn.MaxPool2D(
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.conv4 = nn.Conv2D(
channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr)
self.relu4 = nn.ReLU()
# conv 5 ((Max-pooling), Residual block, Conv)
self.pool5 = None
if self.last_stage_pool:
self.pool5 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.conv5 = nn.Conv2D(
channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr)
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
def _make_layer(self, input_channels, output_channels, blocks):
layers = []
for _ in range(blocks):
downsample = None
if input_channels != output_channels:
downsample = nn.Sequential(
nn.Conv2D(
input_channels,
output_channels,
kernel_size=1,
stride=1,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr))
layers.append(
BasicBlock(
input_channels, output_channels, downsample=downsample))
input_channels = output_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu1_1(x)
x = self.conv1_2(x)
x = self.bn1_2(x)
x = self.relu1_2(x)
outs = []
for i in range(4):
layer_index = i + 2
pool_layer = getattr(self, f'pool{layer_index}')
block_layer = getattr(self, f'block{layer_index}')
conv_layer = getattr(self, f'conv{layer_index}')
bn_layer = getattr(self, f'bn{layer_index}')
relu_layer = getattr(self, f'relu{layer_index}')
if pool_layer is not None:
x = pool_layer(x)
x = block_layer(x)
x = conv_layer(x)
x = bn_layer(x)
x = relu_layer(x)
outs.append(x)
if self.out_indices is not None:
return tuple([outs[i] for i in self.out_indices])
return x
......@@ -217,18 +217,7 @@ class SequenceAttentionDecoder(BaseDecoder):
else:
value = paddle.reshape(feat, [n, c_feat, h * w])
# mask = None
# if valid_ratios is not None:
# mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool')
# for i, valid_ratio in enumerate(valid_ratios):
# valid_width = min(w, math.ceil(w * valid_ratio))
# if valid_width < w:
# mask[i, :, :, valid_width:] = True
# # mask = mask.view(n, h * w)
# mask = paddle.reshape(mask, (n, len_q, h * w))
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
# attn_out = attn_out.permute(0, 2, 1).contiguous()
attn_out = paddle.transpose(attn_out, (0, 2, 1))
if self.return_feature:
......@@ -253,8 +242,6 @@ class SequenceAttentionDecoder(BaseDecoder):
seq_len = self.max_seq_len
batch_size = feat.shape[0]
# decode_sequence = (feat.new_ones(
# (batch_size, seq_len)) * self.start_idx).long()
decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
outputs = []
......@@ -303,20 +290,8 @@ class SequenceAttentionDecoder(BaseDecoder):
value = key
else:
value = paddle.reshape(feat, [n, c_feat, h * w])
# len_q = query.shape[2]
# mask = None
# if valid_ratios is not None:
# mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool')
# for i, valid_ratio in enumerate(valid_ratios):
# valid_width = min(w, math.ceil(w * valid_ratio))
# if valid_width < w:
# mask[i, :, :, valid_width:] = True
# # mask = mask.view(n, h * w)
# mask = paddle.reshape(mask, (n, len_q, h * w))
# [n, c, l]
# attn_out = self.attention_layer(query, key, value, mask)
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
out = attn_out[:, :, current_step]
......@@ -445,7 +420,6 @@ class PositionAttentionDecoder(BaseDecoder):
before the prediction projection layer, whose shape is
:math:`(N, T, D_m)`.
"""
#
n, c_enc, h, w = out_enc.shape
assert c_enc == self.dim_model
_, c_feat, _, _ = feat.shape
......@@ -453,8 +427,6 @@ class PositionAttentionDecoder(BaseDecoder):
_, len_q = targets.shape
assert len_q <= self.max_seq_len
# position_index = self._get_position_index(len_q, n)
position_out_enc = self.position_aware_module(out_enc)
query = self.embedding(position_index)
......@@ -465,16 +437,6 @@ class PositionAttentionDecoder(BaseDecoder):
else:
value = paddle.reshape(feat,(n, c_feat, h * w))
# mask = None
# if valid_ratios is not None:
# mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool')
# for i, valid_ratio in enumerate(valid_ratios):
# valid_width = min(w, math.ceil(w * valid_ratio))
# if valid_width < w:
# mask[i, :, :, valid_width:] = True
# # mask = mask.view(n, h * w)
# mask = paddle.reshape(mask, (n, len_q, h * w))
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
......@@ -498,7 +460,6 @@ class PositionAttentionDecoder(BaseDecoder):
before the prediction projection layer, whose shape is
:math:`(N, T, D_m)`.
"""
# seq_len = self.max_seq_len
n, c_enc, h, w = out_enc.shape
assert c_enc == self.dim_model
_, c_feat, _, _ = feat.shape
......@@ -516,16 +477,6 @@ class PositionAttentionDecoder(BaseDecoder):
value = paddle.reshape(out_enc,(n, c_enc, h * w))
else:
value = paddle.reshape(feat,(n, c_feat, h * w))
# len_q = query.shape[2]
# mask = None
# if valid_ratios is not None:
# mask = paddle.zeros(shape=[n, len_q, h, w], dtype='bool')
# for i, valid_ratio in enumerate(valid_ratios):
# valid_width = min(w, math.ceil(w * valid_ratio))
# if valid_width < w:
# mask[i, :, :, valid_width:] = True
# # mask = mask.view(n, h * w)
# mask = paddle.reshape(mask, (n, len_q, h * w))
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
......@@ -676,9 +627,6 @@ class RobustScannerDecoder(BaseDecoder):
seq_len = self.max_seq_len
batch_size = feat.shape[0]
# decode_sequence = (feat.new_ones(
# (batch_size, seq_len)) * self.start_idx).long()
decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
position_glimpse = self.position_decoder.forward_test(
......@@ -712,7 +660,7 @@ class RobustScannerHead(nn.Layer):
hybrid_dec_dropout=0,
position_dec_rnn_layers=2,
start_idx=0,
max_seq_len=40,
max_text_length=40,
mask=True,
padding_idx=None,
encode_value=False,
......@@ -731,7 +679,7 @@ class RobustScannerHead(nn.Layer):
hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
hybrid_decoder_dropout=hybrid_dec_dropout,
position_decoder_rnn_layers=position_dec_rnn_layers,
max_seq_len=max_seq_len,
max_seq_len=max_text_length,
start_idx=start_idx,
mask=mask,
padding_idx=padding_idx,
......
......@@ -15,7 +15,7 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: 40
max_text_length: &max_text_length 40
infer_mode: False
use_space_char: False
rm_symbol: True
......@@ -38,7 +38,7 @@ Architecture:
algorithm: RobustScanner
Transform:
Backbone:
name: ResNet31V2
name: ResNet31
Head:
name: RobustScannerHead
enc_outchannles: 128
......@@ -75,7 +75,7 @@ Train:
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
max_seq_len: 40
max_seq_len: *max_text_length
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
......@@ -97,7 +97,7 @@ Eval:
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160]
max_seq_len: 40
max_seq_len: *max_text_length
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册