未验证 提交 dd544b87 编写于 作者: C Chang Xu 提交者: GitHub

Add OCR Demo in ACT (#1430)

上级 9910aa5a
# OCR模型自动压缩示例
目录:
- [1. 简介](#1简介)
- [2. Benchmark](#2Benchmark)
- [3. 自动压缩流程](#自动压缩流程)
- [3.1 准备环境](#31-准备准备)
- [3.2 准备数据集](#32-准备数据集)
- [3.3 准备预测模型](#33-准备预测模型)
- [3.4 自动压缩并产出模型](#34-自动压缩并产出模型)
- [4. 预测部署](#4预测部署)
- [4.1 Python预测推理](#41-Python预测推理)
- [4.2 PaddleLite端侧部署](#42-PaddleLite端侧部署)
- [5. FAQ](5FAQ)
## 1. 简介
本示例将以图像分类模型PPOCRV3为例,介绍如何使用PaddleOCR中Inference部署模型进行自动压缩。本示例使用的自动压缩策略为量化训练和蒸馏。
## 2. Benchmark
| 模型 | 策略 | Metric | GPU 耗时(ms) | ARM CPU 耗时(ms) | 配置文件 | Inference模型 |
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
| 中文PPOCRV3-det | Baseline | 84.57 | - | - | - | [Model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar) |
| 中文PPOCRV3-det | 量化+蒸馏 | 83.4 | - | - | [Config](./configs/ppocrv3_det_qat_dist.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/PPOCRV3_det_QAT.tar) |
## 3. 自动压缩流程
#### 3.1 准备环境
- python >= 3.6
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim >= 2.3
安装paddlepaddle:
```shell
# CPU
pip install paddlepaddle
# GPU
pip install paddlepaddle-gpu
```
安装paddleslim:
```shell
pip install paddleslim
```
#### 3.2 准备数据集
公开数据集可参考[OCR数据集](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/dataset/ocr_datasets.md)
注意:使用不同的数据集需要修改配置文件中`dataset`中数据路径和数据处理部分。
#### 3.3 准备预测模型
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
注:其他像`__model__``__params__`分别对应`model.pdmodel``model.pdiparams`文件。
可在[PaddleOCR模型库](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/models_list.md)中直接获取Inference模型,具体可参考下方获取中文PPOCRV3检测模型示例:
```shell
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar -xf ch_PP-OCRv3_det_infer.tar
```
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口 ```paddleslim.auto_compression.AutoCompression``` 对模型进行量化训练和蒸馏。配置config文件中模型路径、数据集路径、蒸馏、量化和训练等部分的参数,配置完成后便可开始自动压缩。
**单卡启动**
```shell
export CUDA_VISIBLE_DEVICES=0
python run.py --save_dir='./save_quant_ppocr_det/' --config_path='./configs/ppocrv3_det_qat_dist.yaml'
```
**多卡启动**
若训练任务中包含大量训练数据,如果使用单卡训练,会非常耗时,使用分布式训练可以达到几乎线性的加速比。
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch run.py --save_dir='./save_quant_ppocr_det/' --config_path='./configs/ppocrv3_det_qat_dist.yaml'
```
多卡训练指的是将训练任务按照一定方法拆分到多个训练节点完成数据读取、前向计算、反向梯度计算等过程,并将计算出的梯度上传至服务节点。服务节点在收到所有训练节点传来的梯度后,会将梯度聚合并更新参数。最后将参数发送给训练节点,开始新一轮的训练。多卡训练一轮训练能训练```batch size * num gpus```的数据,比如单卡的```batch size```为32,单轮训练的数据量即32,而四卡训练的```batch size```为32,单轮训练的数据量为128。
注意 ```learning rate``````batch size``` 呈线性关系,这里单卡 ```batch size``` 8,对应的 ```learning rate``` 为0.00005,那么如果 ```batch size``` 增大4倍改为32,```learning rate``` 也需乘以4;多卡时 ```batch size``` 为8,```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```
**验证精度**
根据训练log可以看到模型验证的精度,若需再次验证精度,修改配置文件```./configs/ppocrv3_det_qat_dist.yaml```中所需验证模型的文件夹路径及模型和参数名称```model_dir, model_filename, params_filename```,然后使用以下命令进行验证:
```shell
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path='./configs/ppocrv3_det_qat_dist.yaml'
```
## 4.预测部署
#### 4.1 Python预测推理
环境配置:若使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python)
Python预测引擎推理可参考[基于Python预测引擎推理](https://github.com/PaddlePaddle/PaddleOCR/blob/9cdab61d909eb595af849db885c257ca8c74cb57/doc/doc_ch/inference_ppocr.md)
#### 4.2 PaddleLite端侧部署
PaddleLite端侧部署可参考:
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleOCR/tree/9cdab61d909eb595af849db885c257ca8c74cb57/deploy/lite)
## 5.FAQ
Global:
model_type: det
model_dir: ch_PP-OCRv3_det_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
algorithm: DB
input_name: 'x'
Distillation:
alpha: 1.0
loss: l2
Quantization:
use_pact: true
activation_bits: 8
is_full_quantize: false
onnx_format: True
activation_quantize_type: moving_average_abs_max
weight_quantize_type: channel_wise_abs_max
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
- depthwise_conv2d
weight_bits: 8
TrainConfig:
epochs: 3
eval_iter: 200
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00005
optimizer_builder:
optimizer:
type: Adam
weight_decay: 5.0e-05
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 960
- 960
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest: null
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import numpy as np
import argparse
from tqdm import tqdm
import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.common import get_logger
from paddleslim.auto_compression import AutoCompression
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
logger = get_logger(__name__, level=logging.INFO)
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default='./image_classification/configs/eval.yaml',
help="path of compression strategy config.")
parser.add_argument(
'--model_dir',
type=str,
default='./ch_PP-OCRv3_det_infer',
help='model directory')
return parser
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
]
def sample_generator(loader):
def __reader__():
for indx, data in enumerate(loader):
images = np.array(data[0])
yield images
return __reader__
def eval():
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
global_config["model_dir"],
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
val_loader = build_dataloader(all_config, 'Eval', devices, logger)
post_process_class = build_post_process(all_config['PostProcess'],
global_config)
eval_class = build_metric(all_config['Metric'])
model_type = global_config['model_type']
extra_input = True if global_config[
'algorithm'] in extra_input_models else False
with tqdm(
total=len(val_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for batch_id, batch in enumerate(val_loader):
images = batch[0]
if extra_input:
preds = exe.run(
val_program,
feed={feed_target_names[0]: images,
'data': batch[1:]},
fetch_list=fetch_targets)
else:
preds = exe.run(val_program,
feed={feed_target_names[0]: images},
fetch_list=fetch_targets)
batch_numpy = []
for item in batch:
batch_numpy.append(np.array(item))
if model_type == 'det':
preds_map = {'maps': preds[0]}
post_result = post_process_class(preds_map, batch_numpy[1])
eval_class(post_result, batch_numpy)
elif model_type == 'rec':
post_result = post_process_class(preds[0], batch_numpy[1])
eval_class(post_result, batch_numpy)
t.update()
metric = eval_class.get_metric()
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
return metric
def main():
global all_config, global_config
all_config = load_slim_config(args.config_path)
global_config = all_config["Global"]
eval()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
args = parser.parse_args()
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
from tqdm import tqdm
import numpy as np
import argparse
import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.common import get_logger
from paddleslim.auto_compression import AutoCompression
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
logger = get_logger(__name__, level=logging.INFO)
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_name):
def gen():
for i, batch in enumerate(reader()):
yield {input_name: batch[0]}
return gen
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
post_process_class = build_post_process(all_config['PostProcess'],
global_config)
eval_class = build_metric(all_config['Metric'])
model_type = global_config['model_type']
with tqdm(
total=len(val_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for batch_id, batch in enumerate(val_loader):
images = batch[0]
preds = exe.run(compiled_test_program,
feed={test_feed_names[0]: images},
fetch_list=test_fetch_list)
batch_numpy = []
for item in batch:
batch_numpy.append(np.array(item))
if model_type == 'det':
preds_map = {'maps': preds[0]}
post_result = post_process_class(preds_map, batch_numpy[1])
eval_class(post_result, batch_numpy)
elif model_type == 'rec':
post_result = post_process_class(preds[0], batch_numpy[1])
eval_class(post_result, batch_numpy)
t.update()
metric = eval_class.get_metric()
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
if model_type == 'det':
return metric['hmean']
elif model_type == 'rec':
return metric['acc']
return metric
def main():
rank_id = paddle.distributed.get_rank()
if args.devices == 'gpu':
place = paddle.CUDAPlace(rank_id)
paddle.set_device('gpu')
else:
place = paddle.CPUPlace()
paddle.set_device('cpu')
global all_config, global_config
all_config = load_slim_config(args.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"]
gpu_num = paddle.distributed.get_world_size()
train_dataloader = build_dataloader(all_config, 'Train', args.devices,
logger)
global val_loader
val_loader = build_dataloader(all_config, 'Eval', args.devices, logger)
if isinstance(all_config['TrainConfig']['learning_rate'],
dict) and all_config['TrainConfig']['learning_rate'][
'type'] == 'CosineAnnealingDecay':
steps = len(train_dataloader) * all_config['TrainConfig']['epochs']
all_config['TrainConfig']['learning_rate']['T_max'] = steps
print('total training steps:', steps)
ac = AutoCompression(
model_dir=global_config['model_dir'],
model_filename=global_config['model_filename'],
params_filename=global_config['params_filename'],
save_dir=args.save_dir,
config=all_config,
train_dataloader=reader_wrapper(train_dataloader,
global_config['input_name']),
eval_callback=eval_function if rank_id == 0 else None,
eval_dataloader=reader_wrapper(val_loader, global_config['input_name']))
ac.compress()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
args = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册