未验证 提交 60ea689d 编写于 作者: S smilelite 提交者: GitHub

Merge branch 'dygraph' into robustscanner_branch

Global:
use_gpu: true
epoch_num: 8
log_smooth_window: 200
print_batch_step: 200
save_model_dir: ./output/rec/r45_visionlan
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: True
infer_img: doc/imgs_words/en/word_2.png
# for data or label process
character_dict_path:
max_text_length: &max_text_length 25
training_step: &training_step LA
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_visionlan.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 20.0
group_lr: true
training_step: *training_step
lr:
name: Piecewise
decay_epochs: [6]
values: [0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: VisionLAN
Transform:
Backbone:
name: ResNet45
strides: [2, 2, 2, 1, 1]
Head:
name: VLHead
n_layers: 3
n_position: 256
n_dim: 512
max_text_length: *max_text_length
training_step: *training_step
Loss:
name: VLLoss
mode: *training_step
weight_res: 0.5
weight_mas: 0.5
PostProcess:
name: VLLabelDecode
Metric:
name: RecMetric
is_filter: true
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetRecAug:
- VLLabelEncode: # Class handling label
- VLRecResizeImg:
image_shape: [3, 64, 256]
- KeepKeys:
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 220
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VLLabelEncode: # Class handling label
- VLRecResizeImg:
image_shape: [3, 64, 256]
- KeepKeys:
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 64
num_workers: 4
......@@ -69,6 +69,7 @@
- [x] [SVTR](./algorithm_rec_svtr.md)
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
- [x] [ABINet](./algorithm_rec_abinet.md)
- [x] [VisionLAN](./algorithm_rec_visionlan.md)
- [x] [SPIN](./algorithm_rec_spin.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
......@@ -91,6 +92,7 @@
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
......
# 场景文本识别算法-VisionLAN
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
> Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang
> ICCV, 2021
<a name="model"></a>
`VisionLAN`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
<a name="3-1"></a>
### 3.1 模型训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`VisionLAN`识别模型时需要**更换配置文件**`VisionLAN`[配置文件](../../configs/rec/rec_r45_visionlan.yml)
#### 启动训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```shell
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_r45_visionlan.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_visionlan.yml
```
<a name="3-2"></a>
### 3.2 评估
可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/eval.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
```
<a name="3-3"></a>
### 3.3 预测
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_img='./doc/imgs_words/en/word_2.png' Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应VisionLAN的`infer_shape`
转换成功后,在目录下有三个文件:
```
./inference/rec_r45_visionlan/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt'
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
```
![](../imgs_words/en/word_2.png)
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
结果如下:
```shell
Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982)
```
**注意**
- 训练上述模型采用的图像分辨率是[3,64,256],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中VisionLAN的预处理为您的预处理方法。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持VisionLAN,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN)
2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练。
## 引用
```bibtex
@inproceedings{wang2021two,
title={From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network},
author={Wang, Yuxin and Xie, Hongtao and Fang, Shancheng and Wang, Jing and Zhu, Shenggao and Zhang, Yongdong},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={14194--14203},
year={2021}
}
```
......@@ -68,6 +68,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SVTR](./algorithm_rec_svtr_en.md)
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
- [x] [ABINet](./algorithm_rec_abinet_en.md)
- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
- [x] [SPIN](./algorithm_rec_spin_en.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
......@@ -90,6 +91,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
......
# VisionLAN
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
> Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang
> ICCV, 2021
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/rec/rec_r45_visionlan.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_visionlan.yml
```
Evaluation:
```
# GPU evaluation
python3 tools/eval.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_img='./doc/imgs_words/en/word_2.png' Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the VisionLAN text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)) ), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
```
**Note:**
- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
- If you modified the input size during training, please modify the `infer_shape` corresponding to VisionLAN in the `tools/export_model.py` file.
After the conversion is successful, there are three files in the directory:
```
./inference/rec_r45_visionlan/
├── inference.pdiparams
├── inference.pdiparams.info
└── inference.pdmodel
```
For VisionLAN text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt'
```
![](../imgs_words/en/word_2.png)
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
The result is as follows:
```shell
Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982)
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN).
2. We use the pre-trained model provided by the VisionLAN authors for finetune training.
## Citation
```bibtex
@inproceedings{wang2021two,
title={From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network},
author={Wang, Yuxin and Xie, Hongtao and Fang, Shancheng and Wang, Jing and Zhu, Shenggao and Zhang, Yongdong},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={14194--14203},
year={2021}
}
```
......@@ -25,8 +25,8 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, SPINRecResizeImg, RobustScannerRecResizeImg
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -23,6 +23,8 @@ import string
from shapely.geometry import LineString, Point, Polygon
import json
import copy
from random import sample
from ppocr.utils.logging import get_logger
from ppocr.data.imaug.vqa.augment import order_by_tbyx
......@@ -98,12 +100,13 @@ class BaseRecLabelEncode(object):
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False):
use_space_char=False,
lower=False):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
self.lower = False
self.lower = lower
if character_dict_path is None:
logger = get_logger()
......@@ -1266,3 +1269,67 @@ class SPINLabelEncode(AttnLabelEncode):
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
class VLLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
lower=True,
**kwargs):
super(VLLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char, lower)
self.character = self.character[10:] + self.character[
1:10] + [self.character[0]]
self.dict = {}
for i, char in enumerate(self.character):
self.dict[char] = i
def __call__(self, data):
text = data['label'] # original string
# generate occluded text
len_str = len(text)
if len_str <= 0:
return None
change_num = 1
order = list(range(len_str))
change_id = sample(order, change_num)[0]
label_sub = text[change_id]
if change_id == (len_str - 1):
label_res = text[:change_id]
elif change_id == 0:
label_res = text[1:]
else:
label_res = text[:change_id] + text[change_id + 1:]
data['label_res'] = label_res # remaining string
data['label_sub'] = label_sub # occluded character
data['label_id'] = change_id # character index
# encode label
text = self.encode(text)
if text is None:
return None
text = [i + 1 for i in text]
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
label_res = self.encode(label_res)
label_sub = self.encode(label_sub)
if label_res is None:
label_res = []
else:
label_res = [i + 1 for i in label_res]
if label_sub is None:
label_sub = []
else:
label_sub = [i + 1 for i in label_sub]
data['length_res'] = np.array(len(label_res))
data['length_sub'] = np.array(len(label_sub))
label_res = label_res + [0] * (self.max_text_len - len(label_res))
label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
data['label_res'] = np.array(label_res)
data['label_sub'] = np.array(label_sub)
return data
......@@ -205,6 +205,38 @@ class RecResizeImg(object):
return data
class VLRecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
imgC, imgH, imgW = self.image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
resized_image = resized_image.astype('float32')
if self.image_shape[0] == 1:
resized_image = resized_image / 255
norm_img = resized_image[np.newaxis, :]
else:
norm_img = resized_image.transpose((2, 0, 1)) / 255
valid_ratio = min(1.0, float(resized_w / imgW))
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
......@@ -259,6 +291,7 @@ class PRENResizeImg(object):
data['image'] = resized_img.astype(np.float32)
return data
class SPINRecResizeImg(object):
def __init__(self,
image_shape,
......@@ -267,7 +300,7 @@ class SPINRecResizeImg(object):
std=(127.5, 127.5, 127.5),
**kwargs):
self.image_shape = image_shape
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.interpolation = interpolation
......@@ -303,6 +336,7 @@ class SPINRecResizeImg(object):
data['image'] = img
return data
class GrayRecResizeImg(object):
def __init__(self,
image_shape,
......
......@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
# cls loss
......@@ -63,7 +64,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss'
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/wangyuxin87/VisionLAN
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class VLLoss(nn.Layer):
def __init__(self, mode='LF_1', weight_res=0.5, weight_mas=0.5, **kwargs):
super(VLLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean")
assert mode in ['LF_1', 'LF_2', 'LA']
self.mode = mode
self.weight_res = weight_res
self.weight_mas = weight_mas
def flatten_label(self, target):
label_flatten = []
label_length = []
for i in range(0, target.shape[0]):
cur_label = target[i].tolist()
label_flatten += cur_label[:cur_label.index(0) + 1]
label_length.append(cur_label.index(0) + 1)
label_flatten = paddle.to_tensor(label_flatten, dtype='int64')
label_length = paddle.to_tensor(label_length, dtype='int32')
return (label_flatten, label_length)
def _flatten(self, sources, lengths):
return paddle.concat([t[:l] for t, l in zip(sources, lengths)])
def forward(self, predicts, batch):
text_pre = predicts[0]
target = batch[1].astype('int64')
label_flatten, length = self.flatten_label(target)
text_pre = self._flatten(text_pre, length)
if self.mode == 'LF_1':
loss = self.loss_func(text_pre, label_flatten)
else:
text_rem = predicts[1]
text_mas = predicts[2]
target_res = batch[2].astype('int64')
target_sub = batch[3].astype('int64')
label_flatten_res, length_res = self.flatten_label(target_res)
label_flatten_sub, length_sub = self.flatten_label(target_sub)
text_rem = self._flatten(text_rem, length_res)
text_mas = self._flatten(text_mas, length_sub)
loss_ori = self.loss_func(text_pre, label_flatten)
loss_res = self.loss_func(text_rem, label_flatten_res)
loss_mas = self.loss_func(text_mas, label_flatten_sub)
loss = loss_ori + loss_res * self.weight_res + loss_mas * self.weight_mas
return {'loss': loss}
......@@ -84,11 +84,15 @@ class BasicBlock(nn.Layer):
class ResNet45(nn.Layer):
def __init__(self, block=BasicBlock, layers=[3, 4, 6, 6, 3], in_channels=3):
def __init__(self,
in_channels=3,
block=BasicBlock,
layers=[3, 4, 6, 6, 3],
strides=[2, 1, 2, 1, 1]):
self.inplanes = 32
super(ResNet45, self).__init__()
self.conv1 = nn.Conv2D(
3,
in_channels,
32,
kernel_size=3,
stride=1,
......@@ -98,18 +102,13 @@ class ResNet45(nn.Layer):
self.bn1 = nn.BatchNorm2D(32)
self.relu = nn.ReLU()
self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
self.layer3 = self._make_layer(block, 128, layers[2], stride=strides[2])
self.layer4 = self._make_layer(block, 256, layers[3], stride=strides[3])
self.layer5 = self._make_layer(block, 512, layers[4], stride=strides[4])
self.out_channels = 512
# for m in self.modules():
# if isinstance(m, nn.Conv2D):
# n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
......@@ -137,11 +136,9 @@ class ResNet45(nn.Layer):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# print(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
# print(x)
x = self.layer4(x)
x = self.layer5(x)
return x
......@@ -140,4 +140,4 @@ class ResNet_ASTER(nn.Layer):
rnn_feat, _ = self.rnn(cnn_feat)
return rnn_feat
else:
return cnn_feat
return cnn_feat
\ No newline at end of file
......@@ -36,6 +36,7 @@ def build_head(config):
from .rec_spin_att_head import SPINAttentionHead
from .rec_abinet_head import ABINetHead
from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VLHead
# cls head
from .cls_head import ClsHead
......@@ -50,7 +51,8 @@ def build_head(config):
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'RobustScannerHead'
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'RobustScannerHead'
]
#table head
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/wangyuxin87/VisionLAN
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, XavierNormal
import numpy as np
class PositionalEncoding(nn.Layer):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
self.register_buffer(
'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
return sinusoid_table
def forward(self, x):
return x + self.pos_table[:, :x.shape[1]].clone().detach()
class ScaledDotProductAttention(nn.Layer):
"Scaled Dot-Product Attention"
def __init__(self, temperature, attn_dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(axis=2)
def forward(self, q, k, v, mask=None):
k = paddle.transpose(k, perm=[0, 2, 1])
attn = paddle.bmm(q, k)
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -1e9)
if mask.dim() == 3:
mask = paddle.unsqueeze(mask, axis=1)
elif mask.dim() == 2:
mask = paddle.unsqueeze(mask, axis=1)
mask = paddle.unsqueeze(mask, axis=1)
repeat_times = [
attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
]
mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
attn[mask == 0] = -1e9
attn = self.softmax(attn)
attn = self.dropout(attn)
output = paddle.bmm(attn, v)
return output
class MultiHeadAttention(nn.Layer):
" Multi-Head Attention module"
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(
d_model,
n_head * d_k,
weight_attr=ParamAttr(initializer=Normal(
mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
self.w_ks = nn.Linear(
d_model,
n_head * d_k,
weight_attr=ParamAttr(initializer=Normal(
mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
self.w_vs = nn.Linear(
d_model,
n_head * d_v,
weight_attr=ParamAttr(initializer=Normal(
mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
0.5))
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(
n_head * d_v,
d_model,
weight_attr=ParamAttr(initializer=XavierNormal()))
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.shape
sz_b, len_k, _ = k.shape
sz_b, len_v, _ = v.shape
residual = q
q = self.w_qs(q)
q = paddle.reshape(
q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
k = self.w_ks(k)
k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
v = self.w_vs(v)
v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
q = paddle.transpose(q, perm=[2, 0, 1, 3])
q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk
k = paddle.transpose(k, perm=[2, 0, 1, 3])
k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk
v = paddle.transpose(v, perm=[2, 0, 1, 3])
v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
mask = paddle.tile(
mask,
[n_head, 1, 1]) if mask is not None else None # (n*b) x .. x ..
output = self.attention(q, k, v, mask=mask)
output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
output = paddle.transpose(output, perm=[1, 2, 0, 3])
output = paddle.reshape(
output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output
class PositionwiseFeedForward(nn.Layer):
def __init__(self, d_in, d_hid, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise
self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise
self.layer_norm = nn.LayerNorm(d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = paddle.transpose(x, perm=[0, 2, 1])
x = self.w_2(F.relu(self.w_1(x)))
x = paddle.transpose(x, perm=[0, 2, 1])
x = self.dropout(x)
x = self.layer_norm(x + residual)
return x
class EncoderLayer(nn.Layer):
''' Compose with two layers '''
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout)
def forward(self, enc_input, slf_attn_mask=None):
enc_output = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output
class Transformer_Encoder(nn.Layer):
def __init__(self,
n_layers=2,
n_head=8,
d_word_vec=512,
d_k=64,
d_v=64,
d_model=512,
d_inner=2048,
dropout=0.1,
n_position=256):
super(Transformer_Encoder, self).__init__()
self.position_enc = PositionalEncoding(
d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.LayerList([
EncoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
def forward(self, enc_output, src_mask, return_attns=False):
enc_output = self.dropout(
self.position_enc(enc_output)) # position embeding
for enc_layer in self.layer_stack:
enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
enc_output = self.layer_norm(enc_output)
return enc_output
class PP_layer(nn.Layer):
def __init__(self, n_dim=512, N_max_character=25, n_position=256):
super(PP_layer, self).__init__()
self.character_len = N_max_character
self.f0_embedding = nn.Embedding(N_max_character, n_dim)
self.w0 = nn.Linear(N_max_character, n_position)
self.wv = nn.Linear(n_dim, n_dim)
self.we = nn.Linear(n_dim, N_max_character)
self.active = nn.Tanh()
self.softmax = nn.Softmax(axis=2)
def forward(self, enc_output):
# enc_output: b,256,512
reading_order = paddle.arange(self.character_len, dtype='int64')
reading_order = reading_order.unsqueeze(0).expand(
[enc_output.shape[0], self.character_len]) # (S,) -> (B, S)
reading_order = self.f0_embedding(reading_order) # b,25,512
# calculate attention
reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
t = self.w0(reading_order) # b,512,256
t = self.active(
paddle.transpose(
t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512
t = self.we(t) # b,256,25
t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
g_output = paddle.bmm(t, enc_output) # b,25,512
return g_output
class Prediction(nn.Layer):
def __init__(self,
n_dim=512,
n_position=256,
N_max_character=25,
n_class=37):
super(Prediction, self).__init__()
self.pp = PP_layer(
n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
self.pp_share = PP_layer(
n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
self.w_vrm = nn.Linear(n_dim, n_class) # output layer
self.w_share = nn.Linear(n_dim, n_class) # output layer
self.nclass = n_class
def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
use_mlm=True):
if train_mode:
if not use_mlm:
g_output = self.pp(cnn_feature) # b,25,512
g_output = self.w_vrm(g_output)
f_res = 0
f_sub = 0
return g_output, f_res, f_sub
g_output = self.pp(cnn_feature) # b,25,512
f_res = self.pp_share(f_res)
f_sub = self.pp_share(f_sub)
g_output = self.w_vrm(g_output)
f_res = self.w_share(f_res)
f_sub = self.w_share(f_sub)
return g_output, f_res, f_sub
else:
g_output = self.pp(cnn_feature) # b,25,512
g_output = self.w_vrm(g_output)
return g_output
class MLM(nn.Layer):
"Architecture of MLM"
def __init__(self, n_dim=512, n_position=256, max_text_length=25):
super(MLM, self).__init__()
self.MLM_SequenceModeling_mask = Transformer_Encoder(
n_layers=2, n_position=n_position)
self.MLM_SequenceModeling_WCL = Transformer_Encoder(
n_layers=1, n_position=n_position)
self.pos_embedding = nn.Embedding(max_text_length, n_dim)
self.w0_linear = nn.Linear(1, n_position)
self.wv = nn.Linear(n_dim, n_dim)
self.active = nn.Tanh()
self.we = nn.Linear(n_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x, label_pos):
# transformer unit for generating mask_c
feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
# position embedding layer
label_pos = paddle.to_tensor(label_pos, dtype='int64')
pos_emb = self.pos_embedding(label_pos)
pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
# fusion position embedding with features V & generate mask_c
att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
att_map_sub = self.we(att_map_sub) # b,256,1
att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
att_map_sub = self.sigmoid(att_map_sub) # b,1,256
# WCL
## generate inputs for WCL
att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
f_res = x * (1 - att_map_sub) # second path with remaining string
f_sub = x * att_map_sub # first path with occluded character
## transformer units in WCL
f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
return f_res, f_sub, att_map_sub
def trans_1d_2d(x):
b, w_h, c = x.shape # b, 256, 512
x = paddle.transpose(x, perm=[0, 2, 1])
x = paddle.reshape(x, [-1, c, 32, 8])
x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32]
return x
class MLM_VRM(nn.Layer):
"""
MLM+VRM, MLM is only used in training.
ratio controls the occluded number in a batch.
The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
x: input image
label_pos: character index
training_step: LF or LA process
output
text_pre: prediction of VRM
test_rem: prediction of remaining string in MLM
text_mas: prediction of occluded character in MLM
mask_c_show: visualization of Mask_c
"""
def __init__(self,
n_layers=3,
n_position=256,
n_dim=512,
max_text_length=25,
nclass=37):
super(MLM_VRM, self).__init__()
self.MLM = MLM(n_dim=n_dim,
n_position=n_position,
max_text_length=max_text_length)
self.SequenceModeling = Transformer_Encoder(
n_layers=n_layers, n_position=n_position)
self.Prediction = Prediction(
n_dim=n_dim,
n_position=n_position,
N_max_character=max_text_length +
1, # N_max_character = 1 eos + 25 characters
n_class=nclass)
self.nclass = nclass
self.max_text_length = max_text_length
def forward(self, x, label_pos, training_step, train_mode=False):
b, c, h, w = x.shape
nT = self.max_text_length
x = paddle.transpose(x, perm=[0, 1, 3, 2])
x = paddle.reshape(x, [-1, c, h * w])
x = paddle.transpose(x, perm=[0, 2, 1])
if train_mode:
if training_step == 'LF_1':
f_res = 0
f_sub = 0
x = self.SequenceModeling(x, src_mask=None)
text_pre, test_rem, text_mas = self.Prediction(
x, f_res, f_sub, train_mode=True, use_mlm=False)
return text_pre, text_pre, text_pre, text_pre
elif training_step == 'LF_2':
# MLM
f_res, f_sub, mask_c = self.MLM(x, label_pos)
x = self.SequenceModeling(x, src_mask=None)
text_pre, test_rem, text_mas = self.Prediction(
x, f_res, f_sub, train_mode=True)
mask_c_show = trans_1d_2d(mask_c)
return text_pre, test_rem, text_mas, mask_c_show
elif training_step == 'LA':
# MLM
f_res, f_sub, mask_c = self.MLM(x, label_pos)
## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
## ratio controls the occluded number in a batch
character_mask = paddle.zeros_like(mask_c)
ratio = b // 2
if ratio >= 1:
with paddle.no_grad():
character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
else:
character_mask = mask_c
x = x * (1 - character_mask)
# VRM
## transformer unit for VRM
x = self.SequenceModeling(x, src_mask=None)
## prediction layer for MLM and VSR
text_pre, test_rem, text_mas = self.Prediction(
x, f_res, f_sub, train_mode=True)
mask_c_show = trans_1d_2d(mask_c)
return text_pre, test_rem, text_mas, mask_c_show
else:
raise NotImplementedError
else: # VRM is only used in the testing stage
f_res = 0
f_sub = 0
contextual_feature = self.SequenceModeling(x, src_mask=None)
text_pre = self.Prediction(
contextual_feature,
f_res,
f_sub,
train_mode=False,
use_mlm=False)
text_pre = paddle.transpose(
text_pre, perm=[1, 0, 2]) # (26, b, 37))
return text_pre, x
class VLHead(nn.Layer):
"""
Architecture of VisionLAN
"""
def __init__(self,
in_channels,
out_channels=36,
n_layers=3,
n_position=256,
n_dim=512,
max_text_length=25,
training_step='LA'):
super(VLHead, self).__init__()
self.MLM_VRM = MLM_VRM(
n_layers=n_layers,
n_position=n_position,
n_dim=n_dim,
max_text_length=max_text_length,
nclass=out_channels + 1)
self.training_step = training_step
def forward(self, feat, targets=None):
if self.training:
label_pos = targets[-2]
text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
feat, label_pos, self.training_step, train_mode=True)
return text_pre, test_rem, text_mas, mask_map
else:
text_pre, x = self.MLM_VRM(
feat, targets, self.training_step, train_mode=False)
return text_pre, x
......@@ -77,11 +77,62 @@ class Adam(object):
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
self.group_lr = kwargs.get('group_lr', False)
self.training_step = kwargs.get('training_step', None)
def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
if self.group_lr:
if self.training_step == 'LF_2':
import paddle
if isinstance(model, paddle.fluid.dygraph.parallel.
DataParallel): # multi gpu
mlm = model._layers.head.MLM_VRM.MLM.parameters()
pre_mlm_pp = model._layers.head.MLM_VRM.Prediction.pp_share.parameters(
)
pre_mlm_w = model._layers.head.MLM_VRM.Prediction.w_share.parameters(
)
else: # single gpu
mlm = model.head.MLM_VRM.MLM.parameters()
pre_mlm_pp = model.head.MLM_VRM.Prediction.pp_share.parameters(
)
pre_mlm_w = model.head.MLM_VRM.Prediction.w_share.parameters(
)
total = []
for param in mlm:
total.append(id(param))
for param in pre_mlm_pp:
total.append(id(param))
for param in pre_mlm_w:
total.append(id(param))
group_base_params = [
param for param in model.parameters() if id(param) in total
]
group_small_params = [
param for param in model.parameters()
if id(param) not in total
]
train_params = [{
'params': group_base_params
}, {
'params': group_small_params,
'learning_rate': self.learning_rate.values[0] * 0.1
}]
else:
print(
'group lr currently only support VisionLAN in LF_2 training step'
)
train_params = [
param for param in model.parameters()
if param.trainable is True
]
else:
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
......
......@@ -28,41 +28,27 @@ from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
SPINLabelDecode
SPINLabelDecode, VLLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess
def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess',
'EASTPostProcess',
'SASTPostProcess',
'FCEPostProcess',
'CTCLabelDecode',
'AttnLabelDecode',
'ClsPostProcess',
'SRNLabelDecode',
'PGPostProcess',
'DistillationCTCLabelDecode',
'TableLabelDecode',
'DistillationDBPostProcess',
'NRTRLabelDecode',
'SARLabelDecode',
'SEEDLabelDecode',
'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess',
'PRENLabelDecode',
'DistillationSARLabelDecode',
'ViTSTRLabelDecode',
'ABINetLabelDecode',
'TableMasterLabelDecode',
'SPINLabelDecode',
'DistillationSerPostProcess',
'DistillationRePostProcess',
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess',
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess'
]
if config['name'] == 'PSEPostProcess':
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.special import softmax
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
indexes = np.argsort(scores)
indexes = indexes[-candidate_size:]
while len(indexes) > 0:
current = indexes[-1]
picked.append(current)
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[:-1]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
np.expand_dims(
current_box, axis=0), )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)
def area_of(left_top, right_bottom):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
class PicoDetPostProcess(object):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
layout_dict_path,
strides=[8, 16, 32, 64],
score_threshold=0.4,
nms_threshold=0.5,
nms_top_k=1000,
keep_top_k=100):
self.labels = self.load_layout_dict(layout_dict_path)
self.strides = strides
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
def load_layout_dict(self, layout_dict_path):
with open(layout_dict_path, 'r', encoding='utf-8') as fp:
labels = fp.readlines()
return [label.strip('\n') for label in labels]
def warp_boxes(self, boxes, ori_shape):
"""Apply transform to boxes
"""
width, height = ori_shape[1], ori_shape[0]
n = len(boxes)
if n:
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
n * 4, 2) # x1y1, x2y2, x1y2, x2y1
# xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip boxes
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
return xy.astype(np.float32)
else:
return boxes
def img_info(self, ori_img, img):
origin_shape = ori_img.shape
resize_shape = img.shape
im_scale_y = resize_shape[2] / float(origin_shape[0])
im_scale_x = resize_shape[3] / float(origin_shape[1])
scale_factor = np.array([im_scale_y, im_scale_x], dtype=np.float32)
img_shape = np.array(img.shape[2:], dtype=np.float32)
input_shape = np.array(img).astype('float32').shape[2:]
ori_shape = np.array((img_shape, )).astype('float32')
scale_factor = np.array((scale_factor, )).astype('float32')
return ori_shape, input_shape, scale_factor
def __call__(self, ori_img, img, preds):
scores, raw_boxes = preds['boxes'], preds['boxes_num']
batch_size = raw_boxes[0].shape[0]
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
out_boxes_num = []
out_boxes_list = []
results = []
ori_shape, input_shape, scale_factor = self.img_info(ori_img, img)
for batch_id in range(batch_size):
# generate centers
decode_boxes = []
select_scores = []
for stride, box_distribute, score in zip(self.strides, raw_boxes,
scores):
box_distribute = box_distribute[batch_id]
score = score[batch_id]
# centers
fm_h = input_shape[0] / stride
fm_w = input_shape[1] / stride
h_range = np.arange(fm_h)
w_range = np.arange(fm_w)
ww, hh = np.meshgrid(w_range, h_range)
ct_row = (hh.flatten() + 0.5) * stride
ct_col = (ww.flatten() + 0.5) * stride
center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
# box distribution to distance
reg_range = np.arange(reg_max + 1)
box_distance = box_distribute.reshape((-1, reg_max + 1))
box_distance = softmax(box_distance, axis=1)
box_distance = box_distance * np.expand_dims(reg_range, axis=0)
box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
box_distance = box_distance * stride
# top K candidate
topk_idx = np.argsort(score.max(axis=1))[::-1]
topk_idx = topk_idx[:self.nms_top_k]
center = center[topk_idx]
score = score[topk_idx]
box_distance = box_distance[topk_idx]
# decode box
decode_box = center + [-1, -1, 1, 1] * box_distance
select_scores.append(score)
decode_boxes.append(decode_box)
# nms
bboxes = np.concatenate(decode_boxes, axis=0)
confidences = np.concatenate(select_scores, axis=0)
picked_box_probs = []
picked_labels = []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
mask = probs > self.score_threshold
probs = probs[mask]
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate(
[subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
top_k=self.keep_top_k, )
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
out_boxes_num.append(0)
else:
picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes
picked_box_probs[:, :4] = self.warp_boxes(
picked_box_probs[:, :4], ori_shape[batch_id])
im_scale = np.concatenate([
scale_factor[batch_id][::-1], scale_factor[batch_id][::-1]
])
picked_box_probs[:, :4] /= im_scale
# clas score box
out_boxes_list.append(
np.concatenate(
[
np.expand_dims(
np.array(picked_labels),
axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4]
],
axis=1))
out_boxes_num.append(len(picked_labels))
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
for dt in out_boxes_list:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
label = self.labels[clsid]
result = {'bbox': bbox, 'label': label}
results.append(result)
return results
......@@ -668,6 +668,7 @@ class ABINetLabelDecode(NRTRLabelDecode):
dict_character = ['</s>'] + dict_character
return dict_character
class SPINLabelDecode(AttnLabelDecode):
""" Convert between text-label and text-index """
......@@ -681,4 +682,106 @@ class SPINLabelDecode(AttnLabelDecode):
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + [self.end_str] + dict_character
return dict_character
\ No newline at end of file
return dict_character
class VLLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
self.nclass = len(self.character) + 1
self.character = self.character[10:] + self.character[
1:10] + [self.character[0]]
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id - 1]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, length=None, *args, **kwargs):
if len(preds) == 2: # eval mode
text_pre, x = preds
b = text_pre.shape[1]
lenText = self.max_text_length
nsteps = self.max_text_length
if not isinstance(text_pre, paddle.Tensor):
text_pre = paddle.to_tensor(text_pre, dtype='float32')
out_res = paddle.zeros(
shape=[lenText, b, self.nclass], dtype=x.dtype)
out_length = paddle.zeros(shape=[b], dtype=x.dtype)
now_step = 0
for _ in range(nsteps):
if 0 in out_length and now_step < nsteps:
tmp_result = text_pre[now_step, :, :]
out_res[now_step] = tmp_result
tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
for j in range(b):
if out_length[j] == 0 and tmp_result[j] == 0:
out_length[j] = now_step + 1
now_step += 1
for j in range(0, b):
if int(out_length[j]) == 0:
out_length[j] = nsteps
start = 0
output = paddle.zeros(
shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
for i in range(0, b):
cur_length = int(out_length[i])
output[start:start + cur_length] = out_res[0:cur_length, i, :]
start += cur_length
net_out = output
length = out_length
else: # train mode
net_out = preds[0]
length = length
net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
text = []
if not isinstance(net_out, paddle.Tensor):
net_out = paddle.to_tensor(net_out, dtype='float32')
net_out = F.softmax(net_out, axis=1)
for i in range(0, length.shape[0]):
preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
) + length[i])].topk(1)[1][:, 0].tolist()
preds_text = ''.join([
self.character[idx - 1]
if idx > 0 and idx <= len(self.character) else ''
for idx in preds_idx
])
preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
) + length[i])].topk(1)[0][:, 0]
preds_prob = paddle.exp(
paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
text.append((preds_text, preds_prob))
if label is None:
return text
label = self.decode(label)
return text, label
......@@ -53,6 +53,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
is_float16 = False
if model_type == 'vqa':
# NOTE: for vqa model, resume training is not supported now
......@@ -100,6 +101,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
key, params.keys()))
continue
pre_value = params[key]
if pre_value.dtype == paddle.float16:
pre_value = pre_value.astype(paddle.float32)
is_float16 = True
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
......@@ -107,7 +111,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
if is_float16:
logger.info(
"The parameter type is float16, which is converted to float32 when loading"
)
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
......@@ -126,9 +133,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_pretrained_params(model, pretrained_model)
is_float16 = load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
best_model_dict['is_float16'] = is_float16
return best_model_dict
......@@ -142,19 +150,28 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
is_float16 = False
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
if params[k1].dtype == paddle.float16:
params[k1] = params[k1].astype(paddle.float32)
is_float16 = True
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict)
if is_float16:
logger.info(
"The parameter type is float16, which is converted to float32 when loading"
)
logger.info("load pretrain successful from {}".format(path))
return model
return is_float16
def save_model(model,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppstructure.utility import parse_args
from picodet_postprocess import PicoDetPostProcess
logger = get_logger()
class LayoutPredictor(object):
def __init__(self, args):
pre_process_list = [{
'Resize': {
'size': [800, 608]
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image']
}
}]
postprocess_params = {
'name': 'PicoDetPostProcess',
"layout_dict_path": args.layout_dict_path,
"score_threshold": args.layout_score_threshold,
"nms_threshold": args.layout_nms_threshold,
}
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'layout', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
preds, elapse = 0, 1
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
np_score_list, np_boxes_list = [], []
output_names = self.predictor.get_output_names()
num_outs = int(len(output_names) / 2)
for out_idx in range(num_outs):
np_score_list.append(
self.predictor.get_output_handle(output_names[out_idx])
.copy_to_cpu())
np_boxes_list.append(
self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu())
preds = dict(boxes=np_score_list, boxes_num=np_boxes_list)
post_preds = self.postprocess_op(ori_im, img, preds)
elapse = time.time() - starttime
return post_preds, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
layout_predictor = LayoutPredictor(args)
count = 0
total_time = 0
repeats = 50
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
layout_res, elapse = layout_predictor(img)
logger.info("result: {}".format(layout_res))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
......@@ -32,15 +32,18 @@ def init_args():
type=str,
default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout
parser.add_argument("--layout_model_dir", type=str)
parser.add_argument(
"--layout_path_model",
"--layout_dict_path",
type=str,
default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
default="../ppocr/utils/dict/layout_pubalynet_dict.txt")
parser.add_argument(
"--layout_label_map",
type=ast.literal_eval,
default=None,
help='label map according to ppstructure/layout/README_ch.md')
"--layout_score_threshold",
type=float,
default=0.5,
help="Threshold of score.")
parser.add_argument(
"--layout_nms_threshold", type=float, default=0.5, help="Threshold of nms.")
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
......@@ -87,7 +90,7 @@ def draw_structure_result(image, result, font_path):
image = Image.fromarray(image)
boxes, txts, scores = [], [], []
for region in result:
if region['type'] == 'Table':
if region['type'] == 'table':
pass
else:
for text_result in region['res']:
......
......@@ -8,7 +8,7 @@ Global:
# evaluation is run every 835 iterations
eval_batch_step: [0, 4000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
......
......@@ -6,14 +6,14 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=8|whole_train_whole_infer=8
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
Architecture.Backbone.checkpoints:null
train_model_name:latest
train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
......@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints:
norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o
norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o
quant_export:
fpgm_export:
distill_export:null
......
......@@ -108,7 +108,7 @@ if [ ${MODE} = "benchmark_train" ];then
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
pip install -r ppstructure/vqa/requirements.txt
pip install paddlenlp\>=2.3.5 --force-reinstall
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
cd ./train_data/ && tar xf XFUND.tar
# expand gt.txt 10 times
......@@ -222,7 +222,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
pip install -r ppstructure/vqa/requirements.txt
pip install paddlenlp\>=2.3.5 --force-reinstall
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
cd ./train_data/ && tar xf XFUND.tar
cd ../
......
......@@ -73,7 +73,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "RobustScanner"]
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
......
......@@ -113,6 +113,12 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "VisionLAN":
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
......@@ -233,4 +239,4 @@ def main():
if __name__ == "__main__":
main()
main()
\ No newline at end of file
......@@ -69,6 +69,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "VisionLAN":
postprocess_params = {
'name': 'VLLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == 'ViTSTR':
postprocess_params = {
'name': 'ViTSTRLabelDecode',
......@@ -164,6 +170,16 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_vl(self, img, image_shape):
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
......@@ -287,6 +303,7 @@ class TextRecognizer(object):
img -= mean
img *= stdinv
return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
......@@ -367,6 +384,11 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "VisionLAN":
norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == 'SPIN':
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
......
......@@ -139,7 +139,6 @@ def main():
img_metas = [paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
......
......@@ -154,13 +154,14 @@ def check_xpu(use_xpu):
except Exception as e:
pass
def to_float32(preds):
if isinstance(preds, dict):
for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = preds[k].astype(paddle.float32)
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
......@@ -168,11 +169,12 @@ def to_float32(preds):
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = preds[k].astype(paddle.float32)
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else:
preds = preds.astype(paddle.float32)
preds = paddle.to_tensor(preds, dtype='float32')
return preds
def train(config,
train_dataloader,
valid_dataloader,
......@@ -225,7 +227,9 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "RobustScanner"]
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "RobustScanner"
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
......@@ -267,7 +271,6 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
......@@ -308,6 +311,9 @@ def train(config,
]: # for multi head loss
post_result = post_process_class(
preds['ctc'], batch[1]) # for CTC head out
elif config['Loss']['name'] in ['VLLoss']:
post_result = post_process_class(preds, batch[1],
batch[-1])
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
......@@ -370,7 +376,8 @@ def train(config,
post_process_class,
eval_class,
model_type,
extra_input=extra_input)
extra_input=extra_input,
scaler=scaler)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
......@@ -460,7 +467,8 @@ def eval(model,
post_process_class,
eval_class,
model_type=None,
extra_input=False):
extra_input=False,
scaler=None):
model.eval()
with paddle.no_grad():
total_frame = 0.0
......@@ -477,12 +485,24 @@ def eval(model,
break
images = batch[0]
start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
else:
preds = model(images)
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
......@@ -596,7 +616,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'RobustScanner'
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'RobustScanner'
]
if use_xpu:
......@@ -615,7 +635,7 @@ def preprocess(is_train=False):
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir)
log_writer = VDLLogger(vdl_writer_path)
loggers.append(log_writer)
if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册