Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
dfe5d3f7
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dfe5d3f7
编写于
7月 26, 2021
作者:
littletomatodonkey
提交者:
GitHub
7月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[distill]add distillation losses (#789)
上级
c9c0e83f
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
1110 addition
and
0 deletion
+1110
-0
paddleslim/dygraph/dist/losses/__init__.py
paddleslim/dygraph/dist/losses/__init__.py
+70
-0
paddleslim/dygraph/dist/losses/basic_loss.py
paddleslim/dygraph/dist/losses/basic_loss.py
+207
-0
paddleslim/dygraph/dist/losses/distillation_loss.py
paddleslim/dygraph/dist/losses/distillation_loss.py
+136
-0
tests/dygraph/test_distillation_loss.py
tests/dygraph/test_distillation_loss.py
+697
-0
未找到文件。
paddleslim/dygraph/dist/losses/__init__.py
浏览文件 @
dfe5d3f7
...
...
@@ -11,3 +11,73 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
paddle
import
paddle.nn
as
nn
from
.
import
basic_loss
from
.
import
distillation_loss
from
.basic_loss
import
L1Loss
from
.basic_loss
import
L2Loss
from
.basic_loss
import
SmoothL1Loss
from
.basic_loss
import
CELoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DistanceLoss
from
.basic_loss
import
RKdAngle
,
RkdDistance
from
.distillation_loss
import
DistillationDistanceLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationRKDLoss
class
CombinedLoss
(
nn
.
Layer
):
"""
CombinedLoss: a combination of loss function.
Args:
loss_config_list: a config list used to build loss function. A demo is as follows,
which is used to calculate dml loss between Student output and
Teacher output. Parameter weight is needed for the loss weight.
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
"""
def
__init__
(
self
,
loss_config_list
=
None
):
super
().
__init__
()
loss_config_list
=
copy
.
deepcopy
(
loss_config_list
)
self
.
loss_func
=
[]
self
.
loss_weight
=
[]
assert
isinstance
(
loss_config_list
,
list
),
(
'operator config should be a list'
)
supported_loss_list
=
basic_loss
.
__all__
+
distillation_loss
.
__all__
for
config
in
loss_config_list
:
assert
isinstance
(
config
,
dict
)
and
len
(
config
)
==
1
,
"yaml format error"
name
=
list
(
config
)[
0
]
assert
name
in
supported_loss_list
,
\
"loss name must be in {} but got: {}"
.
format
(
name
,
supported_loss_list
)
param
=
config
[
name
]
assert
"weight"
in
param
,
"weight must be in param, but param just contains {}"
.
format
(
param
.
keys
())
self
.
loss_weight
.
append
(
param
.
pop
(
"weight"
))
self
.
loss_func
.
append
(
eval
(
name
)(
**
param
))
def
forward
(
self
,
input
,
batch
,
**
kargs
):
loss_dict
=
{}
for
idx
,
loss_func
in
enumerate
(
self
.
loss_func
):
loss
=
loss_func
(
input
,
batch
,
**
kargs
)
weight
=
self
.
loss_weight
[
idx
]
if
isinstance
(
loss
,
paddle
.
Tensor
):
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
*
weight
}
else
:
loss
=
{
"{}_{}"
.
format
(
key
,
idx
):
loss
[
key
]
*
weight
for
key
in
loss
}
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
return
loss_dict
paddleslim/dygraph/dist/losses/basic_loss.py
0 → 100644
浏览文件 @
dfe5d3f7
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
L1Loss
from
paddle.nn
import
MSELoss
as
L2Loss
from
paddle.nn
import
SmoothL1Loss
__all__
=
[
"CELoss"
,
"DMLLoss"
,
"DistanceLoss"
,
"RKdAngle"
,
"RkdDistance"
,
]
class
CELoss
(
nn
.
Layer
):
"""
CELoss: cross entropy loss
Args:
epsilon(float | None): label smooth epsilon. If it is None or not in range (0, 1),
then label smooth will not be used.
label_act(string | None): activation function, it works when the label is also the logits
rather than the groundtruth label.
axis(int): axis used to calculate cross entropy loss.
"""
def
__init__
(
self
,
epsilon
=
None
,
label_act
=
"softmax"
,
axis
=-
1
):
super
().
__init__
()
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
epsilon
=
None
assert
label_act
in
[
"softmax"
,
None
]
if
epsilon
is
not
None
and
(
epsilon
>=
1
or
epsilon
<=
0
):
epsilon
=
None
self
.
epsilon
=
epsilon
self
.
label_act
=
label_act
self
.
axis
=
axis
def
_labelsmoothing
(
self
,
target
,
class_num
):
if
target
.
shape
[
-
1
]
!=
class_num
:
one_hot_target
=
F
.
one_hot
(
target
,
class_num
)
else
:
one_hot_target
=
target
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
epsilon
)
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
class_num
])
return
soft_target
def
forward
(
self
,
x
,
label
):
assert
len
(
x
.
shape
)
==
len
(
label
.
shape
),
\
"x and label shape length should be same but got {} for x.shape and {} for label.shape"
.
format
(
x
.
shape
,
label
.
shape
)
if
self
.
epsilon
is
not
None
:
class_num
=
x
.
shape
[
-
1
]
label
=
self
.
_labelsmoothing
(
label
,
class_num
)
x
=
-
F
.
log_softmax
(
x
,
axis
=
self
.
axis
)
loss
=
paddle
.
sum
(
x
*
label
,
axis
=
self
.
axis
)
else
:
if
label
.
shape
[
self
.
axis
]
==
x
.
shape
[
self
.
axis
]:
if
self
.
label_act
==
"softmax"
:
label
=
F
.
softmax
(
label
,
axis
=
self
.
axis
)
soft_label
=
True
else
:
soft_label
=
False
loss
=
F
.
cross_entropy
(
x
,
label
=
label
,
soft_label
=
soft_label
,
axis
=
self
.
axis
)
loss
=
loss
.
mean
()
return
loss
class
DMLLoss
(
nn
.
Layer
):
"""
DMLLoss
Args:
act(string | None): activation function used to activate the input tensor
axis(int): axis used to build activation function
"""
def
__init__
(
self
,
act
=
None
,
axis
=-
1
):
super
().
__init__
()
if
act
is
not
None
:
assert
act
in
[
"softmax"
,
"sigmoid"
]
if
act
==
"softmax"
:
self
.
act
=
nn
.
Softmax
(
axis
=
axis
)
elif
act
==
"sigmoid"
:
self
.
act
=
nn
.
Sigmoid
()
else
:
self
.
act
=
None
def
forward
(
self
,
out1
,
out2
):
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
log_out1
=
paddle
.
log
(
out1
)
log_out2
=
paddle
.
log
(
out2
)
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
return
loss
class
DistanceLoss
(
nn
.
Layer
):
"""
DistanceLoss
Args:
mode: loss mode
kargs(dict): used to build corresponding loss function, for more details, please
refer to:
L1loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/L1Loss_cn.html#l1loss
L2Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/MSELoss_cn.html#mseloss
SmoothL1Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SmoothL1Loss_cn.html#smoothl1loss
"""
def
__init__
(
self
,
mode
=
"l2"
,
**
kargs
):
super
().
__init__
()
assert
mode
in
[
"l1"
,
"l2"
,
"smooth_l1"
]
if
mode
==
"l1"
:
self
.
loss_func
=
nn
.
L1Loss
(
**
kargs
)
elif
mode
==
"l2"
:
self
.
loss_func
=
nn
.
MSELoss
(
**
kargs
)
elif
mode
==
"smooth_l1"
:
self
.
loss_func
=
nn
.
SmoothL1Loss
(
**
kargs
)
def
forward
(
self
,
x
,
y
):
return
self
.
loss_func
(
x
,
y
)
def
pdist
(
e
,
squared
=
False
,
eps
=
1e-12
):
e_square
=
e
.
pow
(
2
).
sum
(
axis
=
1
)
prod
=
paddle
.
mm
(
e
,
e
.
t
())
res
=
(
e_square
.
unsqueeze
(
1
)
+
e_square
.
unsqueeze
(
0
)
-
2
*
prod
).
clip
(
min
=
eps
)
if
not
squared
:
res
=
res
.
sqrt
()
return
res
class
RKdAngle
(
nn
.
Layer
):
"""
RKdAngle loss, see https://arxiv.org/abs/1904.05068
"""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
student
,
teacher
):
# reshape for feature map distillation
bs
=
student
.
shape
[
0
]
student
=
student
.
reshape
([
bs
,
-
1
])
teacher
=
teacher
.
reshape
([
bs
,
-
1
])
td
=
(
teacher
.
unsqueeze
(
0
)
-
teacher
.
unsqueeze
(
1
))
norm_td
=
F
.
normalize
(
td
,
p
=
2
,
axis
=
2
)
t_angle
=
paddle
.
bmm
(
norm_td
,
norm_td
.
transpose
([
0
,
2
,
1
])).
reshape
(
[
-
1
,
1
])
sd
=
(
student
.
unsqueeze
(
0
)
-
student
.
unsqueeze
(
1
))
norm_sd
=
F
.
normalize
(
sd
,
p
=
2
,
axis
=
2
)
s_angle
=
paddle
.
bmm
(
norm_sd
,
norm_sd
.
transpose
([
0
,
2
,
1
])).
reshape
(
[
-
1
,
1
])
loss
=
F
.
smooth_l1_loss
(
s_angle
,
t_angle
,
reduction
=
'mean'
)
return
loss
class
RkdDistance
(
nn
.
Layer
):
"""
RkdDistance loss, see https://arxiv.org/abs/1904.05068
Args:
eps(float): epsilon for the pdist function
"""
def
__init__
(
self
,
eps
=
1e-12
):
super
().
__init__
()
self
.
eps
=
eps
def
forward
(
self
,
student
,
teacher
):
bs
=
student
.
shape
[
0
]
student
=
student
.
reshape
([
bs
,
-
1
])
teacher
=
teacher
.
reshape
([
bs
,
-
1
])
t_d
=
pdist
(
teacher
,
squared
=
False
,
eps
=
self
.
eps
)
mean_td
=
t_d
.
mean
()
t_d
=
t_d
/
(
mean_td
+
self
.
eps
)
d
=
pdist
(
student
,
squared
=
False
,
eps
=
self
.
eps
)
mean_d
=
d
.
mean
()
d
=
d
/
(
mean_d
+
self
.
eps
)
loss
=
F
.
smooth_l1_loss
(
d
,
t_d
,
reduction
=
"mean"
)
return
loss
paddleslim/dygraph/dist/losses/distillation_loss.py
0 → 100644
浏览文件 @
dfe5d3f7
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DistanceLoss
from
.basic_loss
import
RkdDistance
from
.basic_loss
import
RKdAngle
__all__
=
[
"DistillationDMLLoss"
,
"DistillationDistanceLoss"
,
"DistillationRKDLoss"
,
]
class
DistillationDMLLoss
(
DMLLoss
):
"""
DistillationDMLLoss
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
act(string | None): activation function used to build dml loss.
axis(int): axis used to build activation function.
key(string | None): key of the tensor used to calculate loss if the submodel
output type is dict.
name(string): loss name.
"""
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
key
=
None
,
name
=
"loss_dml"
):
super
().
__init__
(
act
=
act
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss_dict
[
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
super
().
forward
(
out1
,
out2
)
return
loss_dict
class
DistillationDistanceLoss
(
DistanceLoss
):
"""
DistillationDistanceLoss
Args:
mode: loss mode
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string | None): key of the tensor used to calculate loss if the submodel.
name(string): loss name.
kargs(dict): used to build corresponding loss function.
"""
def
__init__
(
self
,
mode
=
"l2"
,
model_name_pairs
=
[],
key
=
None
,
name
=
"loss_distance"
,
**
kargs
):
super
().
__init__
(
mode
=
mode
,
**
kargs
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
+
"_"
+
mode
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
)
loss_dict
[
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
return
loss_dict
class
DistillationRKDLoss
(
nn
.
Layer
):
"""
DistillationRKDLoss
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string | None): key of the tensor used to calculate loss if the submodel.
eps(float): epsilon for the pdist function for RkdDistance loss.
name(string): loss name.
"""
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
eps
=
1e-12
,
name
=
"loss_rkd"
):
super
().
__init__
()
self
.
model_name_pairs
=
model_name_pairs
self
.
key
=
key
self
.
rkd_angle_loss_func
=
RKdAngle
()
self
.
rkd_dist_func
=
RkdDistance
(
eps
=
eps
)
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss_dict
[
"{}_{}_{}_angle_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
self
.
rkd_angle_loss_func
(
out1
,
out2
)
loss_dict
[
"{}_{}_{}_dist_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
self
.
rkd_dist_func
(
out1
,
out2
)
return
loss_dict
tests/dygraph/test_distillation_loss.py
0 → 100644
浏览文件 @
dfe5d3f7
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录