Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
73c4f2b7
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
73c4f2b7
编写于
3月 29, 2019
作者:
W
whs
提交者:
GitHub
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix distillation for soft label. (#16538)
test=develop
上级
3e6aa498
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
97 addition
and
2 deletion
+97
-2
python/paddle/fluid/contrib/slim/distillation/distiller.py
python/paddle/fluid/contrib/slim/distillation/distiller.py
+89
-1
python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml
...addle/fluid/contrib/slim/tests/distillation/compress.yaml
+8
-1
未找到文件。
python/paddle/fluid/contrib/slim/distillation/distiller.py
浏览文件 @
73c4f2b7
...
...
@@ -19,7 +19,7 @@ from .... import Program
from
....
import
program_guard
from
....
import
regularizer
__all__
=
[
'FSPDistiller'
,
'L2Distiller'
]
__all__
=
[
'FSPDistiller'
,
'L2Distiller'
,
'SoftLabelDistiller'
]
class
L2Distiller
(
object
):
...
...
@@ -186,3 +186,91 @@ class FSPDistillerPass(object):
def
_fsp_matrix
(
self
,
fea_map_0
,
fea_map_1
):
return
layers
.
fsp_matrix
(
fea_map_0
,
fea_map_1
)
class
SoftLabelDistiller
(
object
):
"""
Combine two layers from student net and teacher net by softmax_with_cross_entropy loss.
And add the loss into the total loss using for distillation training.
"""
def
__init__
(
self
,
student_feature_map
=
None
,
teacher_feature_map
=
None
,
student_temperature
=
1.0
,
teacher_temperature
=
1.0
,
distillation_loss_weight
=
1
):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy. default: 1.0
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy. default: 1.0
distillation_loss_weight(float): The weight of the l2-loss.
"""
self
.
student_feature_map
=
student_feature_map
self
.
teacher_feature_map
=
teacher_feature_map
self
.
distillation_loss_weight
=
distillation_loss_weight
self
.
student_temperature
=
student_temperature
self
.
teacher_temperature
=
teacher_temperature
def
distiller_loss
(
self
,
graph
):
"""
Modify graph inplace to add softmax_with_cross_entropy loss.
Args:
graph(GraphWrapper): The graph to be modified.
Returns:
GraphWrapper: The modified graph.
"""
distiller_pass
=
SoftLabelDistillerPass
(
self
.
student_feature_map
,
self
.
teacher_feature_map
,
self
.
student_temperature
,
self
.
teacher_temperature
,
self
.
distillation_loss_weight
)
dis_graph
=
distiller_pass
.
apply
(
graph
)
return
dis_graph
class
SoftLabelDistillerPass
(
object
):
def
__init__
(
self
,
student_feature_map
,
teacher_feature_map
,
student_temperature
,
teacher_temperature
,
distillation_loss_weight
=
1
):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy.
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy.
distillation_loss_weight(float): The weight of the l2-loss.
"""
self
.
student_feature_map
=
student_feature_map
self
.
teacher_feature_map
=
teacher_feature_map
self
.
student_temperature
=
student_temperature
self
.
teacher_temperature
=
teacher_temperature
self
.
distillation_loss_weight
=
distillation_loss_weight
def
apply
(
self
,
graph
):
ret_graph
=
graph
with
program_guard
(
ret_graph
.
program
):
student_feature_map
=
ret_graph
.
var
(
self
.
student_feature_map
).
_var
teacher_feature_map
=
ret_graph
.
var
(
self
.
teacher_feature_map
).
_var
s_fea
=
student_feature_map
/
self
.
student_temperature
t_fea
=
teacher_feature_map
/
self
.
distillation_loss_weight
t_fea
.
stop_gradient
=
True
ce_loss
=
layers
.
softmax_with_cross_entropy
(
s_fea
,
t_fea
,
soft_label
=
True
)
distillation_loss
=
ce_loss
*
self
.
distillation_loss_weight
student_loss
=
ret_graph
.
var
(
ret_graph
.
out_nodes
[
'loss'
]).
_var
loss
=
distillation_loss
+
student_loss
ret_graph
.
out_nodes
[
'soft_label_loss_'
+
self
.
student_feature_map
+
"_"
+
self
.
teacher_feature_map
]
=
distillation_loss
.
name
ret_graph
.
out_nodes
[
'loss'
]
=
loss
.
name
return
ret_graph
python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml
浏览文件 @
73c4f2b7
...
...
@@ -33,10 +33,17 @@ distillers:
teacher_feature_map
:
'
teacher.tmp_2'
student_feature_map
:
'
student.tmp_2'
distillation_loss_weight
:
1
soft_label_distiller
:
class
:
'
SoftLabelDistiller'
student_temperature
:
1.0
teacher_temperature
:
1.0
teacher_feature_map
:
'
teacher.tmp_1'
student_feature_map
:
'
student.tmp_1'
distillation_loss_weight
:
0.001
strategies
:
distillation_strategy
:
class
:
'
DistillationStrategy'
distillers
:
[
'
fsp_distiller'
,
'
l2_distiller'
]
distillers
:
[
'
fsp_distiller'
,
'
l2_distiller'
,
'
soft_label_distiller'
]
start_epoch
:
0
end_epoch
:
1
compressor
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录