提交 c503dc2f 编写于 作者: T Topdu

[New Rec] add vitstr and ABINet

上级 9a717433
......@@ -82,7 +82,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: BGR
......@@ -97,5 +97,5 @@ Eval:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 1
num_workers: 4
use_shared_memory: False
Global:
use_gpu: True
epoch_num: 10
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/r45_abinet/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_abinet.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.99
clip_norm: 20.0
lr:
name: Piecewise
decay_epochs: [6]
values: [0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0.
Architecture:
model_type: rec
algorithm: ABINet
in_channels: 3
Transform:
Backbone:
name: ResNet45
Head:
name: ABINetHead
use_lang: True
iter_size: 3
Loss:
name: CELoss
ignore_index: &ignore_index 100 # Must be greater than the number of character classes
PostProcess:
name: ABINetLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 128]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 96
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 128]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 4
use_shared_memory: False
......@@ -26,7 +26,7 @@ Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 0.00000008
epsilon: 8.e-8
weight_decay: 0.05
no_weight_decay_name: norm pos_embed
one_dim_param_no_weight_decay: true
......
......@@ -6,7 +6,7 @@ Global:
save_model_dir: ./output/rec/vitstr_none_ce/
save_epoch_step: 1
# evaluation is run every 2000 iterations after the 0th iteration#
eval_batch_step: [0, 50]
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
......@@ -23,7 +23,7 @@ Global:
Optimizer:
name: Adadelta
epsilon: 0.00000001
epsilon: 1.e-8
rho: 0.95
clip_norm: 5.0
lr:
......@@ -45,8 +45,8 @@ Architecture:
Loss:
name: CELoss
smoothing: False
with_all: True
ignore_index: &ignore_index 0 # Must be zero or greater than the number of character classes
PostProcess:
name: ViTSTRLabelDecode
......@@ -58,12 +58,13 @@ Metric:
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ViTSTRLabelEncode: # Class handling label
ignore_index: *ignore_index
- GrayRecResizeImg:
image_shape: [224, 224] # W H
resize_type: PIL # PIL or OpenCV
......@@ -80,12 +81,13 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ViTSTRLabelEncode: # Class handling label
ignore_index: *ignore_index
- GrayRecResizeImg:
image_shape: [224, 224] # W H
resize_type: PIL # PIL or OpenCV
......
......@@ -67,6 +67,7 @@
- [x] [SEED](./algorithm_rec_seed.md)
- [x] [SVTR](./algorithm_rec_svtr.md)
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
- [x] [ABINet](./algorithm_rec_abinet.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
......@@ -86,6 +87,7 @@
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
<a name="2"></a>
......
# 场景文本识别算法-ABINet
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [ABINet: Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition](https://openaccess.thecvf.com/content/CVPR2021/papers/Fang_Read_Like_Humans_Autonomous_Bidirectional_and_Iterative_Language_Modeling_for_CVPR_2021_paper.pdf)
> Shancheng Fang and Hongtao Xie and Yuxin Wang and Zhendong Mao and Yongdong Zhang
> CVPR, 2021
<a name="model"></a>
`ABINet`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
|ABINet|ResNet45|[rec_r45_abinet.yml](../../configs/rec/rec_r45_abinet.yml)|90.75%|[训练模型]()/[预训练模型]|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
<a name="3-1"></a>
### 3.1 模型训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`ABINet`识别模型时需要**更换配置文件**`ABINet`[配置文件](../../configs/rec/rec_r45_abinet.yml)
#### 启动训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```shell
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_r45_abinet.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_abinet.yml
```
<a name="3-2"></a>
### 3.2 评估
可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r45_abinet.yml -o Global.pretrained_model=./rec_r45_abinet_train/best_accuracy
```
<a name="3-3"></a>
### 3.3 预测
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_r45_abinet.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_r45_abinet_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址]() ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_r45_abinet.yml -o Global.pretrained_model=./rec_r45_abinet_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_abinet/
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应ABINet的`infer_shape`
转换成功后,在目录下有三个文件:
```
/inference/rec_r45_abinet/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_r45_abinet/' --rec_algorithm='ABINet' --rec_image_shape='3,32,128' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt'
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
```
![](../imgs_words_en/word_10.png)
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
结果如下:
```shell
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999995231628418)
```
**注意**
- 训练上述模型采用的图像分辨率是[3,32,128],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中ABINet的预处理为您的预处理方法。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持ABINet,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
1. MJSynth和SynthText两种数据集来自于[ABINet源repo](https://github.com/FangShancheng/ABINet)
2. 我们使用ABINet作者提供的预训练模型进行finetune训练。
## 引用
```bibtex
@article{Fang2021ABINet,
title = {ABINet: Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition},
author = {Shancheng Fang and Hongtao Xie and Yuxin Wang and Zhendong Mao and Yongdong Zhang},
booktitle = {CVPR},
year = {2021},
url = {https://arxiv.org/abs/2103.06495},
pages = {7098-7107}
}
```
......@@ -12,6 +12,7 @@
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
- [6. 发行公告](#6)
<a name="1"></a>
## 1. 算法简介
......@@ -110,7 +111,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png'
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
结果如下:
```shell
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9465042352676392)
```
**注意**
......@@ -140,12 +141,147 @@ Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
1. `NRTR`论文中使用Beam搜索进行解码字符,但是速度较慢,这里默认未使用Beam搜索,以贪婪搜索进行解码字符。
<a name="6"></a>
## 6. 发行公告
1. release/2.6更新NRTR代码结构,新版NRTR可加载旧版(release/2.5及之前)模型参数,使用下面示例代码将旧版模型参数转换为新版模型参数:
```python
params = paddle.load('path/' + '.pdparams') # 旧版本参数
state_dict = model.state_dict() # 新版模型参数
new_state_dict = {}
for k1, v1 in state_dict.items():
k = k1
if 'encoder' in k and 'self_attn' in k and 'qkv' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
q = params[k_para.replace('qkv', 'conv1')].transpose((1, 0, 2, 3))
k = params[k_para.replace('qkv', 'conv2')].transpose((1, 0, 2, 3))
v = params[k_para.replace('qkv', 'conv3')].transpose((1, 0, 2, 3))
new_state_dict[k1] = np.concatenate([q[:, :, 0, 0], k[:, :, 0, 0], v[:, :, 0, 0]], -1)
elif 'encoder' in k and 'self_attn' in k and 'qkv' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
q = params[k_para.replace('qkv', 'conv1')]
k = params[k_para.replace('qkv', 'conv2')]
v = params[k_para.replace('qkv', 'conv3')]
new_state_dict[k1] = np.concatenate([q, k, v], -1)
elif 'encoder' in k and 'self_attn' in k and 'out_proj' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para]
elif 'encoder' in k and 'norm3' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para.replace('norm3', 'norm2')]
elif 'encoder' in k and 'norm1' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para]
elif 'decoder' in k and 'self_attn' in k and 'qkv' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
q = params[k_para.replace('qkv', 'conv1')].transpose((1, 0, 2, 3))
k = params[k_para.replace('qkv', 'conv2')].transpose((1, 0, 2, 3))
v = params[k_para.replace('qkv', 'conv3')].transpose((1, 0, 2, 3))
new_state_dict[k1] = np.concatenate([q[:, :, 0, 0], k[:, :, 0, 0], v[:, :, 0, 0]], -1)
elif 'decoder' in k and 'self_attn' in k and 'qkv' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
q = params[k_para.replace('qkv', 'conv1')]
k = params[k_para.replace('qkv', 'conv2')]
v = params[k_para.replace('qkv', 'conv3')]
new_state_dict[k1] = np.concatenate([q, k, v], -1)
elif 'decoder' in k and 'self_attn' in k and 'out_proj' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para]
elif 'decoder' in k and 'cross_attn' in k and 'q' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
q = params[k_para.replace('q', 'conv1')].transpose((1, 0, 2, 3))
new_state_dict[k1] = q[:, :, 0, 0]
elif 'decoder' in k and 'cross_attn' in k and 'q' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
q = params[k_para.replace('q', 'conv1')]
new_state_dict[k1] = q
elif 'decoder' in k and 'cross_attn' in k and 'kv' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
k = params[k_para.replace('kv', 'conv2')].transpose((1, 0, 2, 3))
v = params[k_para.replace('kv', 'conv3')].transpose((1, 0, 2, 3))
new_state_dict[k1] = np.concatenate([k[:, :, 0, 0], v[:, :, 0, 0]], -1)
elif 'decoder' in k and 'cross_attn' in k and 'kv' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
k = params[k_para.replace('kv', 'conv2')]
v = params[k_para.replace('kv', 'conv3')]
new_state_dict[k1] = np.concatenate([k, v], -1)
elif 'decoder' in k and 'cross_attn' in k and 'out_proj' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
new_state_dict[k1] = params[k_para]
elif 'decoder' in k and 'norm' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para]
elif 'mlp' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('fc', 'conv')
k_para = k_para.replace('mlp.', '')
w = params[k_para].transpose((1, 0, 2, 3))
new_state_dict[k1] = w[:, :, 0, 0]
elif 'mlp' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('fc', 'conv')
k_para = k_para.replace('mlp.', '')
w = params[k_para]
new_state_dict[k1] = w
else:
new_state_dict[k1] = params[k1]
if list(new_state_dict[k1].shape) != list(v1.shape):
print(k1)
for k, v1 in state_dict.items():
if k not in new_state_dict.keys():
print(1, k)
elif list(new_state_dict[k].shape) != list(v1.shape):
print(2, k)
model.set_state_dict(new_state_dict)
paddle.save(model.state_dict(), 'nrtrnew_from_old_params.pdparams')
```
2. 新版相比与旧版,代码结构简洁,推理速度有所提高。
## 引用
```bibtex
@article{Sheng2019NRTR,
title = {NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition},
author = {Fenfen Sheng and Zhineng Chen andBo Xu},
author = {Fenfen Sheng and Zhineng Chen and Bo Xu},
booktitle = {ICDAR},
year = {2019},
url = {http://arxiv.org/abs/1806.00926},
......
......@@ -66,6 +66,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SEED](./algorithm_rec_seed_en.md)
- [x] [SVTR](./algorithm_rec_svtr_en.md)
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
- [x] [ABINet](./algorithm_rec_abinet_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
......@@ -85,6 +86,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
<a name="2"></a>
......
# ABINet
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [ABINet: Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition](https://openaccess.thecvf.com/content/CVPR2021/papers/Fang_Read_Like_Humans_Autonomous_Bidirectional_and_Iterative_Language_Modeling_for_CVPR_2021_paper.pdf)
> Shancheng Fang and Hongtao Xie and Yuxin Wang and Zhendong Mao and Yongdong Zhang
> CVPR, 2021
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|ABINet|ResNet45|[rec_r45_abinet.yml](../../configs/rec/rec_r45_abinet.yml)|90.75%|[trained model]()/[pretrained model]()|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/rec/rec_r45_abinet.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_abinet.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r45_abinet.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_r45_abinet.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_r45_abinet_train/best_accuracy
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the ABINet text recognition training process is converted into an inference model. ( [Model download link]()) ), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_r45_abinet.yml -o Global.pretrained_model=./rec_r45_abinet_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_abinet
```
**Note:**
- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
- If you modified the input size during training, please modify the `infer_shape` corresponding to ABINet in the `tools/export_model.py` file.
After the conversion is successful, there are three files in the directory:
```
/inference/rec_r45_abinet/
├── inference.pdiparams
├── inference.pdiparams.info
└── inference.pdmodel
```
For ABINet text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_r45_abinet/' --rec_algorithm='ABINet' --rec_image_shape='3,32,128' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt'
```
![](../imgs_words_en/word_10.png)
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
The result is as follows:
```shell
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999995231628418)
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
1. Note that the MJSynth and SynthText datasets come from [ABINet repo](https://github.com/FangShancheng/ABINet).
2. We use the pre-trained model provided by the ABINet authors for finetune training.
## Citation
```bibtex
@article{Fang2021ABINet,
title = {ABINet: Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition},
author = {Shancheng Fang and Hongtao Xie and Yuxin Wang and Zhendong Mao and Yongdong Zhang},
booktitle = {CVPR},
year = {2021},
url = {https://arxiv.org/abs/2103.06495},
pages = {7098-7107}
}
```
......@@ -12,6 +12,7 @@
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
- [6. Release Note](#6)
<a name="1"></a>
## 1. Introduction
......@@ -25,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|NRTR|MTB|[rec_mtb_nrtr.yml](../../configs/rec/rec_mtb_nrtr.yml)|84.21%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)|
|NRTR|MTB|[rec_mtb_nrtr.yml](../../configs/rec/rec_mtb_nrtr.yml)|84.21%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)|
<a name="2"></a>
## 2. Environment
......@@ -98,7 +99,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png'
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
The result is as follows:
```shell
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9465042352676392)
```
<a name="4-2"></a>
......@@ -121,12 +122,146 @@ Not supported
1. In the `NRTR` paper, Beam search is used to decode characters, but the speed is slow. Beam search is not used by default here, and greedy search is used to decode characters.
<a name="6"></a>
## 6. Release Note
1. The release/2.6 version updates the NRTR code structure. The new version of NRTR can load the model parameters of the old version (release/2.5 and before), and you may use the following code to convert the old version model parameters to the new version model parameters:
```python
params = paddle.load('path/' + '.pdparams') # the old version parameters
state_dict = model.state_dict() # the new version model parameters
new_state_dict = {}
for k1, v1 in state_dict.items():
k = k1
if 'encoder' in k and 'self_attn' in k and 'qkv' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
q = params[k_para.replace('qkv', 'conv1')].transpose((1, 0, 2, 3))
k = params[k_para.replace('qkv', 'conv2')].transpose((1, 0, 2, 3))
v = params[k_para.replace('qkv', 'conv3')].transpose((1, 0, 2, 3))
new_state_dict[k1] = np.concatenate([q[:, :, 0, 0], k[:, :, 0, 0], v[:, :, 0, 0]], -1)
elif 'encoder' in k and 'self_attn' in k and 'qkv' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
q = params[k_para.replace('qkv', 'conv1')]
k = params[k_para.replace('qkv', 'conv2')]
v = params[k_para.replace('qkv', 'conv3')]
new_state_dict[k1] = np.concatenate([q, k, v], -1)
elif 'encoder' in k and 'self_attn' in k and 'out_proj' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para]
elif 'encoder' in k and 'norm3' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para.replace('norm3', 'norm2')]
elif 'encoder' in k and 'norm1' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para]
elif 'decoder' in k and 'self_attn' in k and 'qkv' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
q = params[k_para.replace('qkv', 'conv1')].transpose((1, 0, 2, 3))
k = params[k_para.replace('qkv', 'conv2')].transpose((1, 0, 2, 3))
v = params[k_para.replace('qkv', 'conv3')].transpose((1, 0, 2, 3))
new_state_dict[k1] = np.concatenate([q[:, :, 0, 0], k[:, :, 0, 0], v[:, :, 0, 0]], -1)
elif 'decoder' in k and 'self_attn' in k and 'qkv' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
q = params[k_para.replace('qkv', 'conv1')]
k = params[k_para.replace('qkv', 'conv2')]
v = params[k_para.replace('qkv', 'conv3')]
new_state_dict[k1] = np.concatenate([q, k, v], -1)
elif 'decoder' in k and 'self_attn' in k and 'out_proj' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para]
elif 'decoder' in k and 'cross_attn' in k and 'q' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
q = params[k_para.replace('q', 'conv1')].transpose((1, 0, 2, 3))
new_state_dict[k1] = q[:, :, 0, 0]
elif 'decoder' in k and 'cross_attn' in k and 'q' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
q = params[k_para.replace('q', 'conv1')]
new_state_dict[k1] = q
elif 'decoder' in k and 'cross_attn' in k and 'kv' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
k = params[k_para.replace('kv', 'conv2')].transpose((1, 0, 2, 3))
v = params[k_para.replace('kv', 'conv3')].transpose((1, 0, 2, 3))
new_state_dict[k1] = np.concatenate([k[:, :, 0, 0], v[:, :, 0, 0]], -1)
elif 'decoder' in k and 'cross_attn' in k and 'kv' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
k = params[k_para.replace('kv', 'conv2')]
v = params[k_para.replace('kv', 'conv3')]
new_state_dict[k1] = np.concatenate([k, v], -1)
elif 'decoder' in k and 'cross_attn' in k and 'out_proj' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('cross_attn', 'multihead_attn')
new_state_dict[k1] = params[k_para]
elif 'decoder' in k and 'norm' in k:
k_para = k[:13] + 'layers.' + k[13:]
new_state_dict[k1] = params[k_para]
elif 'mlp' in k and 'weight' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('fc', 'conv')
k_para = k_para.replace('mlp.', '')
w = params[k_para].transpose((1, 0, 2, 3))
new_state_dict[k1] = w[:, :, 0, 0]
elif 'mlp' in k and 'bias' in k:
k_para = k[:13] + 'layers.' + k[13:]
k_para = k_para.replace('fc', 'conv')
k_para = k_para.replace('mlp.', '')
w = params[k_para]
new_state_dict[k1] = w
else:
new_state_dict[k1] = params[k1]
if list(new_state_dict[k1].shape) != list(v1.shape):
print(k1)
for k, v1 in state_dict.items():
if k not in new_state_dict.keys():
print(1, k)
elif list(new_state_dict[k].shape) != list(v1.shape):
print(2, k)
model.set_state_dict(new_state_dict)
paddle.save(model.state_dict(), 'nrtrnew_from_old_params.pdparams')
```
2. The new version has a clean code structure and improved inference speed compared with the old version.
## Citation
```bibtex
@article{Sheng2019NRTR,
title = {NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition},
author = {Fenfen Sheng and Zhineng Chen andBo Xu},
author = {Fenfen Sheng and Zhineng Chen and Bo Xu},
booktitle = {ICDAR},
year = {2019},
url = {http://arxiv.org/abs/1806.00926},
......
......@@ -25,7 +25,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|ViTSTR|ViTSTR|[rec_vitstr_none_ce.yml](../../configs/rec/rec_vitstr_none_ce.yml)|79.82%|[训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)|
|ViTSTR|ViTSTR|[rec_vitstr_none_ce.yml](../../configs/rec/rec_vitstr_none_ce.yml)|79.82%|[trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)|
<a name="2"></a>
## 2. Environment
......
......@@ -24,7 +24,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, ABINetRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -157,37 +157,6 @@ class BaseRecLabelEncode(object):
return text_list
class NRTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(NRTRLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text))
text.insert(0, 2)
text.append(3)
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
......@@ -840,37 +809,6 @@ class PRENLabelEncode(BaseRecLabelEncode):
return data
class ViTSTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(ViTSTRLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text))
text.insert(0, 0)
text.append(1)
text = text + [0] * (self.max_text_len + 2 - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['<s>', '</s>'] + dict_character
return dict_character
class VQATokenLabelEncode(object):
"""
Label encode for NLP VQA methods
......@@ -1077,3 +1015,99 @@ class MultiLabelEncode(BaseRecLabelEncode):
data_out['label_sar'] = sar['label']
data_out['length'] = ctc['length']
return data_out
class NRTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(NRTRLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text))
text.insert(0, 2)
text.append(3)
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
class ViTSTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
ignore_index=0,
**kwargs):
super(ViTSTRLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.ignore_index = ignore_index
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text))
text.insert(0, self.ignore_index)
text.append(1)
text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['<s>', '</s>'] + dict_character
return dict_character
class ABINetLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
ignore_index=100,
**kwargs):
super(ABINetLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.ignore_index = ignore_index
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text))
text.append(0)
text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['</s>'] + dict_character
return dict_character
......@@ -67,39 +67,6 @@ class DecodeImage(object):
return data
class NRTRDecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
......
......@@ -279,6 +279,24 @@ class PRENResizeImg(object):
return data
class ABINetRecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path=None,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
def __call__(self, data):
img = data['image']
norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
......@@ -397,6 +415,26 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def resize_norm_img_abinet(img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
resized_image = resized_image.astype('float32')
resized_image = resized_image / 255.
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
resized_image = (
resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
resized_image = resized_image.astype('float32')
valid_ratio = min(1.0, float(resized_w / imgW))
return resized_image, valid_ratio
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
......
......@@ -4,25 +4,57 @@ import paddle.nn.functional as F
class CELoss(nn.Layer):
def __init__(self, smoothing=True, with_all=False, **kwargs):
def __init__(self,
smoothing=False,
with_all=False,
ignore_index=-1,
**kwargs):
super(CELoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
if ignore_index >= 0:
self.loss_func = nn.CrossEntropyLoss(
reduction='mean', ignore_index=ignore_index)
else:
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
self.smoothing = smoothing
self.with_all = with_all
def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
if self.with_all:
tgt = batch[1]
if isinstance(pred, dict): # for ABINet
loss = {}
loss_sum = []
for name, logits in pred.items():
if isinstance(logits, list):
logit_num = len(logits)
all_tgt = paddle.concat([batch[1]] * logit_num, 0)
all_logits = paddle.concat(logits, 0)
flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
flt_tgt = all_tgt.reshape([-1])
else:
flt_logtis = logits.reshape([-1, logits.shape[2]])
flt_tgt = batch[1].reshape([-1])
loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt)
loss_sum.append(loss[name + '_loss'])
loss['loss'] = sum(loss_sum)
return loss
else:
if self.with_all: # for ViTSTR
tgt = batch[1]
pred = pred.reshape([-1, pred.shape[2]])
tgt = tgt.reshape([-1])
loss = self.loss_func(pred, tgt)
return {'loss': loss}
else: # for NRTR
max_len = batch[2].max()
tgt = batch[1][:, 1:2 + max_len]
pred = pred.reshape([-1, pred.shape[2]])
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (
n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
......
......@@ -27,7 +27,7 @@ def build_backbone(config, model_type):
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
from .rec_resnet import ResNet31, ResNet45
from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
......@@ -35,29 +35,29 @@ def build_backbone(config, model_type):
from .rec_vitstr import ViTSTR
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
'SVTRNet', 'ViTSTR'
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR'
]
elif model_type == "e2e":
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
elif model_type == 'kie':
from .kie_unet_sdmgr import Kie_backbone
support_dict = ['Kie_backbone']
elif model_type == "table":
elif model_type == 'table':
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"]
support_dict = ['ResNet', 'MobileNetV3']
elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
support_dict = [
"LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
"LayoutXLMForSer", 'LayoutXLMForRe'
'LayoutLMForSer', 'LayoutLMv2ForSer', 'LayoutLMv2ForRe',
'LayoutXLMForSer', 'LayoutXLMForRe'
]
else:
raise NotImplementedError
module_name = config.pop("name")
module_name = config.pop('name')
assert module_name in support_dict, Exception(
"when model typs is {}, backbone only support {}".format(model_type,
support_dict))
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
from paddle.nn.initializer import KaimingNormal
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
import math
__all__ = ["ResNet31", "ResNet45"]
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2D(
in_planes,
out_planes,
kernel_size=1,
stride=stride,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2D(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, in_channels, channels, stride=1, downsample=None):
super().__init__()
self.conv1 = conv1x1(in_channels, channels)
self.bn1 = nn.BatchNorm2D(channels)
self.relu = nn.ReLU()
self.conv2 = conv3x3(channels, channels, stride)
self.bn2 = nn.BatchNorm2D(channels)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet31(nn.Layer):
'''
Args:
in_channels (int): Number of channels of input image tensor.
layers (list[int]): List of BasicBlock number for each stage.
channels (list[int]): List of out_channels of Conv2d layer.
out_indices (None | Sequence[int]): Indices of output stages.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
'''
def __init__(self,
in_channels=3,
layers=[1, 2, 5, 3],
channels=[64, 128, 256, 256, 512, 512, 512],
out_indices=None,
last_stage_pool=False):
super(ResNet31, self).__init__()
assert isinstance(in_channels, int)
assert isinstance(last_stage_pool, bool)
self.out_indices = out_indices
self.last_stage_pool = last_stage_pool
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(
in_channels, channels[0], kernel_size=3, stride=1, padding=1)
self.bn1_1 = nn.BatchNorm2D(channels[0])
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
self.bn1_2 = nn.BatchNorm2D(channels[1])
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
self.pool2 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.conv2 = nn.Conv2D(
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(channels[2])
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
self.pool3 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.conv3 = nn.Conv2D(
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2D(channels[3])
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
self.pool4 = nn.MaxPool2D(
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.conv4 = nn.Conv2D(
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2D(channels[4])
self.relu4 = nn.ReLU()
# conv 5 ((Max-pooling), Residual block, Conv)
self.pool5 = None
if self.last_stage_pool:
self.pool5 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.conv5 = nn.Conv2D(
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2D(channels[5])
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
def _make_layer(self, input_channels, output_channels, blocks):
layers = []
for _ in range(blocks):
downsample = None
if input_channels != output_channels:
downsample = nn.Sequential(
nn.Conv2D(
input_channels,
output_channels,
kernel_size=1,
stride=1,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False),
nn.BatchNorm2D(output_channels), )
layers.append(
BasicBlock(
input_channels, output_channels, downsample=downsample))
input_channels = output_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu1_1(x)
x = self.conv1_2(x)
x = self.bn1_2(x)
x = self.relu1_2(x)
outs = []
for i in range(4):
layer_index = i + 2
pool_layer = getattr(self, f'pool{layer_index}')
block_layer = getattr(self, f'block{layer_index}')
conv_layer = getattr(self, f'conv{layer_index}')
bn_layer = getattr(self, f'bn{layer_index}')
relu_layer = getattr(self, f'relu{layer_index}')
if pool_layer is not None:
x = pool_layer(x)
x = block_layer(x)
x = conv_layer(x)
x = bn_layer(x)
x = relu_layer(x)
outs.append(x)
if self.out_indices is not None:
return tuple([outs[i] for i in self.out_indices])
return x
class ResNet(nn.Layer):
def __init__(self, block, layers, in_channels=3):
self.inplanes = 32
super(ResNet, self).__init__()
self.conv1 = nn.Conv2D(
3,
32,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(32)
self.relu = nn.ReLU()
self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
self.out_channels = 512
# for m in self.modules():
# if isinstance(m, nn.Conv2D):
# n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
# downsample = True
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion), )
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# print(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
# print(x)
x = self.layer4(x)
x = self.layer5(x)
return x
def ResNet45(in_channels=3):
return ResNet(BasicBlock, [3, 4, 6, 6, 3], in_channels=in_channels)
......@@ -33,6 +33,7 @@ def build_head(config):
from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead
from .rec_abinet_head import ABINetHead
# cls head
from .cls_head import ClsHead
......@@ -46,7 +47,7 @@ def build_head(config):
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead'
'MultiHead', 'ABINetHead'
]
#table head
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import Linear
from paddle.nn.initializer import XavierUniform as xavier_uniform_
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_
zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)
class MultiheadAttention(nn.Layer):
"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
"""
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
bias=True,
add_bias_kv=False,
add_zero_attn=False):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
self._reset_parameters()
self.conv1 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv2 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv3 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
def _reset_parameters(self):
xavier_uniform_(self.out_proj.weight)
def forward(self,
query,
key,
value,
key_padding_mask=None,
incremental_state=None,
attn_mask=None):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]
key_padding_mask: if True, mask padding based on batch size
incremental_state: if provided, previous time steps are cashed
need_weights: output attn_output_weights
static_kv: key and value are static
Outputs of forward function
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
q_shape = paddle.shape(query)
src_shape = paddle.shape(key)
q = self._in_proj_q(query)
k = self._in_proj_k(key)
v = self._in_proj_v(value)
q *= self.scaling
q = paddle.transpose(
paddle.reshape(
q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
k = paddle.transpose(
paddle.reshape(
k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
v = paddle.transpose(
paddle.reshape(
v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
if key_padding_mask is not None:
assert key_padding_mask.shape[0] == q_shape[1]
assert key_padding_mask.shape[1] == src_shape[0]
attn_output_weights = paddle.matmul(q,
paddle.transpose(k, [0, 1, 3, 2]))
if attn_mask is not None:
attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = paddle.reshape(
attn_output_weights,
[q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
key = paddle.cast(key, 'float32')
y = paddle.full(
shape=paddle.shape(key), dtype='float32', fill_value='-inf')
y = paddle.where(key == 0., key, y)
attn_output_weights += y
attn_output_weights = F.softmax(
attn_output_weights.astype('float32'),
axis=-1,
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
else attn_output_weights.dtype)
attn_output_weights = F.dropout(
attn_output_weights, p=self.dropout, training=self.training)
attn_output = paddle.matmul(attn_output_weights, v)
attn_output = paddle.reshape(
paddle.transpose(attn_output, [2, 0, 1, 3]),
[q_shape[0], q_shape[1], self.embed_dim])
attn_output = self.out_proj(attn_output)
return attn_output
def _in_proj_q(self, query):
query = paddle.transpose(query, [1, 2, 0])
query = paddle.unsqueeze(query, axis=2)
res = self.conv1(query)
res = paddle.squeeze(res, axis=2)
res = paddle.transpose(res, [2, 0, 1])
return res
def _in_proj_k(self, key):
key = paddle.transpose(key, [1, 2, 0])
key = paddle.unsqueeze(key, axis=2)
res = self.conv2(key)
res = paddle.squeeze(res, axis=2)
res = paddle.transpose(res, [2, 0, 1])
return res
def _in_proj_v(self, value):
value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
value = paddle.unsqueeze(value, axis=2)
res = self.conv3(value)
res = paddle.squeeze(res, axis=2)
res = paddle.transpose(res, [2, 0, 1])
return res
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FangShancheng/ABINet/tree/main/modules
"""
import math
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import LayerList
from ppocr.modeling.heads.rec_nrtr_head import TransformerBlock, PositionalEncoding
class BCNLanguage(nn.Layer):
def __init__(self,
d_model=512,
nhead=8,
num_layers=4,
dim_feedforward=2048,
dropout=0.,
max_length=25,
detach=True,
num_classes=37):
super().__init__()
self.d_model = d_model
self.detach = detach
self.max_length = max_length + 1 # additional stop token
self.proj = nn.Linear(num_classes, d_model, bias_attr=False)
self.token_encoder = PositionalEncoding(
dropout=0.1, dim=d_model, max_len=self.max_length)
self.pos_encoder = PositionalEncoding(
dropout=0, dim=d_model, max_len=self.max_length)
self.decoder = nn.LayerList([
TransformerBlock(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
attention_dropout_rate=dropout,
residual_dropout_rate=dropout,
with_self_attn=False,
with_cross_attn=True) for i in range(num_layers)
])
self.cls = nn.Linear(d_model, num_classes)
def forward(self, tokens, lengths):
"""
Args:
tokens: (B, N, C) where N is length, B is batch size and C is classes number
lengths: (B,)
"""
if self.detach: tokens = tokens.detach()
embed = self.proj(tokens) # (B, N, C)
embed = self.token_encoder(embed) # (B, N, C)
padding_mask = _get_mask(lengths, self.max_length)
zeros = paddle.zeros_like(embed) # (B, N, C)
qeury = self.pos_encoder(zeros)
for decoder_layer in self.decoder:
qeury = decoder_layer(qeury, embed, cross_mask=padding_mask)
output = qeury # (B, N, C)
logits = self.cls(output) # (B, N, C)
return output, logits
def encoder_layer(in_c, out_c, k=3, s=2, p=1):
return nn.Sequential(
nn.Conv2D(in_c, out_c, k, s, p), nn.BatchNorm2D(out_c), nn.ReLU())
def decoder_layer(in_c,
out_c,
k=3,
s=1,
p=1,
mode='nearest',
scale_factor=None,
size=None):
align_corners = False if mode == 'nearest' else True
return nn.Sequential(
nn.Upsample(
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners),
nn.Conv2D(in_c, out_c, k, s, p),
nn.BatchNorm2D(out_c),
nn.ReLU())
class PositionAttention(nn.Layer):
def __init__(self,
max_length,
in_channels=512,
num_channels=64,
h=8,
w=32,
mode='nearest',
**kwargs):
super().__init__()
self.max_length = max_length
self.k_encoder = nn.Sequential(
encoder_layer(
in_channels, num_channels, s=(1, 2)),
encoder_layer(
num_channels, num_channels, s=(2, 2)),
encoder_layer(
num_channels, num_channels, s=(2, 2)),
encoder_layer(
num_channels, num_channels, s=(2, 2)))
self.k_decoder = nn.Sequential(
decoder_layer(
num_channels, num_channels, scale_factor=2, mode=mode),
decoder_layer(
num_channels, num_channels, scale_factor=2, mode=mode),
decoder_layer(
num_channels, num_channels, scale_factor=2, mode=mode),
decoder_layer(
num_channels, in_channels, size=(h, w), mode=mode))
self.pos_encoder = PositionalEncoding(
dropout=0, dim=in_channels, max_len=max_length)
self.project = nn.Linear(in_channels, in_channels)
def forward(self, x):
B, C, H, W = x.shape
k, v = x, x
# calculate key vector
features = []
for i in range(0, len(self.k_encoder)):
k = self.k_encoder[i](k)
features.append(k)
for i in range(0, len(self.k_decoder) - 1):
k = self.k_decoder[i](k)
# print(k.shape, features[len(self.k_decoder) - 2 - i].shape)
k = k + features[len(self.k_decoder) - 2 - i]
k = self.k_decoder[-1](k)
# calculate query vector
# TODO q=f(q,k)
zeros = paddle.zeros(
(B, self.max_length, C), dtype=x.dtype) # (T, N, C)
q = self.pos_encoder(zeros) # (B, N, C)
q = self.project(q) # (B, N, C)
# calculate attention
attn_scores = q @k.flatten(2) # (B, N, (H*W))
attn_scores = attn_scores / (C**0.5)
attn_scores = F.softmax(attn_scores, axis=-1)
v = v.flatten(2).transpose([0, 2, 1]) # (B, (H*W), C)
attn_vecs = attn_scores @v # (B, N, C)
return attn_vecs, attn_scores.reshape([0, self.max_length, H, W])
class ABINetHead(nn.Layer):
def __init__(self,
in_channels,
out_channels,
d_model=512,
nhead=8,
num_layers=3,
dim_feedforward=2048,
dropout=0.1,
max_length=25,
use_lang=False,
iter_size=1):
super().__init__()
self.max_length = max_length + 1
self.pos_encoder = PositionalEncoding(
dropout=0.1, dim=d_model, max_len=8 * 32)
self.encoder = nn.LayerList([
TransformerBlock(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
attention_dropout_rate=dropout,
residual_dropout_rate=dropout,
with_self_attn=True,
with_cross_attn=False) for i in range(num_layers)
])
self.decoder = PositionAttention(
max_length=max_length + 1, # additional stop token
mode='nearest', )
self.out_channels = out_channels
self.cls = nn.Linear(d_model, self.out_channels)
self.use_lang = use_lang
if use_lang:
self.iter_size = iter_size
self.language = BCNLanguage(
d_model=d_model,
nhead=nhead,
num_layers=4,
dim_feedforward=dim_feedforward,
dropout=dropout,
max_length=max_length,
num_classes=self.out_channels)
# alignment
self.w_att_align = nn.Linear(2 * d_model, d_model)
self.cls_align = nn.Linear(d_model, self.out_channels)
def forward(self, x, targets=None):
x = x.transpose([0, 2, 3, 1])
_, H, W, C = x.shape
feature = x.flatten(1, 2)
feature = self.pos_encoder(feature)
for encoder_layer in self.encoder:
feature = encoder_layer(feature)
feature = feature.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
v_feature, attn_scores = self.decoder(
feature) # (B, N, C), (B, C, H, W)
vis_logits = self.cls(v_feature) # (B, N, C)
logits = vis_logits
vis_lengths = _get_length(vis_logits)
if self.use_lang:
align_logits = vis_logits
align_lengths = vis_lengths
all_l_res, all_a_res = [], []
for i in range(self.iter_size):
tokens = F.softmax(align_logits, axis=-1)
lengths = align_lengths
lengths = paddle.clip(
lengths, 2, self.max_length) # TODO:move to langauge model
l_feature, l_logits = self.language(tokens, lengths)
# alignment
all_l_res.append(l_logits)
fuse = paddle.concat((l_feature, v_feature), -1)
f_att = F.sigmoid(self.w_att_align(fuse))
output = f_att * v_feature + (1 - f_att) * l_feature
align_logits = self.cls_align(output) # (B, N, C)
align_lengths = _get_length(align_logits)
all_a_res.append(align_logits)
if self.training:
return {
'align': all_a_res,
'lang': all_l_res,
'vision': vis_logits
}
else:
logits = align_logits
if self.training:
return logits
else:
return F.softmax(logits, -1)
def _get_length(logit):
""" Greed decoder to obtain length from logit"""
out = (logit.argmax(-1) == 0)
abn = out.any(-1)
out_int = out.cast('int32')
out = (out_int.cumsum(-1) == 1) & out
out = out.cast('int32')
out = out.argmax(-1)
out = out + 1
out = paddle.where(abn, out, paddle.to_tensor(logit.shape[1]))
return out
def _get_mask(length, max_length):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
length = length.unsqueeze(-1)
B = paddle.shape(length)[0]
grid = paddle.arange(0, max_length).unsqueeze(0).tile([B, 1])
zero_mask = paddle.zeros([B, max_length], dtype='float32')
inf_mask = paddle.full([B, max_length], '-inf', dtype='float32')
diag_mask = paddle.diag(
paddle.full(
[max_length], '-inf', dtype=paddle.float32),
offset=0,
name=None)
mask = paddle.where(grid >= length, inf_mask, zero_mask)
mask = mask.unsqueeze(1) + diag_mask
return mask.unsqueeze(1)
......@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
......@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode'
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode'
]
if config['name'] == 'PSEPostProcess':
......
......@@ -140,96 +140,6 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
return output
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
if isinstance(preds_id, paddle.Tensor):
preds_id = preds_id.numpy()
if isinstance(preds_prob, paddle.Tensor):
preds_prob = preds_prob.numpy()
if preds_id[0][0] == 2:
preds_idx = preds_id[:, 1:]
preds_prob = preds_prob[:, 1:]
else:
preds_idx = preds_id
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
else:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
try:
char_idx = self.character[int(text_index[batch_idx][idx])]
except:
continue
if char_idx == '</s>': # end
break
char_list.append(char_idx)
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list).tolist()))
return result_list
class ViTSTRLabelDecode(NRTRLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(ViTSTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds[:, 1:].numpy()
else:
preds = preds[:, 1:]
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['<s>', '</s>'] + dict_character
return dict_character
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......@@ -778,3 +688,122 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text
label = self.decode(label)
return text, label
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
if isinstance(preds_id, paddle.Tensor):
preds_id = preds_id.numpy()
if isinstance(preds_prob, paddle.Tensor):
preds_prob = preds_prob.numpy()
if preds_id[0][0] == 2:
preds_idx = preds_id[:, 1:]
preds_prob = preds_prob[:, 1:]
else:
preds_idx = preds_id
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
else:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
try:
char_idx = self.character[int(text_index[batch_idx][idx])]
except:
continue
if char_idx == '</s>': # end
break
char_list.append(char_idx)
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list).tolist()))
return result_list
class ViTSTRLabelDecode(NRTRLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(ViTSTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds[:, 1:].numpy()
else:
preds = preds[:, 1:]
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['<s>', '</s>'] + dict_character
return dict_character
class ABINetLabelDecode(NRTRLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(ABINetLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, dict):
preds = preds['align'][-1].numpy()
elif isinstance(preds, paddle.Tensor):
preds = preds.numpy()
else:
preds = preds
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['</s>'] + dict_character
return dict_character
......@@ -99,5 +99,5 @@ Eval:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 1
num_workers: 4
use_shared_memory: False
Global:
use_gpu: True
epoch_num: 10
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/r45_abinet/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_abinet.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.99
clip_norm: 20.0
lr:
name: Piecewise
decay_epochs: [6]
values: [0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0.
Architecture:
model_type: rec
algorithm: ABINet
in_channels: 3
Transform:
Backbone:
name: ResNet45
Head:
name: ABINetHead
use_lang: True
iter_size: 3
Loss:
name: CELoss
ignore_index: &ignore_index 100 # Must be greater than the number of character classes
PostProcess:
name: ABINetLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 128]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 96
drop_last: True
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 128]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 4
use_shared_memory: False
===========================train_params===========================
model_name:rec_abinet
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_r45_abinet_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,128" --rec_algorithm="ABINet"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1|6
--use_tensorrt:False
--precision:fp32
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,128]}]
......@@ -26,7 +26,7 @@ Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 0.00000008
epsilon: 8.e-8
weight_decay: 0.05
no_weight_decay_name: norm pos_embed
one_dim_param_no_weight_decay: true
......
......@@ -37,7 +37,7 @@ export2:null
train_model:./inference/rec_svtrnet_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbol_dict.txt --rec_image_shape="3,64,256" --rec_algorithm="SVTR"
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,64,256" --rec_algorithm="SVTR"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
......@@ -50,4 +50,4 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbo
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[1=3,64,256]}]
random_infer_input:[{float32,[3,64,256]}]
......@@ -23,7 +23,7 @@ Global:
Optimizer:
name: Adadelta
epsilon: 0.00000001
epsilon: 1.e-8
rho: 0.95
clip_norm: 5.0
lr:
......@@ -46,6 +46,7 @@ Loss:
name: CELoss
smoothing: False
with_all: True
ignore_index: &ignore_index 0 # Must be zero or greater than the number of character classes
PostProcess:
name: ViTSTRLabelDecode
......@@ -64,6 +65,7 @@ Train:
img_mode: BGR
channel_first: False
- ViTSTRLabelEncode: # Class handling label
ignore_index: *ignore_index
- GrayRecResizeImg:
image_shape: [224, 224] # W H
resize_type: PIL # PIL or OpenCV
......@@ -87,6 +89,7 @@ Eval:
img_mode: BGR
channel_first: False
- ViTSTRLabelEncode: # Class handling label
ignore_index: *ignore_index
- GrayRecResizeImg:
image_shape: [224, 224] # W H
resize_type: PIL # PIL or OpenCV
......
......@@ -79,6 +79,19 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
shape=[None, 1, 224, 224], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ABINet":
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 32, 128], dtype="float32"),
]
# print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "NRTR":
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
......@@ -90,8 +103,6 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
if arch_config["algorithm"] == "NRTR":
infer_shape = [1, 32, 100]
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
model = to_static(
......
......@@ -75,6 +75,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == 'ABINet':
postprocess_params = {
'name': 'ABINetLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
......@@ -145,17 +151,6 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
......@@ -263,6 +258,35 @@ class TextRecognizer(object):
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_abinet(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image / 255.
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
resized_image = (
resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
resized_image = resized_image.astype('float32')
return resized_image
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
......@@ -313,6 +337,11 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "ABINet":
norm_img = self.resize_norm_img_abinet(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
......
......@@ -575,7 +575,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
'ViTSTR'
'ViTSTR', 'ABINet'
]
if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册