Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
d9b2b7ad
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d9b2b7ad
编写于
12月 27, 2021
作者:
E
Evezerest
提交者:
GitHub
12月 27, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5078 from LDOUBLEV/24_doc
add det distill doc
上级
04c44974
fb33880e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
276 addition
and
1 deletion
+276
-1
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
+3
-0
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
+1
-0
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml
+1
-0
doc/doc_ch/knowledge_distillation.md
doc/doc_ch/knowledge_distillation.md
+271
-1
未找到文件。
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
浏览文件 @
d9b2b7ad
...
@@ -21,6 +21,7 @@ Architecture:
...
@@ -21,6 +21,7 @@ Architecture:
model_type
:
det
model_type
:
det
Models
:
Models
:
Teacher
:
Teacher
:
pretrained
:
./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params
:
true
freeze_params
:
true
return_all_feats
:
false
return_all_feats
:
false
model_type
:
det
model_type
:
det
...
@@ -36,6 +37,7 @@ Architecture:
...
@@ -36,6 +37,7 @@ Architecture:
name
:
DBHead
name
:
DBHead
k
:
50
k
:
50
Student
:
Student
:
pretrained
:
freeze_params
:
false
freeze_params
:
false
return_all_feats
:
false
return_all_feats
:
false
model_type
:
det
model_type
:
det
...
@@ -52,6 +54,7 @@ Architecture:
...
@@ -52,6 +54,7 @@ Architecture:
name
:
DBHead
name
:
DBHead
k
:
50
k
:
50
Student2
:
Student2
:
pretrained
:
freeze_params
:
false
freeze_params
:
false
return_all_feats
:
false
return_all_feats
:
false
model_type
:
det
model_type
:
det
...
...
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
浏览文件 @
d9b2b7ad
...
@@ -18,6 +18,7 @@ Global:
...
@@ -18,6 +18,7 @@ Global:
Architecture
:
Architecture
:
name
:
DistillationModel
name
:
DistillationModel
algorithm
:
Distillation
algorithm
:
Distillation
model_type
:
det
Models
:
Models
:
Student
:
Student
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
...
...
configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml
浏览文件 @
d9b2b7ad
...
@@ -18,6 +18,7 @@ Global:
...
@@ -18,6 +18,7 @@ Global:
Architecture
:
Architecture
:
name
:
DistillationModel
name
:
DistillationModel
algorithm
:
Distillation
algorithm
:
Distillation
model_type
:
det
Models
:
Models
:
Student
:
Student
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
...
...
doc/doc_ch/knowledge_distillation.md
浏览文件 @
d9b2b7ad
...
@@ -279,6 +279,276 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
...
@@ -279,6 +279,276 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
转化完成之后,使用
[
ch_PP-OCRv2_rec.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
)
,修改预训练模型的路径(为导出的
`student.pdparams`
模型路径)以及自己的数据路径,即可进行模型微调。
转化完成之后,使用
[
ch_PP-OCRv2_rec.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
)
,修改预训练模型的路径(为导出的
`student.pdparams`
模型路径)以及自己的数据路径,即可进行模型微调。
### 2.2 检测配置文件解析
### 2.2 检测配置文件解析
*
coming soon!
检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件:
-
ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
-
ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
-
ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法
#### 2.2.1 模型结构
知识蒸馏任务中,模型结构配置如下所示:
```
Architecture:
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
algorithm: Distillation # 算法名称
Models: # 模型,包含子网络的配置信息
Student: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false # 是否需要固定参数
return_all_feats: false # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Teacher: # 另外一个子网络,这里给的是普通大模型蒸小模型的蒸馏示例,
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true # Teacher模型是训练好的,不需要参与训练,freeze_params设置为True
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
```
如果是采用DML,即两个小模型互相学习的方法,上述配置文件里的Teacher网络结构需要设置为Student模型一样的配置,具体参考配置文件
[
ch_PP-OCRv2_det_dml.yml
](
https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml
)
。
下面介绍
[
ch_PP-OCRv2_det_cml.yml
](
https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
)
的配置文件参数:
```
Architecture:
name: DistillationModel
algorithm: Distillation
model_type: det
Models:
Teacher: # CML蒸馏的Teacher模型配置
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true # Teacher 不训练
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Student: # CML蒸馏的Student模型配置
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2: # CML蒸馏的Student2模型配置
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
```
蒸馏模型
`DistillationModel`
类的具体实现代码可以参考
[
distillation_model.py
](
../../ppocr/modeling/architectures/distillation_model.py
)
。
最终模型
`forward`
输出为一个字典,key为所有的子网络名称,例如这里为
`Student`
与
`Teacher`
,value为对应子网络的输出,可以为
`Tensor`
(只返回该网络的最后一层)和
`dict`
(也返回了中间的特征信息)。
在蒸馏任务中,为了方便添加蒸馏损失函数,每个网络的输出保存为
`dict`
,其中包含子模块输出。每个子网络的输出结果均为
`dict`
,key包含
`backbone_out`
,
`neck_out`
,
`head_out`
,
`value`
为对应模块的tensor,最终对于上述配置文件,
`DistillationModel`
的输出格式如下。
```
json
{
"Teacher"
:
{
"backbone_out"
:
tensor
,
"neck_out"
:
tensor
,
"head_out"
:
tensor
,
},
"Student"
:
{
"backbone_out"
:
tensor
,
"neck_out"
:
tensor
,
"head_out"
:
tensor
,
}
}
```
#### 2.1.2 损失函数
知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。
```
yaml
Loss
:
name
:
CombinedLoss
# 损失函数名称,基于改名称,构建用于损失函数的类
loss_config_list
:
# 损失函数配置文件列表,为CombinedLoss的必备函数
-
DistillationDilaDBLoss
:
# 基于蒸馏的DB损失函数,继承自标准的DBloss
weight
:
1.0
# 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_pairs
:
# 对于蒸馏模型的预测结果,提取这两个子网络的输出,计算Teacher模型和Student模型输出的loss
-
[
"
Student"
,
"
Teacher"
]
key
:
maps
# 取子网络输出dict中,该key对应的tensor
balance_loss
:
true
# 以下几个参数为标准DBloss的配置参数
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
-
DistillationDBLoss
:
# 基于蒸馏的DB损失函数,继承自标准的DBloss,用于计算Student和GT之间的loss
weight
:
1.0
model_name_list
:
[
"
Student"
]
# 模型名字只有Student,表示计算Student和GT之间的loss
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
```
同理,检测ch_PP-OCRv2_det_cml.yml蒸馏损失函数配置如下所示。相比较于ch_PP-OCRv2_det_distill.yml的损失函数配置,cml蒸馏的损失函数配置做了3个改动:
```
yaml
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationDilaDBLoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
-
[
"
Student2"
,
"
Teacher"
]
# 改动1,计算两个Student和Teacher的损失
key
:
maps
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
-
DistillationDMLLoss
:
# 改动2,增加计算两个Student之间的损失
model_name_pairs
:
-
[
"
Student"
,
"
Student2"
]
maps_name
:
"
thrink_maps"
weight
:
1.0
# act: None
key
:
maps
-
DistillationDBLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Student2"
]
# 改动3,计算两个Student和GT之间的损失
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
```
关于
`DistillationDilaDBLoss`
更加具体的实现可以参考:
[
distillation_loss.py
](
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/distillation_loss.py#L185
)
。关于
`DistillationDBLoss`
等蒸馏损失函数更加具体的实现可以参考
[
distillation_loss.py
](
https://github.com/PaddlePaddle/PaddleOCR/blob/04c44974b13163450dfb6bd2c327863f8a194b3c/ppocr/losses/distillation_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L148
)
。
#### 2.1.3 后处理
知识蒸馏任务中,检测蒸馏后处理配置如下所示。
```
yaml
PostProcess
:
name
:
DistillationDBPostProcess
# DB检测蒸馏任务的CTC解码后处理,继承自标准的DBPostProcess类
model_name
:
[
"
Student"
,
"
Student2"
,
"
Teacher"
]
# 对于蒸馏模型的预测结果,提取多个子网络的输出,进行解码,不需要后处理的网络可以不在model_name中设置
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
```
以上述配置为例,最终会同时计算
`Student`
,
`Student2`
和
`Teacher`
3个子网络的输出做后处理计算。同时,由于有多个输入,后处理返回的输出也有多个,
关于
`DistillationDBPostProcess`
更加具体的实现可以参考:
[
db_postprocess.py
](
../../ppocr/postprocess/db_postprocess.py#L195
)
#### 2.1.4 蒸馏指标计算
知识蒸馏任务中,检测蒸馏指标计算配置如下所示。
```
yaml
Metric
:
name
:
DistillationMetric
base_metric_name
:
DetMetric
main_indicator
:
hmean
key
:
"
Student"
```
由于蒸馏需要包含多个网络,甚至多个Student网络,在计算指标的时候只需要计算一个Student网络的指标即可,
`key`
字段设置为
`Student`
则表示只计算
`Student`
网络的精度。
#### 2.1.5 检测蒸馏模型finetune
检测蒸馏有三种方式:
-
采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
-
采用ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
-
采用ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上大约有1.7%的精度提升。
在具体finetune时,需要在网络结构的
`pretrained`
参数中设置要加载的预训练模型。
在精度提升方面,cml的精度>dml的精度>distill蒸馏方法的精度。当数据量不足或者Teacher模型精度与Student精度相差不大的时候,这个结论或许会改变。
另外,由于PaddleOCR提供的蒸馏预训练模型包含了多个模型的参数,如果您希望提取Student模型的参数,可以参考如下代码:
```
# 下载蒸馏训练模型的参数
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar
```
```
python
import
paddle
# 加载预训练模型
all_params
=
paddle
.
load
(
"ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams"
)
# 查看权重参数的keys
print
(
all_params
.
keys
())
# 学生模型的权重提取
s_params
=
{
key
[
len
(
"Student."
):]:
all_params
[
key
]
for
key
in
all_params
if
"Student."
in
key
}
# 查看学生模型权重参数的keys
print
(
s_params
.
keys
())
# 保存
paddle
.
save
(
s_params
,
"ch_PP-OCRv2_det_distill_train/student.pdparams"
)
```
最终
`Student`
模型的参数将会保存在
`ch_PP-OCRv2_det_distill_train/student.pdparams`
中,用于模型的fine-tune。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录