提交 de6b6ef1 编写于 作者: B baiyfbupt

fix some details in distillation demo

上级 7bc5e3c1
本示例将介绍如何使用PaddleSlim蒸馏接口来对模型进行蒸馏训练 本示例将介绍如何使用PaddleSlim蒸馏接口来对模型进行蒸馏训练
## 接口介绍 ## 接口介绍
...@@ -28,9 +28,9 @@ with fluid.program_guard(student_program, student_startup): ...@@ -28,9 +28,9 @@ with fluid.program_guard(student_program, student_startup):
### 2. 定义teacher_program ### 2. 定义teacher_program
在定义好teacher_program后,可以一并加载训练好的pretrained_model 在定义好`teacher_program`后,可以一并加载训练好的pretrained_model。
teacher_program内需要加上`with fluid.unique_name.guard():`,保证teacher的变量命名不被student_program影响,从而跟能够正确地加载预训练参数 `teacher_program`内需要加上`with fluid.unique_name.guard():`,保证teacher的变量命名不被`student_program`影响,从而能够正确地加载预训练参数。
```python ```python
teacher_program = fluid.Program() teacher_program = fluid.Program()
...@@ -55,7 +55,7 @@ fluid.io.load_vars( ...@@ -55,7 +55,7 @@ fluid.io.load_vars(
### 3.选择特征图 ### 3.选择特征图
定义好student_program和teacher_program后,我们需要从中两两对应地挑选出若干个特征图,留待后续为其添加知识蒸馏损失函数 定义好`student_program``teacher_program`后,我们需要从中两两对应地挑选出若干个特征图,留待后续为其添加知识蒸馏损失函数。
```python ```python
# get all student variables # get all student variables
...@@ -91,12 +91,15 @@ student_program = merge(teacher_program, student_program, data_name_map, place) ...@@ -91,12 +91,15 @@ student_program = merge(teacher_program, student_program, data_name_map, place)
### 5.添加蒸馏loss ### 5.添加蒸馏loss
在添加蒸馏loss的过程中,可能还会引入部分变量(Variable),为了避免命名重复这里可以使用`with fluid.name_scope("distill"):`为新引入的变量加一个命名作用域 在添加蒸馏loss的过程中,可能还会引入部分变量(Variable),为了避免命名重复这里可以使用`with fluid.name_scope("distill"):`为新引入的变量加一个命名作用域。
另外需要注意的是,merge过程为`teacher_program`的变量统一加了名称前缀,默认是`"teacher_"`, 这里在添加`l2_loss`时也要为teacher的变量加上这个前缀。
```python ```python
with fluid.program_guard(student_program, student_startup): with fluid.program_guard(student_program, student_startup):
with fluid.name_scope("distill"): with fluid.name_scope("distill"):
distill_loss = l2_loss('teacher_bn5c_branch2b.output.1.tmp_3', 'depthwise_conv2d_11.tmp_0', main) distill_loss = l2_loss('teacher_bn5c_branch2b.output.1.tmp_3',
'depthwise_conv2d_11.tmp_0', student_program)
distill_weight = 1 distill_weight = 1
loss = avg_cost + distill_loss * distill_weight loss = avg_cost + distill_loss * distill_weight
opt = create_optimizer() opt = create_optimizer()
...@@ -104,4 +107,4 @@ with fluid.program_guard(student_program, student_startup): ...@@ -104,4 +107,4 @@ with fluid.program_guard(student_program, student_startup):
exe.run(student_startup) exe.run(student_startup)
``` ```
至此,我们就得到了用于蒸馏训练的student_program,后面就可以使用一个普通program一样对其开始训练和评估 至此,我们就得到了用于蒸馏训练的`student_program`,后面就可以使用一个普通program一样对其开始训练和评估。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册