Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
ed8c2afa
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1534
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ed8c2afa
编写于
3年前
作者:
D
Double_V
提交者:
GitHub
3年前
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5191 from WZMIAOMIAO/dygraph
update knowledge_distillation to dygraph branch
上级
a031e333
e77b1f0f
变更
2
展开全部
隐藏空白更改
内联
并排
Showing
2 changed file
with
631 addition
and
15 deletion
+631
-15
doc/doc_ch/knowledge_distillation.md
doc/doc_ch/knowledge_distillation.md
+39
-15
doc/doc_en/knowledge_distillation_en.md
doc/doc_en/knowledge_distillation_en.md
+592
-0
未找到文件。
doc/doc_ch/knowledge_distillation.md
浏览文件 @
ed8c2afa
<a
name=
"0"
></a>
# 知识蒸馏
+
[
知识蒸馏
](
#0
)
+
[
1. 简介
](
#1
)
-
[
1.1 知识蒸馏介绍
](
#11
)
-
[
1.2 PaddleOCR知识蒸馏简介
](
#12
)
+
[
2. 配置文件解析
](
#2
)
+
[
2.1 识别配置文件解析
](
#21
)
-
[
2.1.1 模型结构
](
#211
)
-
[
2.1.2 损失函数
](
#212
)
-
[
2.1.3 后处理
](
#213
)
-
[
2.1.4 指标计算
](
#214
)
-
[
2.1.5 蒸馏模型微调
](
#215
)
+
[
2.2 检测配置文件解析
](
#22
)
-
[
2.2.1 模型结构
](
#221
)
-
[
2.2.2 损失函数
](
#222
)
-
[
2.2.3 后处理
](
#223
)
-
[
2.2.4 蒸馏指标计算
](
#224
)
-
[
2.2.5 检测蒸馏模型Fine-tune
](
#225
)
<a
name=
"1"
></a>
## 1. 简介
<a
name=
"11"
></a>
### 1.1 知识蒸馏介绍
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
...
...
@@ -13,6 +32,7 @@
此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文
[
Deep Mutual Learning
](
https://arxiv.org/abs/1706.00384
)
中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。
<a
name=
"12"
></a>
### 1.2 PaddleOCR知识蒸馏简介
无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
...
...
@@ -30,17 +50,19 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。
<a
name=
"2"
></a>
## 2. 配置文件解析
在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。
下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。
<a
name=
"21"
></a>
### 2.1 识别配置文件解析
配置文件在
[
ch_PP-OCRv2_rec_distillation.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
)
。
<a
name=
"211"
></a>
#### 2.1.1 模型结构
知识蒸馏任务中,模型结构配置如下所示。
...
...
@@ -176,6 +198,7 @@ Architecture:
}
```
<a
name=
"212"
></a>
#### 2.1.2 损失函数
知识蒸馏任务中,损失函数配置如下所示。
...
...
@@ -212,7 +235,7 @@ Loss:
关于
`CombinedLoss`
更加具体的实现可以参考:
[
combined_loss.py
](
../../ppocr/losses/combined_loss.py#L23
)
。关于
`DistillationCTCLoss`
等蒸馏损失函数更加具体的实现可以参考
[
distillation_loss.py
](
../../ppocr/losses/distillation_loss.py
)
。
<a
name=
"213"
></a>
#### 2.1.3 后处理
知识蒸馏任务中,后处理配置如下所示。
...
...
@@ -228,7 +251,7 @@ PostProcess:
关于
`DistillationCTCLabelDecode`
更加具体的实现可以参考:
[
rec_postprocess.py
](
../../ppocr/postprocess/rec_postprocess.py#L128
)
<a
name=
"214"
></a>
#### 2.1.4 指标计算
知识蒸馏任务中,指标计算配置如下所示。
...
...
@@ -245,7 +268,7 @@ Metric:
关于
`DistillationMetric`
更加具体的实现可以参考:
[
distillation_metric.py
](
../../ppocr/metrics/distillation_metric.py#L24
)
。
<a
name=
"215"
></a>
#### 2.1.5 蒸馏模型微调
对蒸馏得到的识别蒸馏进行微调有2种方式。
...
...
@@ -279,15 +302,15 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
转化完成之后,使用
[
ch_PP-OCRv2_rec.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
)
,修改预训练模型的路径(为导出的
`student.pdparams`
模型路径)以及自己的数据路径,即可进行模型微调。
<a
name=
"22"
></a>
### 2.2 检测配置文件解析
检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件:
-
ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
-
ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
-
ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法
<a
name=
"221"
></a>
#### 2.2.1 模型结构
知识蒸馏任务中,模型结构配置如下所示:
...
...
@@ -419,7 +442,8 @@ Architecture:
}
```
#### 2.1.2 损失函数
<a
name=
"222"
></a>
#### 2.2.2 损失函数
知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。
...
...
@@ -484,8 +508,8 @@ Loss:
关于
`DistillationDilaDBLoss`
更加具体的实现可以参考:
[
distillation_loss.py
](
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/distillation_loss.py#L185
)
。关于
`DistillationDBLoss`
等蒸馏损失函数更加具体的实现可以参考
[
distillation_loss.py
](
https://github.com/PaddlePaddle/PaddleOCR/blob/04c44974b13163450dfb6bd2c327863f8a194b3c/ppocr/losses/distillation_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L148
)
。
#### 2.
1
.3 后处理
<a
name=
"223"
></a>
#### 2.
2
.3 后处理
知识蒸馏任务中,检测蒸馏后处理配置如下所示。
...
...
@@ -503,8 +527,8 @@ PostProcess:
关于
`DistillationDBPostProcess`
更加具体的实现可以参考:
[
db_postprocess.py
](
../../ppocr/postprocess/db_postprocess.py#L195
)
#### 2.
1
.4 蒸馏指标计算
<a
name=
"224"
></a>
#### 2.
2
.4 蒸馏指标计算
知识蒸馏任务中,检测蒸馏指标计算配置如下所示。
...
...
@@ -518,8 +542,8 @@ Metric:
由于蒸馏需要包含多个网络,甚至多个Student网络,在计算指标的时候只需要计算一个Student网络的指标即可,
`key`
字段设置为
`Student`
则表示只计算
`Student`
网络的精度。
#### 2.
1
.5 检测蒸馏模型finetune
<a
name=
"225"
></a>
#### 2.
2
.5 检测蒸馏模型finetune
检测蒸馏有三种方式:
-
采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
...
...
This diff is collapsed.
Click to expand it.
doc/doc_en/knowledge_distillation_en.md
0 → 100755
浏览文件 @
ed8c2afa
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部