Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
0f9c3ceb
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0f9c3ceb
编写于
12月 30, 2019
作者:
B
baiyfbupt
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add distillation tutorial
上级
0d337bce
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
109 addition
and
1 deletion
+109
-1
docs/docs/tutorials/distillation_demo.md
docs/docs/tutorials/distillation_demo.md
+107
-0
docs/mkdocs.yml
docs/mkdocs.yml
+2
-1
未找到文件。
docs/docs/tutorials/distillation_demo.md
0 → 100644
浏览文件 @
0f9c3ceb
本示例将介绍如何使用PaddleSlim蒸馏接口来对模型进行蒸馏训练
## 接口介绍
请参考
[
蒸馏API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/
)
。
## PaddleSlim蒸馏训练流程
一般情况下,模型参数量越多,结构越复杂,其性能越好,但运算量和资源消耗也越大。
**知识蒸馏**
就是一种将大模型学习到的有用信息(Dark Knowledge)压缩进更小更快的模型,而获得可以匹敌大模型结果的方法。
在本示例中精度较高的大模型被称为teacher,精度稍逊但速度更快的小模型被称为student。
### 1. 定义student_program
```
python
student_program
=
fluid
.
Program
()
student_startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
student_program
,
student_startup
):
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
[
None
]
+
[
3
,
224
,
224
],
dtype
=
'float32'
)
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
# student model definition
model
=
MobileNet
()
out
=
model
.
net
(
input
=
image
,
class_dim
=
1000
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
```
### 2. 定义teacher_program
在定义好teacher_program后,可以一并加载训练好的pretrained_model
在teacher_program内需要加上
`with fluid.unique_name.guard():`
,保证teacher的变量命名不被student_program影响,从而跟能够正确地加载预训练参数
```
python
teacher_program
=
fluid
.
Program
()
teacher_startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
teacher_program
,
teacher_startup
):
with
fluid
.
unique_name
.
guard
():
image
=
fluid
.
data
(
name
=
'data'
,
shape
=
[
None
]
+
[
3
,
224
,
224
],
dtype
=
'float32'
)
# teacher model definition
teacher_model
=
ResNet
()
predict
=
teacher_model
.
net
(
image
,
class_dim
=
1000
)
exe
.
run
(
teacher_startup
)
def
if_exist
(
var
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
"./pretrained"
,
var
.
name
)
fluid
.
io
.
load_vars
(
exe
,
"./pretrained"
,
main_program
=
teacher_program
,
predicate
=
if_exist
)
```
### 3.选择特征图
定义好student_program和teacher_program后,我们需要从中两两对应地挑选出若干个特征图,留待后续为其添加知识蒸馏损失函数
```
python
# get all student variables
student_vars
=
[]
for
v
in
student_program
.
list_vars
():
try
:
student_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
print
(
"="
*
50
+
"student_model_vars"
+
"="
*
50
)
print
(
student_vars
)
# get all teacher variables
teacher_vars
=
[]
for
v
in
teacher_program
.
list_vars
():
try
:
teacher_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
print
(
"="
*
50
+
"teacher_model_vars"
+
"="
*
50
)
print
(
teacher_vars
)
```
### 4. 合并特征图(merge)
PaddlePaddle使用Program来描述计算图,为了同时计算student和teacher两个Program,这里需要将其两者合并(merge)为一个Program。
merge过程操作较多,具体细节请参考
[
merge API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/#merge
)
。
```
python
data_name_map
=
{
'data'
:
'image'
}
student_program
=
merge
(
teacher_program
,
student_program
,
data_name_map
,
place
)
```
### 5.添加蒸馏loss
在添加蒸馏loss的过程中,可能还会引入部分变量(Variable),为了避免命名重复这里可以使用
`with fluid.name_scope("distill"):`
为新引入的变量加一个命名作用域
```
python
with
fluid
.
program_guard
(
student_program
,
student_startup
):
with
fluid
.
name_scope
(
"distill"
):
distill_loss
=
l2_loss
(
'teacher_bn5c_branch2b.output.1.tmp_3'
,
'depthwise_conv2d_11.tmp_0'
,
main
)
distill_weight
=
1
loss
=
avg_cost
+
distill_loss
*
distill_weight
opt
=
create_optimizer
()
opt
.
minimize
(
loss
)
exe
.
run
(
student_startup
)
```
至此,我们就得到了用于蒸馏训练的student_program,后面就可以使用一个普通program一样对其开始训练和评估
\ No newline at end of file
docs/mkdocs.yml
浏览文件 @
0f9c3ceb
...
...
@@ -7,11 +7,12 @@ nav:
-
量化训练
:
tutorials/quant_aware_demo.md
-
Embedding量化
:
tutorials/quant_embedding_demo.md
-
SA搜索
:
tutorials/nas_demo.md
-
知识蒸馏
:
tutorials/distillation_demo.md
-
API
:
-
量化
:
api/quantization_api.md
-
剪枝与敏感度
:
api/prune_api.md
-
模型分析
:
api/analysis_api.md
-
蒸馏
:
api/single_distiller_api.md
-
知识
蒸馏
:
api/single_distiller_api.md
-
SA搜索
:
api/nas_api.md
-
搜索空间
:
api/search_space.md
-
硬件延时评估表
:
table_latency.md
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录