Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
4528c8da
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4528c8da
编写于
6月 24, 2022
作者:
W
whs
提交者:
GitHub
6月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add dkd dist loss (#1189)
上级
50db9490
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
147 addition
and
1 deletion
+147
-1
paddleslim/dist/__init__.py
paddleslim/dist/__init__.py
+1
-1
paddleslim/dist/single_distiller.py
paddleslim/dist/single_distiller.py
+75
-0
tests/test_dkd_loss.py
tests/test_dkd_loss.py
+71
-0
未找到文件。
paddleslim/dist/__init__.py
浏览文件 @
4528c8da
...
...
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.single_distiller
import
merge
,
fsp
,
l2
,
soft_label
,
loss
from
.single_distiller
import
merge
,
fsp
,
l2
,
soft_label
,
loss
,
dkd
from
.dml
import
DML
paddleslim/dist/single_distiller.py
浏览文件 @
4528c8da
...
...
@@ -230,3 +230,78 @@ def loss(loss_func, program=None, **kwargs):
func_parameters
.
setdefault
(
item
[
0
],
item
[
1
])
loss
=
loss_func
(
**
func_parameters
)
return
loss
def
_top_mask
(
x
):
top_value
,
top_index
=
paddle
.
topk
(
x
,
1
)
return
paddle
.
cast
(
x
==
top_value
,
"int32"
)
def
_cal_tc_nc_pred
(
x
,
top_mask
):
"""Calculate the predictions of target class and non-target class.
The predictions of target class is a binary distribution.
And after removing the target class, the softmax on the remaining
parts produces the non-target predictions.
"""
pred
=
paddle
.
nn
.
functional
.
softmax
(
x
)
fp_mask
=
paddle
.
cast
(
top_mask
,
"float32"
)
top_value
=
paddle
.
sum
(
fp_mask
*
pred
,
axis
=
1
,
keepdim
=
True
)
tc_pred
=
paddle
.
concat
([
top_value
,
1
-
top_value
],
axis
=
1
)
tmp
=
paddle
.
assign
(
x
)
tmp
=
tmp
+
(
-
100000
*
top_mask
)
nc_pred
=
paddle
.
nn
.
functional
.
softmax
(
tmp
)
return
tc_pred
,
nc_pred
def
_dkd_loss
(
student_logits
,
teacher_logits
,
temperature
=
1.0
,
alpha
=
1.0
,
beta
=
1.0
):
mask
=
_top_mask
(
teacher_logits
)
print
(
f
"mask:
{
mask
.
shape
}
"
)
print
(
f
"student_logits:
{
student_logits
.
shape
}
; teacher_logits:
{
teacher_logits
.
shape
}
"
)
s_tc_pred
,
s_nc_pred
=
_cal_tc_nc_pred
(
student_logits
/
temperature
,
mask
)
t_tc_pred
,
t_nc_pred
=
_cal_tc_nc_pred
(
teacher_logits
/
temperature
,
mask
)
tc_loss
=
paddle
.
nn
.
functional
.
kl_div
(
s_tc_pred
,
t_tc_pred
,
reduction
=
'mean'
)
nc_loss
=
paddle
.
nn
.
functional
.
kl_div
(
s_nc_pred
,
t_nc_pred
,
reduction
=
'mean'
)
loss
=
alpha
*
tc_loss
+
beta
*
nc_loss
return
loss
*
temperature
**
2
def
dkd
(
teacher_var_name
,
student_var_name
,
program
=
None
,
temperature
=
1.0
,
alpha
=
1.0
,
beta
=
1.0
):
"""Combine variables from student model and teacher model
by Decoupled Knowledge Distillation loss (aka. dkd-loss).
Reference: https://github.com/megvii-research/mdistiller
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program. If not specified,
the default program will be used. Default: None
temperature(float): Temperature used to divide
teacher_feature_map before softmax. Default: 1.0
alpha(float): The weight of target class loss. Default: 1.0
beta(float): The weight of none-target class loss. Default: 1.0
Returns:
Variable: dkd distiller loss.
"""
if
program
==
None
:
program
=
paddle
.
static
.
default_main_program
()
student_var
=
program
.
global_block
().
var
(
student_var_name
)
teacher_var
=
program
.
global_block
().
var
(
teacher_var_name
)
return
_dkd_loss
(
student_var
,
teacher_var
,
temperature
=
temperature
,
alpha
=
alpha
,
beta
=
beta
)
tests/test_dkd_loss.py
0 → 100644
浏览文件 @
4528c8da
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
sys
.
path
.
append
(
"../"
)
import
unittest
import
paddle
from
paddleslim.dist
import
merge
,
dkd
from
layers
import
conv_bn_layer
from
static_case
import
StaticCase
class
TestDKDLoss
(
StaticCase
):
def
test_dkd_loss
(
self
):
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
student_predict
=
conv1
+
conv2
student_predict
=
paddle
.
fluid
.
layers
.
fc
(
student_predict
,
size
=
10
)
teacher_main
=
paddle
.
static
.
Program
()
teacher_startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
teacher_main
,
teacher_startup
):
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
sum1
=
conv1
+
conv2
conv3
=
conv_bn_layer
(
sum1
,
8
,
3
,
"conv3"
)
conv4
=
conv_bn_layer
(
conv3
,
8
,
3
,
"conv4"
)
sum2
=
conv4
+
sum1
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
teacher_predict
=
paddle
.
fluid
.
layers
.
fc
(
teacher_predict
,
size
=
10
)
place
=
paddle
.
CPUPlace
()
data_name_map
=
{
'image'
:
'image'
}
merge
(
teacher_main
,
paddle
.
static
.
default_main_program
(),
data_name_map
,
place
)
merged_ops
=
[]
for
block
in
paddle
.
static
.
default_main_program
().
blocks
:
for
op
in
block
.
ops
:
merged_ops
.
append
(
op
.
type
)
distill_loss
=
dkd
(
"teacher_"
+
(
teacher_predict
.
name
),
student_predict
.
name
)
loss_ops
=
[]
for
block
in
paddle
.
static
.
default_main_program
().
blocks
:
for
op
in
block
.
ops
:
loss_ops
.
append
(
op
.
type
)
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
self
.
assertTrue
(
set
(
loss_ops
).
difference
(
set
(
merged_ops
))
==
{
'kldiv_loss'
,
'assign'
,
'scale'
,
'concat'
,
'reduce_sum'
,
'equal'
,
'softmax'
,
'reduce_mean'
,
'cast'
,
'elementwise_mul'
,
'top_k_v2'
})
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录