Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
73c4f2b7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
73c4f2b7
编写于
3月 29, 2019
作者:
W
whs
提交者:
GitHub
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix distillation for soft label. (#16538)
test=develop
上级
3e6aa498
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
97 addition
and
2 deletion
+97
-2
python/paddle/fluid/contrib/slim/distillation/distiller.py
python/paddle/fluid/contrib/slim/distillation/distiller.py
+89
-1
python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml
...addle/fluid/contrib/slim/tests/distillation/compress.yaml
+8
-1
未找到文件。
python/paddle/fluid/contrib/slim/distillation/distiller.py
浏览文件 @
73c4f2b7
...
...
@@ -19,7 +19,7 @@ from .... import Program
from
....
import
program_guard
from
....
import
regularizer
__all__
=
[
'FSPDistiller'
,
'L2Distiller'
]
__all__
=
[
'FSPDistiller'
,
'L2Distiller'
,
'SoftLabelDistiller'
]
class
L2Distiller
(
object
):
...
...
@@ -186,3 +186,91 @@ class FSPDistillerPass(object):
def
_fsp_matrix
(
self
,
fea_map_0
,
fea_map_1
):
return
layers
.
fsp_matrix
(
fea_map_0
,
fea_map_1
)
class
SoftLabelDistiller
(
object
):
"""
Combine two layers from student net and teacher net by softmax_with_cross_entropy loss.
And add the loss into the total loss using for distillation training.
"""
def
__init__
(
self
,
student_feature_map
=
None
,
teacher_feature_map
=
None
,
student_temperature
=
1.0
,
teacher_temperature
=
1.0
,
distillation_loss_weight
=
1
):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy. default: 1.0
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy. default: 1.0
distillation_loss_weight(float): The weight of the l2-loss.
"""
self
.
student_feature_map
=
student_feature_map
self
.
teacher_feature_map
=
teacher_feature_map
self
.
distillation_loss_weight
=
distillation_loss_weight
self
.
student_temperature
=
student_temperature
self
.
teacher_temperature
=
teacher_temperature
def
distiller_loss
(
self
,
graph
):
"""
Modify graph inplace to add softmax_with_cross_entropy loss.
Args:
graph(GraphWrapper): The graph to be modified.
Returns:
GraphWrapper: The modified graph.
"""
distiller_pass
=
SoftLabelDistillerPass
(
self
.
student_feature_map
,
self
.
teacher_feature_map
,
self
.
student_temperature
,
self
.
teacher_temperature
,
self
.
distillation_loss_weight
)
dis_graph
=
distiller_pass
.
apply
(
graph
)
return
dis_graph
class
SoftLabelDistillerPass
(
object
):
def
__init__
(
self
,
student_feature_map
,
teacher_feature_map
,
student_temperature
,
teacher_temperature
,
distillation_loss_weight
=
1
):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy.
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy.
distillation_loss_weight(float): The weight of the l2-loss.
"""
self
.
student_feature_map
=
student_feature_map
self
.
teacher_feature_map
=
teacher_feature_map
self
.
student_temperature
=
student_temperature
self
.
teacher_temperature
=
teacher_temperature
self
.
distillation_loss_weight
=
distillation_loss_weight
def
apply
(
self
,
graph
):
ret_graph
=
graph
with
program_guard
(
ret_graph
.
program
):
student_feature_map
=
ret_graph
.
var
(
self
.
student_feature_map
).
_var
teacher_feature_map
=
ret_graph
.
var
(
self
.
teacher_feature_map
).
_var
s_fea
=
student_feature_map
/
self
.
student_temperature
t_fea
=
teacher_feature_map
/
self
.
distillation_loss_weight
t_fea
.
stop_gradient
=
True
ce_loss
=
layers
.
softmax_with_cross_entropy
(
s_fea
,
t_fea
,
soft_label
=
True
)
distillation_loss
=
ce_loss
*
self
.
distillation_loss_weight
student_loss
=
ret_graph
.
var
(
ret_graph
.
out_nodes
[
'loss'
]).
_var
loss
=
distillation_loss
+
student_loss
ret_graph
.
out_nodes
[
'soft_label_loss_'
+
self
.
student_feature_map
+
"_"
+
self
.
teacher_feature_map
]
=
distillation_loss
.
name
ret_graph
.
out_nodes
[
'loss'
]
=
loss
.
name
return
ret_graph
python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml
浏览文件 @
73c4f2b7
...
...
@@ -33,10 +33,17 @@ distillers:
teacher_feature_map
:
'
teacher.tmp_2'
student_feature_map
:
'
student.tmp_2'
distillation_loss_weight
:
1
soft_label_distiller
:
class
:
'
SoftLabelDistiller'
student_temperature
:
1.0
teacher_temperature
:
1.0
teacher_feature_map
:
'
teacher.tmp_1'
student_feature_map
:
'
student.tmp_1'
distillation_loss_weight
:
0.001
strategies
:
distillation_strategy
:
class
:
'
DistillationStrategy'
distillers
:
[
'
fsp_distiller'
,
'
l2_distiller'
]
distillers
:
[
'
fsp_distiller'
,
'
l2_distiller'
,
'
soft_label_distiller'
]
start_epoch
:
0
end_epoch
:
1
compressor
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录