Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
a35619b8
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a35619b8
编写于
12月 27, 2022
作者:
Z
zhouzj
提交者:
GitHub
12月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add skd distillation. (#1587)
* add skd distillation. * update skd's test.
上级
bddce3ea
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
137 addition
and
2 deletion
+137
-2
paddleslim/dist/__init__.py
paddleslim/dist/__init__.py
+1
-1
paddleslim/dist/single_distiller.py
paddleslim/dist/single_distiller.py
+55
-1
tests/test_skd_loss.py
tests/test_skd_loss.py
+81
-0
未找到文件。
paddleslim/dist/__init__.py
浏览文件 @
a35619b8
...
...
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.single_distiller
import
merge
,
fsp
,
l2
,
soft_label
,
loss
,
dkd
from
.single_distiller
import
merge
,
fsp
,
l2
,
soft_label
,
loss
,
dkd
,
skd
from
.dml
import
DML
paddleslim/dist/single_distiller.py
浏览文件 @
a35619b8
...
...
@@ -15,6 +15,7 @@
import
numpy
as
np
import
paddle
from
paddleslim.core
import
GraphWrapper
import
paddle.nn.functional
as
F
def
_find_var_from_program
(
program
,
var_name
):
...
...
@@ -300,7 +301,10 @@ def soft_label(teacher_var_name,
teacher_temperature
)
soft_label_loss
=
paddle
.
mean
(
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
student_var
,
label
=
teacher_var
,
soft_label
=
True
))
input
=
student_var
,
label
=
teacher_var
,
soft_label
=
True
,
use_softmax
=
False
))
return
soft_label_loss
...
...
@@ -401,3 +405,53 @@ def dkd(teacher_var_name,
temperature
=
temperature
,
alpha
=
alpha
,
beta
=
beta
)
def
skd
(
teacher_var_name
,
student_var_name
,
program
=
None
,
multiplier
=
None
):
"""Combine variables from student model and teacher model
by Spherical Knowledge Distillation loss (aka. skd-loss).
Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program. If not specified,
the default program will be used. Default: None
multiplier(float): The multiplier to recover its norm to the original
level. When it's None, the appropriate multiplier can be computed by
teacher's logits with paddle.std(output_t, axis=1). Default: None.
Returns:
Variable: skd distiller loss.
"""
if
program
==
None
:
program
=
paddle
.
static
.
default_main_program
()
student_var
=
program
.
global_block
().
var
(
student_var_name
)
teacher_var
=
program
.
global_block
().
var
(
teacher_var_name
)
teacher_var
.
stop_gradient
=
True
if
multiplier
is
None
:
multiplier
=
paddle
.
std
(
teacher_var
,
axis
=
1
,
keepdim
=
True
)
logits_student
=
F
.
layer_norm
(
student_var
,
student_var
.
shape
[
1
:],
weight
=
None
,
bias
=
None
,
epsilon
=
1e-7
)
*
multiplier
logits_teacher
=
F
.
layer_norm
(
teacher_var
,
teacher_var
.
shape
[
1
:],
weight
=
None
,
bias
=
None
,
epsilon
=
1e-7
)
*
multiplier
student_out
=
F
.
softmax
(
logits_student
,
axis
=
1
)
teacher_out
=
F
.
softmax
(
logits_teacher
,
axis
=
1
)
skd_loss
=
paddle
.
mean
(
F
.
cross_entropy
(
input
=
student_out
,
label
=
teacher_out
,
soft_label
=
True
,
use_softmax
=
False
))
return
skd_loss
tests/test_skd_loss.py
0 → 100644
浏览文件 @
a35619b8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
sys
.
path
.
append
(
"../"
)
import
unittest
import
paddle
from
paddleslim.dist
import
merge
,
skd
from
layers
import
conv_bn_layer
from
static_case
import
StaticCase
class
TestSKDLoss
(
StaticCase
):
def
test_skd_loss
(
self
):
place
=
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
student_program
=
paddle
.
static
.
Program
()
student_startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
student_program
,
student_startup
):
with
paddle
.
utils
.
unique_name
.
guard
():
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
student_predict
=
conv1
+
conv2
teacher_program
=
paddle
.
static
.
Program
()
teacher_startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
teacher_program
,
teacher_startup
):
with
paddle
.
utils
.
unique_name
.
guard
():
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
sum1
=
conv1
+
conv2
conv3
=
conv_bn_layer
(
sum1
,
8
,
3
,
"conv3"
)
conv4
=
conv_bn_layer
(
conv3
,
8
,
3
,
"conv4"
)
sum2
=
conv4
+
sum1
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
exe
.
run
(
teacher_startup
)
exe
.
run
(
student_startup
)
data_name_map
=
{
'image'
:
'image'
}
merge
(
teacher_program
,
student_program
,
data_name_map
,
place
)
merged_ops
=
[]
for
block
in
student_program
.
blocks
:
for
op
in
block
.
ops
:
merged_ops
.
append
(
op
.
type
)
with
paddle
.
static
.
program_guard
(
student_program
,
student_startup
):
distill_loss
=
skd
(
'teacher_'
+
teacher_predict
.
name
,
student_predict
.
name
,
program
=
None
,
multiplier
=
None
)
loss_ops
=
[]
for
block
in
student_program
.
blocks
:
for
op
in
block
.
ops
:
loss_ops
.
append
(
op
.
type
)
print
(
f
"ret:
{
set
(
loss_ops
).
difference
(
set
(
merged_ops
))
}
"
)
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
self
.
assertTrue
({
'softmax_with_cross_entropy'
,
'softmax'
,
'reduce_mean'
,
'layer_norm'
}.
issubset
(
set
(
loss_ops
).
difference
(
set
(
merged_ops
))))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录