未验证 提交 816471aa 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #2027 from LDOUBLEV/trt_cpp

add prune demo
## 介绍
复杂的模型有利于提高模型的性能,但也导致模型中存在一定冗余,模型裁剪通过移出网络模型中的子模型来减少这种冗余,达到减少模型计算复杂度,提高模型推理性能的目的。
本教程将介绍如何使用飞桨模型压缩库PaddleSlim做PaddleOCR模型的压缩。
[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)集成了模型剪枝、量化(包括量化训练和离线量化)、蒸馏和神经网络搜索等多种业界常用且领先的模型压缩功能,如果您感兴趣,可以关注并了解。
在开始本教程之前,建议先了解:
1. [PaddleOCR模型的训练方法](../../../doc/doc_ch/quickstart.md)
2. [模型裁剪教程](https://github.com/PaddlePaddle/PaddleSlim/blob/release%2F2.0.0/docs/zh_cn/tutorials/pruning/dygraph/filter_pruning.md)
## 快速开始
模型裁剪主要包括四个步骤:
1. 安装 PaddleSlim
2. 准备训练好的模型
3. 敏感度分析、裁剪训练
4. 导出模型、预测部署
### 1. 安装PaddleSlim
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
git checkout develop
cd Paddleslim
python3 setup.py install
```
### 2. 获取预训练模型
模型裁剪需要加载事先训练好的模型,PaddleOCR也提供了一系列(模型)[../../../doc/doc_ch/models_list.md],开发者可根据需要自行选择模型或使用自己的模型。
### 3. 敏感度分析训练
加载预训练模型后,通过对现有模型的每个网络层进行敏感度分析,得到敏感度文件:sen.pickle,可以通过PaddleSlim提供的[接口](https://github.com/PaddlePaddle/PaddleSlim/blob/9b01b195f0c4bc34a1ab434751cb260e13d64d9e/paddleslim/dygraph/prune/filter_pruner.py#L75)加载文件,获得各网络层在不同裁剪比例下的精度损失。从而了解各网络层冗余度,决定每个网络层的裁剪比例。
敏感度文件内容格式:
sen.pickle(Dict){
'layer_weight_name_0': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
'layer_weight_name_1': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
}
例子:
{
'conv10_expand_weights': {0.1: 0.006509952684312718, 0.2: 0.01827734339798862, 0.3: 0.014528405644659832, 0.6: 0.06536008804270439, 0.8: 0.11798612250664964, 0.7: 0.12391408417493704, 0.4: 0.030615754498018757, 0.5: 0.047105205602406594}
'conv10_linear_weights': {0.1: 0.05113190831455035, 0.2: 0.07705573833558801, 0.3: 0.12096721757739311, 0.6: 0.5135061352930738, 0.8: 0.7908166677143281, 0.7: 0.7272187676899062, 0.4: 0.1819252083008504, 0.5: 0.3728054727792405}
}
加载敏感度文件后会返回一个字典,字典中的keys为网络模型参数模型的名字,values为一个字典,里面保存了相应网络层的裁剪敏感度信息。例如在例子中,conv10_expand_weights所对应的网络层在裁掉10%的卷积核后模型性能相较原模型会下降0.65%,详细信息可见[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/algo/algo.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86)
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析训练:
```bash
python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrain_weights="your trained model"
```
### 4. 导出模型、预测部署
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model:
```bash
pytho3.7 deploy/slim/prune/export_prune_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.save_inference_dir=inference_model
```
inference model的预测和部署参考:
1. [inference model python端预测](../../../doc/doc_ch/inference.md)
2. [inference model C++预测](../../cpp_infer/readme.md)
## Introduction
Generally, a more complex model would achive better performance in the task, but it also leads to some redundancy in the model. Model Pruning is a technique that reduces this redundancy by removing the sub-models in the neural network model, so as to reduce model calculation complexity and improve model inference performance.
This example uses PaddleSlim provided[APIs of Pruning](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/) to compress the OCR model.
[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim), an open source library which integrates model pruning, quantization (including quantization training and offline quantization), distillation, neural network architecture search, and many other commonly used and leading model compression technique in the industry.
It is recommended that you could understand following pages before reading this example:
1. [PaddleOCR training methods](../../../doc/doc_ch/quickstart.md)
2. [The demo of prune](https://github.com/PaddlePaddle/PaddleSlim/blob/release%2F2.0.0/docs/zh_cn/tutorials/pruning/dygraph/filter_pruning.md)
## Quick start
Five steps for OCR model prune:
1. Install PaddleSlim
2. Prepare the trained model
3. Sensitivity analysis and tailoring training
4. Export model, predict deployment
### 1. Install PaddleSlim
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
git checkout develop
cd Paddleslim
python3 setup.py install
```
### 2. Download Pretrain Model
Model prune needs to load pre-trained models.
PaddleOCR also provides a series of (models)[../../../doc/doc_en/models_list_en.md]. Developers can choose their own models or use their own models according to their needs.
### 3. Pruning sensitivity analysis
After the pre-training model is loaded, sensitivity analysis is performed on each network layer of the model to understand the redundancy of each network layer, and save a sensitivity file which named: sen.pickle. After that, user could load the sensitivity file via the [methods provided by PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L221) and determining the pruning ratio of each network layer automatically. For specific details of sensitivity analysis, see:[Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/image_classification_sensitivity_analysis_tutorial.md)
The data format of sensitivity file:
sen.pickle(Dict){
'layer_weight_name_0': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
'layer_weight_name_1': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
}
example:
{
'conv10_expand_weights': {0.1: 0.006509952684312718, 0.2: 0.01827734339798862, 0.3: 0.014528405644659832, 0.6: 0.06536008804270439, 0.8: 0.11798612250664964, 0.7: 0.12391408417493704, 0.4: 0.030615754498018757, 0.5: 0.047105205602406594}
'conv10_linear_weights': {0.1: 0.05113190831455035, 0.2: 0.07705573833558801, 0.3: 0.12096721757739311, 0.6: 0.5135061352930738, 0.8: 0.7908166677143281, 0.7: 0.7272187676899062, 0.4: 0.1819252083008504, 0.5: 0.3728054727792405}
}
The function would return a dict after loading the sensitivity file. The keys of the dict are name of parameters in each layer. And the value of key is the information about pruning sensitivity of correspoding layer. In example, pruning 10% filter of the layer corresponding to conv10_expand_weights would lead to 0.65% degradation of model performance. The details could be seen at: [Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/algo/algo.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86)
Enter the PaddleOCR root directory,perform sensitivity analysis on the model with the following command:
```bash
python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrain_weights="your trained model"
```
### 5. Export inference model and deploy it
We can export the pruned model as inference_model for deployment:
```bash
python deploy/slim/prune/export_prune_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model
```
Reference for prediction and deployment of inference model:
1. [inference model python prediction](../../../doc/doc_en/inference_en.md)
2. [inference model C++ prediction](../../cpp_infer/readme_en.md)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import paddle
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
import tools.program as program
def main(config, device, logger, vdl_writer):
global_config = config['Global']
# build dataloader
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
flops = paddle.flops(model, [1, 3, 640, 640])
logger.info(f"FLOPs before pruning: {flops}")
from paddleslim.dygraph import FPGMFilterPruner
model.train()
pruner = FPGMFilterPruner(model, [1, 3, 640, 640])
# build metric
eval_class = build_metric(config['Metric'])
def eval_fn():
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}")
return metric['hmean']
params_sensitive = pruner.sensitive(
eval_func=eval_fn,
sen_file="./sen.pickle",
skip_vars=[
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
])
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle"
)
# calculate pruned params's ratio
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info(f"{key}, {params_sensitive[key]}")
plan = pruner.prune_vars(params_sensitive, [0])
flops = paddle.flops(model, [1, 3, 640, 640])
logger.info(f"FLOPs after pruning: {flops}")
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, None)
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}")
# start export model
from paddle.jit import to_static
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import paddle
import paddle.distributed as dist
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
import tools.program as program
dist.get_world_size()
def get_pruned_params(parameters):
params = []
for param in parameters:
if len(
param.shape
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
params.append(param.name)
return params
def main(config, device, logger, vdl_writer):
# init dist environment
if config['Global']['distributed']:
dist.init_parallel_env()
global_config = config['Global']
# build dataloader
train_dataloader = build_dataloader(config, 'Train', device, logger)
if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
else:
valid_dataloader = None
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
flops = paddle.flops(model, [1, 3, 640, 640])
logger.info(f"FLOPs before pruning: {flops}")
from paddleslim.dygraph import FPGMFilterPruner
model.train()
pruner = FPGMFilterPruner(model, [1, 3, 640, 640])
# build loss
loss_class = build_loss(config['Loss'])
# build optim
optimizer, lr_scheduler = build_optimizer(
config['Optimizer'],
epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader),
parameters=model.parameters())
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
# build metric
eval_class = build_metric(config['Metric'])
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
def eval_fn():
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}")
return metric['hmean']
params_sensitive = pruner.sensitive(
eval_func=eval_fn,
sen_file="./sen.pickle",
skip_vars=[
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
])
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle"
)
# calculate pruned params's ratio
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info(f"{key}, {params_sensitive[key]}")
plan = pruner.prune_vars(params_sensitive, [0])
for param in model.parameters():
if ("weights" in param.name and "conv" in param.name) or (
"w_0" in param.name and "conv2d" in param.name):
logger.info(f"{param.name}: {param.shape}")
flops = paddle.flops(model, [1, 3, 640, 640])
logger.info(f"FLOPs after pruning: {flops}")
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer)
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册