Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
4c0cf753
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4c0cf753
编写于
4月 29, 2022
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
upgrade kd doc
上级
5a08a408
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
137 addition
and
63 deletion
+137
-63
configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
+1
-1
configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
+1
-1
doc/doc_ch/knowledge_distillation.md
doc/doc_ch/knowledge_distillation.md
+127
-58
ppocr/utils/utility.py
ppocr/utils/utility.py
+8
-3
未找到文件。
configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
浏览文件 @
4c0cf753
...
...
@@ -71,7 +71,7 @@ PostProcess:
Metric
:
name
:
RecMetric
main_indicator
:
acc
ignore_space
:
Tru
e
ignore_space
:
Fals
e
Train
:
dataset
:
...
...
configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
浏览文件 @
4c0cf753
...
...
@@ -145,7 +145,7 @@ Metric:
base_metric_name
:
RecMetric
main_indicator
:
acc
key
:
"
Student"
ignore_space
:
Tru
e
ignore_space
:
Fals
e
Train
:
dataset
:
...
...
doc/doc_ch/knowledge_distillation.md
浏览文件 @
4c0cf753
...
...
@@ -60,7 +60,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
<a
name=
"21"
></a>
### 2.1 识别配置文件解析
配置文件在
[
ch_PP-OCRv
2_rec_distillation.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2
_rec_distillation.yml
)
。
配置文件在
[
ch_PP-OCRv
3_rec_distillation.yml
](
../../configs/rec/PP-OCRv3/ch_PP-OCRv3
_rec_distillation.yml
)
。
<a
name=
"211"
></a>
#### 2.1.1 模型结构
...
...
@@ -69,7 +69,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
```
yaml
Architecture
:
model_type
:
&model_type
"
rec"
# 模型类别,rec、det等,每个子网络的的模型
类别都与
model_type
:
&model_type
"
rec"
# 模型类别,rec、det等,每个子网络的的模型
相同
name
:
DistillationModel
# 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
algorithm
:
Distillation
# 算法名称
Models
:
# 模型,包含子网络的配置信息
...
...
@@ -78,37 +78,55 @@ Architecture:
freeze_params
:
false
# 是否需要固定参数
return_all_feats
:
true
# 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type
:
*model_type
# 模型类别
algorithm
:
CRNN
# 子网络的算法名称,该子网络剩余参与
均为构造参数,与普通的模型训练配置一致
algorithm
:
SVTR
# 子网络的算法名称,该子网络其余参数
均为构造参数,与普通的模型训练配置一致
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
Student
:
# 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
pretrained
:
# 下面的组网参数同上
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Student
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
CRNN
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
```
当然,这里如果希望添加更多的子网络进行训练,也可以按照
`Student`
与
`Teacher`
的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么
`Architecture`
可以写为如下格式。
...
...
@@ -124,55 +142,82 @@ Architecture:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
CRNN
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Student
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
CRNN
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
Student2
:
# 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Student2
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
CRNN
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
```
最终该模型训练时,包含3个子网络:
`Teacher`
,
`Student`
,
`Student2`
。
...
...
@@ -205,34 +250,56 @@ Architecture:
```
yaml
Loss
:
name
:
CombinedLoss
# 损失函数名称,基于改名称,构建用于损失函数的类
loss_config_list
:
# 损失函数配置文件列表,为CombinedLoss的必备函数
-
DistillationCTCLoss
:
# 基于蒸馏的CTC损失函数,继承自标准的CTC loss
weight
:
1.0
# 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_list
:
[
"
Student"
,
"
Teacher"
]
# 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
name
:
CombinedLoss
loss_config_list
:
-
DistillationDMLLoss
:
# 蒸馏的DML损失函数,继承自标准的DMLLoss
weight
:
1.0
# 权重
act
:
"
softmax"
# 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
use_log
:
true
# 对输入计算log,如果函数已经
model_name_pairs
:
# 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
# 取子网络输出dict中,该key对应的tensor
multi_head
:
True
# 是否为多头结构,我们
dis_head
:
ctc
# 蒸馏
name
:
dml_ctc
# 蒸馏loss的前缀名称,避免不同loss之间的命名冲突
-
DistillationDMLLoss
:
# 蒸馏的DML损失函数,继承自标准的DMLLoss
weight
:
1.0
# 权重
act
:
"
softmax"
# 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
use_log
:
true
# 对输入计算log,如果函数已经
model_name_pairs
:
# 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
# 取子网络输出dict中,该key对应的tensor
multi_head
:
True
# 是否为多头结构,我们
dis_head
:
sar
# 蒸馏
name
:
dml_sar
# 蒸馏loss的前缀名称,避免不同loss之间的命名冲突
-
DistillationDistanceLoss
:
# 蒸馏的距离损失函数
weight
:
1.0
# 权重
mode
:
"
l2"
# 距离计算方法,目前支持l1, l2, smooth_l1
model_name_pairs
:
# 用于计算distance loss的子网络名称对
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
# 取子网络输出dict中,该key对应的tensor
-
DistillationCTCLoss
:
# 基于蒸馏的CTC损失函数,继承自标准的CTC loss
weight
:
1.0
# 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_list
:
[
"
Student"
,
"
Teacher"
]
# 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
key
:
head_out
# 取子网络输出dict中,该key对应的tensor
-
DistillationSARLoss
:
# 基于蒸馏的SAR损失函数,继承自标准的SARLoss
weight
:
1.0
# 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_list
:
[
"
Student"
,
"
Teacher"
]
# 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
key
:
head_out
# 取子网络输出dict中,该key对应的tensor
multi_head
:
True
# 是否为多头结构,为true时,取出其中的
```
上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
-
`Student`
和
`Teacher`
的最终输出(
`head_out`
)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
-
`Student`
和
`Teacher`
的最终输出(
`head_out`
)之间的DML loss,权重为1。
-
`Student`
和
`Teacher`
最终输出(
`head_out`
)的CTC分支与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
-
`Student`
和
`Teacher`
最终输出(
`head_out`
)的SAR分支与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
-
`Student`
和
`Teacher`
最终输出(
`head_out`
)的CTC分支之间的DML loss,权重为1。
-
`Student`
和
`Teacher`
最终输出(
`head_out`
)SARC分支之间的DML loss,权重为1。
-
`Student`
和
`Teacher`
的骨干网络输出(
`backbone_out`
)之间的l2 loss,权重为1。
关于
`CombinedLoss`
更加具体的实现可以参考:
[
combined_loss.py
](
../../ppocr/losses/combined_loss.py#L23
)
。关于
`DistillationCTCLoss`
等蒸馏损失函数更加具体的实现可以参考
[
distillation_loss.py
](
../../ppocr/losses/distillation_loss.py
)
。
<a
name=
"213"
></a>
...
...
@@ -245,6 +312,7 @@ PostProcess:
name
:
DistillationCTCLabelDecode
# 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
model_name
:
[
"
Student"
,
"
Teacher"
]
# 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码
key
:
head_out
# 取子网络输出dict中,该key对应的tensor
multi_head
:
True
# 多头结构时,会取出其中的CTC分支进行计算
```
以上述配置为例,最终会同时计算
`Student`
和
`Teahcer`
2个子网络的CTC解码输出,返回一个
`dict`
,
`key`
为用于处理的子网络名称,
`value`
为用于处理的子网络列表。
...
...
@@ -262,6 +330,7 @@ Metric:
base_metric_name
:
RecMetric
# 指标计算的基类,对于模型的输出,会基于该类,计算指标
main_indicator
:
acc
# 指标的名称
key
:
"
Student"
# 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准
ignore_space
:
False
# 评估时是否忽略空格的影响
```
以上述配置为例,最终会使用
`Student`
子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。
...
...
@@ -273,15 +342,15 @@ Metric:
对蒸馏得到的识别蒸馏进行微调有2种方式。
(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在
[
ch_PP-OCRv
2_rec_distillation.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2
_rec_distillation.yml
)
中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在
[
ch_PP-OCRv
3_rec_distillation.yml
](
../../configs/rec/PP-OCRv3/ch_PP-OCRv3
_rec_distillation.yml
)
中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。
*
首先下载预训练模型并解压。
```
shell
# 下面预训练模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv
2/chinese/ch_PP-OCRv2
_rec_train.tar
tar
-xf
ch_PP-OCRv
2
_rec_train.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv
3/chinese/ch_PP-OCRv3
_rec_train.tar
tar
-xf
ch_PP-OCRv
3
_rec_train.tar
```
*
然后使用python,对其中的学生模型参数进行提取
...
...
@@ -289,7 +358,7 @@ tar -xf ch_PP-OCRv2_rec_train.tar
```
python
import
paddle
# 加载预训练模型
all_params
=
paddle
.
load
(
"ch_PP-OCRv
2
_rec_train/best_accuracy.pdparams"
)
all_params
=
paddle
.
load
(
"ch_PP-OCRv
3
_rec_train/best_accuracy.pdparams"
)
# 查看权重参数的keys
print
(
all_params
.
keys
())
# 学生模型的权重提取
...
...
@@ -297,10 +366,10 @@ s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Stu
# 查看学生模型权重参数的keys
print
(
s_params
.
keys
())
# 保存
paddle
.
save
(
s_params
,
"ch_PP-OCRv
2
_rec_train/student.pdparams"
)
paddle
.
save
(
s_params
,
"ch_PP-OCRv
3
_rec_train/student.pdparams"
)
```
转化完成之后,使用
[
ch_PP-OCRv
2_rec.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2
_rec.yml
)
,修改预训练模型的路径(为导出的
`student.pdparams`
模型路径)以及自己的数据路径,即可进行模型微调。
转化完成之后,使用
[
ch_PP-OCRv
3_rec.yml
](
../../configs/rec/PP-OCRv3/ch_PP-OCRv3
_rec.yml
)
,修改预训练模型的路径(为导出的
`student.pdparams`
模型路径)以及自己的数据路径,即可进行模型微调。
<a
name=
"22"
></a>
### 2.2 检测配置文件解析
...
...
ppocr/utils/utility.py
浏览文件 @
4c0cf753
...
...
@@ -49,18 +49,23 @@ def get_check_global_params(mode):
return
check_params
def
_check_image_file
(
path
):
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
}
return
any
([
path
.
lower
().
endswith
(
e
)
for
e
in
img_end
])
def
get_image_file_list
(
img_file
):
imgs_lists
=
[]
if
img_file
is
None
or
not
os
.
path
.
exists
(
img_file
):
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'GIF'
}
if
os
.
path
.
isfile
(
img_file
)
and
imghdr
.
what
(
img_file
)
in
img_end
:
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
}
if
os
.
path
.
isfile
(
img_file
)
and
_check_image_file
(
file_path
)
:
imgs_lists
.
append
(
img_file
)
elif
os
.
path
.
isdir
(
img_file
):
for
single_file
in
os
.
listdir
(
img_file
):
file_path
=
os
.
path
.
join
(
img_file
,
single_file
)
if
os
.
path
.
isfile
(
file_path
)
and
imghdr
.
what
(
file_path
)
in
img_end
:
if
os
.
path
.
isfile
(
file_path
)
and
_check_image_file
(
file_path
)
:
imgs_lists
.
append
(
file_path
)
if
len
(
imgs_lists
)
==
0
:
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录