提交 aac628b7 编写于 作者: L LDOUBLEV

v2 to v3

上级 c8aa9346
......@@ -22,9 +22,7 @@
### 1. 安装PaddleSlim
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd PaddleSlim
python setup.py install
pip3 install paddleslim==2.2.2
```
### 2. 准备训练好的模型
......@@ -43,7 +41,15 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
tar -xf ch_ppocr_mobile_v2.0_det_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model
```
模型蒸馏和模型量化可以同时使用,以PPOCRv3检测模型为例:
```
# 下载检测预训练模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3_det/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model='./ch_PP-OCRv3_det_distill_train/best_accuracy' Global.save_model_dir=./output/quant_model_distill/
```
如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。
......
......@@ -25,9 +25,7 @@ After training, if you want to further compress the model size and accelerate th
### 1. Install PaddleSlim
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd PaddlSlim
python setup.py install
pip3 install paddleslim==2.2.2
```
......@@ -52,6 +50,17 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3
```
Model distillation and model quantization can be used at the same time, taking the PPOCRv3 detection model as an example:
```
# download provided model
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3_det/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model='./ch_PP-OCRv3_det_distill_train/best_accuracy' Global.save_model_dir=./output/quant_model_distill/
```
If you want to quantify the text recognition model, you can modify the configuration file and loaded model parameters.
### 4. Export inference model
Once we got the model after pruning and fine-tuning, we can export it as an inference model for the deployment of predictive tasks:
......
......@@ -305,10 +305,9 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
<a name="22"></a>
### 2.2 检测配置文件解析
检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件:
- ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
- ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
- ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法
检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv3/目录下,包含两个个蒸馏配置文件:
- ch_PP-OCRv3_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
- ch_PP-OCRv3_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
<a name="221"></a>
#### 2.2.1 模型结构
......@@ -321,44 +320,44 @@ Architecture:
algorithm: Distillation # 算法名称
Models: # 模型,包含子网络的配置信息
Student: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false # 是否需要固定参数
return_all_feats: false # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
name: ResNet
in_channels: 3
layers: 50
Neck:
name: DBFPN
out_channels: 96
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Teacher: # 另外一个子网络,这里给的是普通大模型蒸小模型的蒸馏示例,
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true # Teacher模型是训练好的,不需要参与训练,freeze_params设置为True
Teacher: # 另外一个子网络,这里给的是DML蒸馏示例,
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
in_channels: 3
layers: 50
Neck:
name: DBFPN
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
```
如果是采用DML,即两个小模型互相学习的方法,上述配置文件里的Teacher网络结构需要设置为Student模型一样的配置,具体参考配置文件[ch_PP-OCRv2_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml)
如果是采用DML,即两个小模型互相学习的方法,上述配置文件里的Teacher网络结构需要设置为Student模型一样的配置,具体参考配置文件[ch_PP-OCRv3_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml)

下面介绍[ch_PP-OCRv2_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)的配置文件参数:
下面介绍[ch_PP-OCRv3_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml)的配置文件参数:
```
Architecture:
......@@ -375,12 +374,14 @@ Architecture:
Transform:
Backbone:
name: ResNet
layers: 18
in_channels: 3
layers: 50
Neck:
name: DBFPN
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Student: # CML蒸馏的Student模型配置
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
......@@ -392,10 +393,11 @@ Architecture:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
disable_se: true
Neck:
name: DBFPN
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
......@@ -410,10 +412,11 @@ Architecture:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
disable_se: true
Neck:
name: DBFPN
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
......@@ -445,34 +448,7 @@ Architecture:
<a name="222"></a>
#### 2.2.2 损失函数
知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。
```yaml
Loss:
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
- DistillationDilaDBLoss: # 基于蒸馏的DB损失函数,继承自标准的DBloss
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_pairs: # 对于蒸馏模型的预测结果,提取这两个子网络的输出,计算Teacher模型和Student模型输出的loss
- ["Student", "Teacher"]
key: maps # 取子网络输出dict中,该key对应的tensor
balance_loss: true # 以下几个参数为标准DBloss的配置参数
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDBLoss: # 基于蒸馏的DB损失函数,继承自标准的DBloss,用于计算Student和GT之间的loss
weight: 1.0
model_name_list: ["Student"] # 模型名字只有Student,表示计算Student和GT之间的loss
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
```
同理,检测ch_PP-OCRv2_det_cml.yml蒸馏损失函数配置如下所示。相比较于ch_PP-OCRv2_det_distill.yml的损失函数配置,cml蒸馏的损失函数配置做了3个改动:
检测ch_PP-OCRv3_det_cml.yml蒸馏损失函数配置如下所示。相比较于ch_PP-OCRv3_det_distill.yml的损失函数配置,cml蒸馏的损失函数配置做了3个改动:
```yaml
Loss:
name: CombinedLoss
......@@ -545,26 +521,25 @@ Metric:
<a name="225"></a>
#### 2.2.5 检测蒸馏模型finetune
检测蒸馏有三种方式:
- 采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上大约有1.7%的精度提升。
PP-OCRv3检测蒸馏有两种方式:
- 采用ch_PP-OCRv3_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv3_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上相比单独训练Student模型有1%-2%的提升。
在具体fine-tune时,需要在网络结构的`pretrained`参数中设置要加载的预训练模型。
在精度提升方面,cml的精度>dml的精度>distill蒸馏方法的精度。当数据量不足或者Teacher模型精度与Student精度相差不大的时候,这个结论或许会改变。
在精度提升方面,cml的精度>dml的精度蒸馏方法的精度。当数据量不足或者Teacher模型精度与Student精度相差不大的时候,这个结论或许会改变。
另外,由于PaddleOCR提供的蒸馏预训练模型包含了多个模型的参数,如果您希望提取Student模型的参数,可以参考如下代码:
```
# 下载蒸馏训练模型的参数
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv3_det_distill_train.tar
```
```python
import paddle
# 加载预训练模型
all_params = paddle.load("ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams")
all_params = paddle.load("ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 学生模型的权重提取
......@@ -572,7 +547,7 @@ s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Stu
# 查看学生模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "ch_PP-OCRv2_det_distill_train/student.pdparams")
paddle.save(s_params, "ch_PP-OCRv3_det_distill_train/student.pdparams")
```
最终`Student`模型的参数将会保存在`ch_PP-OCRv2_det_distill_train/student.pdparams`中,用于模型的fine-tune。
最终`Student`模型的参数将会保存在`ch_PP-OCRv3_det_distill_train/student.pdparams`中,用于模型的fine-tune。
......@@ -319,11 +319,10 @@ After the extraction is complete, use [ch_PP-OCRv2_rec.yml](../../configs/rec/ch
<a name="22"></a>
### 2.2 Detection Model Configuration File Analysis
The configuration file of the detection model distillation is in the ```PaddleOCR/configs/det/ch_PP-OCRv2/``` directory, which contains three distillation configuration files:
The configuration file of the detection model distillation is in the ```PaddleOCR/configs/det/ch_PP-OCRv3/``` directory, which contains three distillation configuration files:
- ```ch_PP-OCRv2_det_cml.yml```, Use one large model to distill two small models, and the two small models learn from each other
- ```ch_PP-OCRv2_det_dml.yml```, Method of mutual distillation of two student models
- ```ch_PP-OCRv2_det_distill.yml```, The method of using large teacher model to distill small student model
- ```ch_PP-OCRv3_det_cml.yml```, Use one large model to distill two small models, and the two small models learn from each other
- ```ch_PP-OCRv3_det_dml.yml```, Method of mutual distillation of two student models
<a name="221"></a>
#### 2.2.1 Model Structure
......@@ -341,39 +340,40 @@ Architecture:
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
name: ResNet
in_channels: 3
layers: 50
Neck:
name: DBFPN
out_channels: 96
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Teacher: # Another sub-network, here is a distillation example of a large model distill a small model
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true # The Teacher model is well-trained and does not need to participate in training
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
in_channels: 3
layers: 50
Neck:
name: DBFPN
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
```
If DML is used, that is, the method of two small models learning from each other, the Teacher network structure in the above configuration file needs to be set to the same configuration as the Student model.
Refer to the configuration file for details. [ch_PP-OCRv2_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml)
Refer to the configuration file for details. [ch_PP-OCRv3_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml)
The following describes the configuration file parameters [ch_PP-OCRv2_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml):
The following describes the configuration file parameters [ch_PP-OCRv3_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml):
```
Architecture:
......@@ -390,12 +390,14 @@ Architecture:
Transform:
Backbone:
name: ResNet
layers: 18
in_channels: 3
layers: 50
Neck:
name: DBFPN
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Student: # Student model configuration for CML distillation
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
......@@ -407,10 +409,11 @@ Architecture:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
disable_se: true
Neck:
name: DBFPN
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
......@@ -425,10 +428,11 @@ Architecture:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
disable_se: true
Neck:
name: DBFPN
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
......@@ -460,34 +464,7 @@ The key contains `backbone_out`, `neck_out`, `head_out`, and `value` is the tens
<a name="222"></a>
#### 2.2.2 Loss Function
In the task of detection knowledge distillation ```ch_PP-OCRv2_det_distill.yml````, the distillation loss function configuration is as follows.
```yaml
Loss:
name: CombinedLoss # Loss function name
loss_config_list: # List of loss function configuration files, mandatory functions for CombinedLoss
- DistillationDilaDBLoss: # DB loss function based on distillation, inherited from standard DBloss
weight: 1.0 # The weight of the loss function. In loss_config_list, each loss function must include this field
model_name_pairs: # Extract the output of these two sub-networks and calculate the loss between them
- ["Student", "Teacher"]
key: maps # In the sub-network output dict, take the corresponding tensor
balance_loss: true # The following parameters are the configuration parameters of standard DBloss
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDBLoss: # Used to calculate the loss between Student and GT
weight: 1.0
model_name_list: ["Student"] # The model name only has Student, which means that the loss between Student and GT is calculated
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
```
Similarly, distillation loss function configuration(`ch_PP-OCRv2_det_cml.yml`) is shown below. Compared with the loss function configuration of ch_PP-OCRv2_det_distill.yml, there are three changes:
The distillation loss function configuration(`ch_PP-OCRv3_det_cml.yml`) is shown below. Compared with the loss function configuration of ch_PP-OCRv3_det_distill.yml, there are three changes:
```yaml
Loss:
name: CombinedLoss
......@@ -530,7 +507,7 @@ In the task of detecting knowledge distillation, the post-processing configurati
```yaml
PostProcess:
name: DistillationDBPostProcess # The CTC decoding post-processing of the DB detection distillation task, inherited from the standard DBPostProcess class
name: DistillationDBPostProcess # The post-processing of the DB detection distillation task, inherited from the standard DBPostProcess class
model_name: ["Student", "Student2", "Teacher"] # Extract the output of multiple sub-networks and decode them. The network that does not require post-processing is not set in model_name
thresh: 0.3
box_thresh: 0.6
......@@ -561,9 +538,9 @@ Model Structure
#### 2.2.5 Fine-tuning Distillation Model
There are three ways to fine-tune the detection distillation task:
- `ch_PP-OCRv2_det_distill.yml`, The teacher model is set to the model provided by PaddleOCR or the large model you have trained.
- `ch_PP-OCRv2_det_cml.yml`, Use cml distillation. Similarly, the Teacher model is set to the model provided by PaddleOCR or the large model you have trained.
- `ch_PP-OCRv2_det_dml.yml`, Distillation using DML. The method of mutual distillation of the two Student models has an accuracy improvement of about 1.7% on the data set used by PaddleOCR.
- `ch_PP-OCRv3_det_distill.yml`, The teacher model is set to the model provided by PaddleOCR or the large model you have trained.
- `ch_PP-OCRv3_det_cml.yml`, Use cml distillation. Similarly, the Teacher model is set to the model provided by PaddleOCR or the large model you have trained.
- `ch_PP-OCRv3_det_dml.yml`, Distillation using DML. The method of mutual distillation of the two Student models has an accuracy improvement of about 1.7% on the data set used by PaddleOCR.
In fine-tune, you need to set the pre-trained model to be loaded in the `pretrained` parameter of the network structure.
......@@ -572,13 +549,13 @@ In terms of accuracy improvement, `cml` > `dml` > `distill`. When the amount of
In addition, since the distillation pre-training model provided by PaddleOCR contains multiple model parameters, if you want to extract the parameters of the student model, you can refer to the following code:
```sh
# Download the parameters of the distillation training model
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
```
```python
import paddle
# Load the pre-trained model
all_params = paddle.load("ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams")
all_params = paddle.load("ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams")
# View the keys of the weight parameter
print(all_params.keys())
# Extract the weights of the student model
......@@ -586,7 +563,7 @@ s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Stu
# View the keys of the weight parameters of the student model
print(s_params.keys())
# Save
paddle.save(s_params, "ch_PP-OCRv2_det_distill_train/student.pdparams")
paddle.save(s_params, "ch_PP-OCRv3_det_distill_train/student.pdparams")
```
Finally, the parameters of the student model will be saved in `ch_PP-OCRv2_det_distill_train/student.pdparams` for the fine-tune of the model.
Finally, the parameters of the student model will be saved in `ch_PP-OCRv3_det_distill_train/student.pdparams` for the fine-tune of the model.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册