README.md 5.2 KB
Newer Older
Y
yukavio 已提交
1

L
LDOUBLEV 已提交
2
## 介绍
Y
yukavio 已提交
3

L
LDOUBLEV 已提交
4
复杂的模型有利于提高模型的性能,但也导致模型中存在一定冗余,模型裁剪通过移出网络模型中的子模型来减少这种冗余,达到减少模型计算复杂度,提高模型推理性能的目的。
Y
yukavio 已提交
5

Y
yukavio 已提交
6 7
本教程将介绍如何使用飞桨模型压缩库PaddleSlim做PaddleOCR模型的压缩。
PaddleSlim(项目链接:https://github.com/PaddlePaddle/PaddleSlim)集成了模型剪枝、量化(包括量化训练和离线量化)、蒸馏和神经网络搜索等多种业界常用且领先的模型压缩功能,如果您感兴趣,可以关注并了解。
Y
yukavio 已提交
8

L
LDOUBLEV 已提交
9 10 11 12
在开始本教程之前,建议先了解
1. [PaddleOCR模型的训练方法](../../../doc/doc_ch/quickstart.md)
2. [分类模型裁剪教程](https://paddlepaddle.github.io/PaddleSlim/tutorials/pruning_tutorial/)
3. [PaddleSlim 裁剪压缩API](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)
Y
yukavio 已提交
13 14


L
LDOUBLEV 已提交
15
## 快速开始
Y
yukavio 已提交
16

L
LDOUBLEV 已提交
17 18 19 20 21 22
模型裁剪主要包括五个步骤:
1. 安装 PaddleSlim
2. 准备训练好的模型
3. 敏感度分析、训练
4. 模型裁剪训练
5. 导出模型、预测部署
Y
yukavio 已提交
23

L
LDOUBLEV 已提交
24
### 1. 安装PaddleSlim
Y
yukavio 已提交
25

Y
yukavio 已提交
26
```bash
Y
yukavio 已提交
27 28 29
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd Paddleslim
python setup.py install
Y
yukavio 已提交
30
```
Y
yukavio 已提交
31

L
LDOUBLEV 已提交
32 33
### 2. 获取预训练模型
模型裁剪需要加载事先训练好的模型,PaddleOCR也提供了一系列模型[../../../doc/doc_ch/models_list.md],开发者可根据需要自行选择模型或使用自己的模型。
Y
yukavio 已提交
34

L
LDOUBLEV 已提交
35
### 3. 敏感度分析训练
Y
yukavio 已提交
36

Y
yukavio 已提交
37
加载预训练模型后,通过对现有模型的每个网络层进行敏感度分析,得到敏感度文件:sensitivities_0.data,可以通过PaddleSlim提供的[接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L221)加载文件,获得各网络层在不同裁剪比例下的精度损失。从而了解各网络层冗余度,决定每个网络层的裁剪比例。
L
LDOUBLEV 已提交
38
敏感度分析的具体细节见:[敏感度分析](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/image_classification_sensitivity_analysis_tutorial.md)
39 40 41 42 43 44 45 46 47 48 49
敏感度文件内容格式:
    sensitivities_0.data(Dict){
            'layer_weight_name_0': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
            'layer_weight_name_1': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
        }

    例子:
        {
            'conv10_expand_weights': {0.1: 0.006509952684312718, 0.2: 0.01827734339798862, 0.3: 0.014528405644659832, 0.6: 0.06536008804270439, 0.8: 0.11798612250664964, 0.7: 0.12391408417493704, 0.4: 0.030615754498018757, 0.5: 0.047105205602406594}
            'conv10_linear_weights': {0.1: 0.05113190831455035, 0.2: 0.07705573833558801, 0.3: 0.12096721757739311, 0.6: 0.5135061352930738, 0.8: 0.7908166677143281, 0.7: 0.7272187676899062, 0.4: 0.1819252083008504, 0.5: 0.3728054727792405}
        }
Y
yukavio 已提交
50
加载敏感度文件后会返回一个字典,字典中的keys为网络模型参数模型的名字,values为一个字典,里面保存了相应网络层的裁剪敏感度信息。例如在例子中,conv10_expand_weights所对应的网络层在裁掉10%的卷积核后模型性能相较原模型会下降0.65%,详细信息可见[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/algo/algo.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86)
Y
yukavio 已提交
51

L
LDOUBLEV 已提交
52
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析训练:
Y
yukavio 已提交
53
```bash
L
LDOUBLEV 已提交
54
python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights="your trained model" Global.test_batch_size_per_card=1
Y
yukavio 已提交
55
```
Y
yukavio 已提交
56

L
LDOUBLEV 已提交
57
### 4. 模型裁剪训练
58
裁剪时通过之前的敏感度分析文件决定每个网络层的裁剪比例。在具体实现时,为了尽可能多的保留从图像中提取的低阶特征,我们跳过了backbone中靠近输入的4个卷积层。同样,为了减少由于裁剪导致的模型性能损失,我们通过之前敏感度分析所获得的敏感度表,人工挑选出了一些冗余较少,对裁剪较为敏感的[网络层](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/slim/prune/pruning_and_finetune.py#L41)(指在较低的裁剪比例下就导致很高性能损失的网络层),并在之后的裁剪过程中选择避开这些网络层。裁剪过后finetune的过程沿用OCR检测模型原始的训练策略。
Y
yukavio 已提交
59

Y
yukavio 已提交
60
```bash
Y
yukavio 已提交
61
python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
Y
yukavio 已提交
62
```
L
LDOUBLEV 已提交
63
通过对比可以发现,经过裁剪训练保存的模型更小。
Y
yukavio 已提交
64

L
LDOUBLEV 已提交
65
### 5. 导出模型、预测部署
Y
yukavio 已提交
66

L
LDOUBLEV 已提交
67
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model:
Y
yukavio 已提交
68
```bash
Y
yukavio 已提交
69
python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model
Y
yukavio 已提交
70
```
L
LDOUBLEV 已提交
71 72 73 74 75

inference model的预测和部署参考:
1. [inference model python端预测](../../../doc/doc_ch/inference.md)
2. [inference model C++预测](../../cpp_infer/readme.md)
3. [inference model在移动端部署](../../lite/readme.md)