Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
15ef0c7d
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
15ef0c7d
编写于
8月 17, 2021
作者:
C
cc
提交者:
GitHub
8月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine distillation for Segmentation (#879)
* refine distillation * up * add test * fix coverage * fix unit test error
上级
30fd1248
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
314 addition
and
51 deletion
+314
-51
paddleslim/dygraph/dist/distill.py
paddleslim/dygraph/dist/distill.py
+82
-45
paddleslim/dygraph/dist/losses/__init__.py
paddleslim/dygraph/dist/losses/__init__.py
+7
-1
paddleslim/dygraph/dist/losses/basic_loss.py
paddleslim/dygraph/dist/losses/basic_loss.py
+44
-5
paddleslim/dygraph/dist/losses/distillation_loss.py
paddleslim/dygraph/dist/losses/distillation_loss.py
+88
-0
tests/dygraph/test_distillation_loss.py
tests/dygraph/test_distillation_loss.py
+93
-0
未找到文件。
paddleslim/dygraph/dist/distill.py
浏览文件 @
15ef0c7d
...
...
@@ -16,7 +16,7 @@ import numpy as np
import
collections
from
collections
import
namedtuple
import
paddle.nn
as
nn
from
.
losses
import
*
from
.
import
losses
__all__
=
[
'Distill'
,
'AdaptorBase'
]
...
...
@@ -39,6 +39,8 @@ class LayerConfig:
self
.
loss_function
=
'DistillationDMLLoss'
elif
loss_function
in
[
'rkl'
]:
self
.
loss_function
=
'DistillationRKDLoss'
elif
hasattr
(
losses
,
loss_function
):
self
.
loss_function
=
loss_function
else
:
raise
NotImplementedError
(
"loss function is not support!!!"
)
self
.
weight
=
weight
...
...
@@ -59,11 +61,12 @@ class AdaptorBase:
def
_add_distill_hook
(
self
,
outs
,
mapping_layers_name
,
layers_type
):
"""
Get output by name.
Get output by
layer
name.
outs(dict): save the middle outputs of model according to the name.
mapping_layers(list): name of middle layers.
layers_type(list): type of the middle layers to calculate distill loss.
"""
### TODO: support DP model
for
idx
,
(
n
,
m
)
in
enumerate
(
self
.
model
.
named_sublayers
()):
if
n
in
mapping_layers_name
:
...
...
@@ -80,6 +83,8 @@ class Distill(nn.Layer):
def
__init__
(
self
,
distill_configs
,
student_models
,
teacher_models
,
adaptors_S
,
adaptors_T
):
super
(
Distill
,
self
).
__init__
()
assert
student_models
.
training
,
"The student model should be eval mode."
self
.
_distill_configs
=
distill_configs
self
.
_student_models
=
student_models
self
.
_teacher_models
=
teacher_models
...
...
@@ -93,6 +98,7 @@ class Distill(nn.Layer):
self
.
configs
.
append
(
LayerConfig
(
**
c
).
__dict__
)
self
.
distill_idx
=
self
.
_get_distill_idx
()
self
.
_loss_config_list
=
[]
for
c
in
self
.
configs
:
loss_config
=
{}
...
...
@@ -105,24 +111,42 @@ class Distill(nn.Layer):
loss_config
[
str
(
c
[
'loss_function'
])][
'model_name_pairs'
]
=
[[
'student'
,
'teacher'
]]
self
.
_loss_config_list
.
append
(
loss_config
)
self
.
_prepare_loss
()
# use self._loss_config_list to create all loss object
self
.
distill_loss
=
losses
.
CombinedLoss
(
self
.
_loss_config_list
)
def
_prepare_outputs
(
self
):
"""
Add hook to get the output tensor of target layer.
Returns:
stu_outs_dict(dict): the name and tensor for the student model,
such as {'hidden_0': tensor_0, ..}
tea_outs_dict(dict): the name and tensor for the teather model,
such as {'hidden_0': tensor_0, ..}
"""
stu_outs_dict
=
collections
.
OrderedDict
()
tea_outs_dict
=
collections
.
OrderedDict
()
stu_outs_dict
=
self
.
_prepare_hook
(
self
.
_adaptors_S
,
stu_outs_dict
)
tea_outs_dict
=
self
.
_prepare_hook
(
self
.
_adaptors_T
,
tea_outs_dict
)
return
stu_outs_dict
,
tea_outs_dict
def
_prepare_hook
(
self
,
adaptors
,
outs_dict
):
"""
Add hook.
"""
mapping_layers
=
adaptors
.
mapping_layers
()
for
layer_type
,
layer
in
mapping_layers
.
items
():
if
isinstance
(
layer
,
str
):
adaptors
.
_add_distill_hook
(
outs_dict
,
[
layer
],
[
layer_type
])
return
outs_dict
def
_get_model_intermediate_output
(
self
,
adaptors
,
outs_dict
):
mapping_layers
=
adaptors
.
mapping_layers
()
for
layer_type
,
layer
in
mapping_layers
.
items
():
if
isinstance
(
layer
,
str
):
continue
outs_dict
[
layer_type
]
=
layer
return
outs_dict
def
_get_distill_idx
(
self
):
"""
For each feature_type, get the feature index in the student and teacher models.
Returns:
distill_idx(dict): the feature index for each feature_type,
such as {'hidden': [[0, 0], [1, 1]], 'out': [[0, 0]]}
"""
distill_idx
=
{}
for
config
in
self
.
_distill_configs
:
if
config
[
'feature_type'
]
not
in
distill_idx
:
...
...
@@ -135,42 +159,13 @@ class Distill(nn.Layer):
])
return
distill_idx
def
_prepare_loss
(
self
):
self
.
distill_loss
=
CombinedLoss
(
self
.
_loss_config_list
)
def
_prepare_outputs
(
self
):
stu_outs_dict
=
collections
.
OrderedDict
()
tea_outs_dict
=
collections
.
OrderedDict
()
stu_outs_dict
=
self
.
_prepare_hook
(
self
.
_adaptors_S
,
stu_outs_dict
)
tea_outs_dict
=
self
.
_prepare_hook
(
self
.
_adaptors_T
,
tea_outs_dict
)
return
stu_outs_dict
,
tea_outs_dict
def
_post_outputs
(
self
):
final_keys
=
[]
for
key
,
value
in
self
.
stu_outs_dict
.
items
():
if
len
(
key
.
split
(
'_'
))
==
1
:
final_keys
.
append
(
key
)
### TODO: support list of student models and teacher_models
final_distill_dict
=
{
"student"
:
collections
.
OrderedDict
(),
"teacher"
:
collections
.
OrderedDict
()
}
for
feature_type
,
dist_idx
in
self
.
distill_idx
.
items
():
for
idx
,
idx_list
in
enumerate
(
dist_idx
):
sidx
,
tidx
=
idx_list
[
0
],
idx_list
[
1
]
final_distill_dict
[
'student'
][
feature_type
+
'_'
+
str
(
sidx
)
+
'_'
+
str
(
tidx
)]
=
self
.
stu_outs_dict
[
feature_type
+
'_'
+
str
(
sidx
)]
final_distill_dict
[
'teacher'
][
feature_type
+
'_'
+
str
(
sidx
)
+
'_'
+
str
(
tidx
)]
=
self
.
tea_outs_dict
[
feature_type
+
'_'
+
str
(
tidx
)]
return
final_distill_dict
def
forward
(
self
,
*
inputs
,
**
kwargs
):
stu_batch_outs
=
self
.
_student_models
.
forward
(
*
inputs
,
**
kwargs
)
tea_batch_outs
=
self
.
_teacher_models
.
forward
(
*
inputs
,
**
kwargs
)
if
not
self
.
_teacher_models
.
training
:
tea_batch_outs
=
[
i
.
detach
()
for
i
in
tea_batch_outs
]
# get all target tensor
if
self
.
_adaptors_S
.
add_tensor
==
False
:
self
.
_adaptors_S
.
add_tensor
=
True
if
self
.
_adaptors_T
.
add_tensor
==
False
:
...
...
@@ -179,8 +174,50 @@ class Distill(nn.Layer):
self
.
_adaptors_S
,
self
.
stu_outs_dict
)
self
.
tea_outs_dict
=
self
.
_get_model_intermediate_output
(
self
.
_adaptors_T
,
self
.
tea_outs_dict
)
distill_inputs
=
self
.
_post_outputs
()
distill_inputs
=
self
.
_process_outputs
()
### batch is None just for now
distill_outputs
=
self
.
distill_loss
(
distill_inputs
,
None
)
distill_loss
=
distill_outputs
[
'loss'
]
return
stu_batch_outs
,
tea_batch_outs
,
distill_loss
def
_get_model_intermediate_output
(
self
,
adaptors
,
outs_dict
):
"""
Use the adaptor get the target tensor.
Returns:
outs_dict(dict): the name and tensor for the target model,
such as {'hidden_0': tensor_0, ..}
"""
mapping_layers
=
adaptors
.
mapping_layers
()
for
layer_type
,
layer
in
mapping_layers
.
items
():
if
isinstance
(
layer
,
str
):
continue
outs_dict
[
layer_type
]
=
layer
return
outs_dict
def
_process_outputs
(
self
):
"""
Process the target tensor to adapt for loss.
"""
### TODO: support list of student models and teacher_models
final_distill_dict
=
{
"student"
:
collections
.
OrderedDict
(),
"teacher"
:
collections
.
OrderedDict
()
}
for
feature_type
,
dist_idx
in
self
.
distill_idx
.
items
():
for
idx
,
idx_list
in
enumerate
(
dist_idx
):
sidx
,
tidx
=
idx_list
[
0
],
idx_list
[
1
]
stu_out
=
self
.
stu_outs_dict
[
feature_type
+
'_'
+
str
(
sidx
)]
tea_out
=
self
.
tea_outs_dict
[
feature_type
+
'_'
+
str
(
tidx
)]
if
not
self
.
_student_models
.
training
:
stu_out
=
stu_out
.
detach
()
if
not
self
.
_teacher_models
.
training
:
tea_out
=
tea_out
.
detach
()
name_str
=
feature_type
+
'_'
+
str
(
sidx
)
+
'_'
+
str
(
tidx
)
final_distill_dict
[
'student'
][
name_str
]
=
stu_out
final_distill_dict
[
'teacher'
][
name_str
]
=
tea_out
return
final_distill_dict
paddleslim/dygraph/dist/losses/__init__.py
浏览文件 @
15ef0c7d
...
...
@@ -30,6 +30,7 @@ from .basic_loss import RKdAngle, RkdDistance
from
.distillation_loss
import
DistillationDistanceLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationRKDLoss
from
.distillation_loss
import
SegPairWiseLoss
,
SegChannelwiseLoss
class
CombinedLoss
(
nn
.
Layer
):
...
...
@@ -44,6 +45,8 @@ class CombinedLoss(nn.Layer):
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
Another example is {'DistillationDistanceLoss': {'weight': 1.0,
'key': 'hidden_0_0', 'model_name_pairs': [['student', 'teacher']]}
"""
def
__init__
(
self
,
loss_config_list
=
None
):
...
...
@@ -79,5 +82,8 @@ class CombinedLoss(nn.Layer):
for
key
in
loss
}
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
if
loss_dict
==
{}:
loss_dict
[
"loss"
]
=
paddle
.
to_tensor
(
0.
)
else
:
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
return
loss_dict
paddleslim/dygraph/dist/losses/basic_loss.py
浏览文件 @
15ef0c7d
...
...
@@ -21,11 +21,7 @@ from paddle.nn import MSELoss as L2Loss
from
paddle.nn
import
SmoothL1Loss
__all__
=
[
"CELoss"
,
"DMLLoss"
,
"DistanceLoss"
,
"RKdAngle"
,
"RkdDistance"
,
"CELoss"
,
"DMLLoss"
,
"DistanceLoss"
,
"RKdAngle"
,
"RkdDistance"
,
"KLLoss"
]
...
...
@@ -114,6 +110,49 @@ class DMLLoss(nn.Layer):
return
loss
class
KLLoss
(
nn
.
Layer
):
"""
KLLoss.
Args:
act(string | None): activation function used for the input and label tensor.
It supports None, softmax and sigmoid. Default: softmax.
axis(int): the axis for the act. Default: -1.
reduction(str): the reduction params for F.kl_div. Default: mean.
"""
def
__init__
(
self
,
act
=
'softmax'
,
axis
=-
1
,
reduction
=
'mean'
):
super
().
__init__
()
assert
act
in
[
'softmax'
,
'sigmoid'
,
None
]
self
.
reduction
=
reduction
if
act
==
'softmax'
:
self
.
act
=
nn
.
Softmax
(
axis
=
axis
)
elif
act
==
'sigmoid'
:
self
.
act
=
nn
.
Sigmoid
()
else
:
self
.
act
=
None
def
forward
(
self
,
input
,
label
):
"""
Args:
input(Tensor): The input tensor.
label(Tensor): The label tensor. The shape of label is the same as input.
Returns:
Tensor: The kl loss.
"""
assert
input
.
shape
==
label
.
shape
,
\
"The shape of label should be the same as input."
if
self
.
act
is
not
None
:
input
=
self
.
act
(
input
)
label
=
self
.
act
(
label
)
log_input
=
paddle
.
log
(
input
)
loss
=
F
.
kl_div
(
log_input
,
label
,
reduction
=
self
.
reduction
)
return
loss
class
DistanceLoss
(
nn
.
Layer
):
"""
DistanceLoss
...
...
paddleslim/dygraph/dist/losses/distillation_loss.py
浏览文件 @
15ef0c7d
...
...
@@ -19,11 +19,14 @@ from .basic_loss import DMLLoss
from
.basic_loss
import
DistanceLoss
from
.basic_loss
import
RkdDistance
from
.basic_loss
import
RKdAngle
from
.basic_loss
import
KLLoss
__all__
=
[
"DistillationDMLLoss"
,
"DistillationDistanceLoss"
,
"DistillationRKDLoss"
,
"SegPairWiseLoss"
,
"SegChannelwiseLoss"
,
]
...
...
@@ -66,7 +69,9 @@ class DistillationDistanceLoss(DistanceLoss):
Args:
mode: loss mode
model_name_pairs(list | tuple): model name pairs to extract submodel output.
such as [['student', 'teacher']]
key(string | None): key of the tensor used to calculate loss if the submodel.
such as 'hidden_0_0'
name(string): loss name.
kargs(dict): used to build corresponding loss function.
"""
...
...
@@ -134,3 +139,86 @@ class DistillationRKDLoss(nn.Layer):
loss_dict
[
"{}_{}_{}_dist_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
self
.
rkd_dist_func
(
out1
,
out2
)
return
loss_dict
class
SegPairWiseLoss
(
DistanceLoss
):
"""
Segmentation pairwise loss, see https://arxiv.org/pdf/1903.04197.pdf
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
mode(string, optional): loss mode. It supports l1, l2 and smooth_l1. Default: l2.
reduction(string, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_pair_wise_loss.
"""
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
mode
=
"l2"
,
reduction
=
"mean"
,
name
=
"seg_pair_wise_loss"
):
super
().
__init__
(
mode
=
mode
,
reduction
=
reduction
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
key
is
not
None
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
self
.
pool1
=
nn
.
AdaptiveAvgPool2D
(
output_size
=
[
2
,
2
])
self
.
pool2
=
nn
.
AdaptiveAvgPool2D
(
output_size
=
[
2
,
2
])
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]][
self
.
key
]
out2
=
predicts
[
pair
[
1
]][
self
.
key
]
pool1
=
self
.
pool1
(
out1
)
pool2
=
self
.
pool2
(
out2
)
loss_name
=
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)
loss_dict
[
loss_name
]
=
super
().
forward
(
pool1
,
pool2
)
return
loss_dict
class
SegChannelwiseLoss
(
KLLoss
):
"""
Segmentation channel wise loss, see `Channel-wise Distillation for Semantic Segmentation`.
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
act(string, optional): activation function used for the input and label tensor.
Default: softmax.
axis(int, optional): the axis for the act. Default: -1.
reduction(str, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_ch_wise_loss.
"""
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
act
=
'softmax'
,
axis
=-
1
,
reduction
=
"mean"
,
name
=
"seg_ch_wise_loss"
):
super
().
__init__
(
act
,
axis
,
reduction
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
key
is
not
None
self
.
model_name_pairs
=
model_name_pairs
self
.
key
=
key
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]][
self
.
key
]
out2
=
predicts
[
pair
[
1
]][
self
.
key
]
loss_name
=
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)
loss_dict
[
loss_name
]
=
super
().
forward
(
out1
,
out2
)
return
loss_dict
tests/dygraph/test_distillation_loss.py
浏览文件 @
15ef0c7d
...
...
@@ -18,6 +18,7 @@ import copy
import
unittest
import
paddle
import
paddle.nn.functional
as
F
# basic loss
from
paddleslim.dygraph.dist.losses
import
CombinedLoss
...
...
@@ -33,6 +34,8 @@ from paddleslim.dygraph.dist.losses import RKdAngle
from
paddleslim.dygraph.dist.losses
import
DistillationDistanceLoss
from
paddleslim.dygraph.dist.losses
import
DistillationRKDLoss
from
paddleslim.dygraph.dist.losses
import
DistillationDMLLoss
from
paddleslim.dygraph.dist.losses
import
SegPairWiseLoss
from
paddleslim.dygraph.dist.losses
import
SegChannelwiseLoss
import
numpy
as
np
...
...
@@ -693,5 +696,95 @@ class TestCombinedLoss(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
np_result
,
pd_result
))
class
TestSegPairWiseLoss
(
unittest
.
TestCase
):
def
calculate_gt_loss
(
self
,
x
,
y
):
pool_x
=
F
.
adaptive_avg_pool2d
(
x
,
[
2
,
2
])
pool_y
=
F
.
adaptive_avg_pool2d
(
y
,
[
2
,
2
])
loss
=
F
.
mse_loss
(
pool_x
,
pool_y
)
return
loss
def
test_seg_pair_wise_loss
(
self
):
shape
=
[
1
,
3
,
10
,
10
]
x
=
paddle
.
rand
(
shape
)
y
=
paddle
.
rand
(
shape
)
model_name_pairs
=
[[
'student'
,
'teacher'
]]
key
=
'hidden_0_0'
inputs
=
{
model_name_pairs
[
0
][
0
]:
{
key
:
x
},
model_name_pairs
[
0
][
1
]:
{
key
:
y
}
}
devices
=
[
"cpu"
]
if
paddle
.
is_compiled_with_cuda
():
devices
.
append
(
"gpu"
)
for
device
in
devices
:
paddle
.
set_device
(
device
)
loss_func
=
SegPairWiseLoss
(
model_name_pairs
,
key
)
pd_loss_dict
=
loss_func
(
inputs
,
None
)
pd_loss
=
pd_loss_dict
[
'seg_pair_wise_loss_student_teacher_0'
]
gt_loss
=
self
.
calculate_gt_loss
(
x
,
y
)
self
.
assertTrue
(
np
.
allclose
(
pd_loss
.
numpy
(),
gt_loss
.
numpy
()))
class
TestSegChannelWiseLoss
(
unittest
.
TestCase
):
def
init
(
self
):
self
.
act_name
=
None
self
.
act_func
=
None
def
calculate_gt_loss
(
self
,
x
,
y
,
act
=
None
):
if
act
is
not
None
:
x
=
act
(
x
)
y
=
act
(
y
)
x
=
paddle
.
log
(
x
)
loss
=
F
.
kl_div
(
x
,
y
)
return
loss
def
test_seg_pair_wise_loss
(
self
):
self
.
init
()
shape
=
[
1
,
3
,
10
,
10
]
x
=
paddle
.
rand
(
shape
)
y
=
paddle
.
rand
(
shape
)
model_name_pairs
=
[[
'student'
,
'teacher'
]]
key
=
'hidden_0_0'
inputs
=
{
model_name_pairs
[
0
][
0
]:
{
key
:
x
},
model_name_pairs
[
0
][
1
]:
{
key
:
y
}
}
devices
=
[
"cpu"
]
if
paddle
.
is_compiled_with_cuda
():
devices
.
append
(
"gpu"
)
for
device
in
devices
:
paddle
.
set_device
(
device
)
loss_func
=
SegChannelwiseLoss
(
model_name_pairs
,
key
,
self
.
act_name
)
pd_loss_dict
=
loss_func
(
inputs
,
None
)
pd_loss
=
pd_loss_dict
[
'seg_ch_wise_loss_student_teacher_0'
]
gt_loss
=
self
.
calculate_gt_loss
(
x
,
y
,
self
.
act_func
)
self
.
assertTrue
(
np
.
allclose
(
pd_loss
.
numpy
(),
gt_loss
.
numpy
()))
class
TestSegChannelWiseLoss1
(
TestSegChannelWiseLoss
):
def
init
(
self
):
self
.
act_name
=
"softmax"
self
.
act_func
=
F
.
softmax
class
TestSegChannelWiseLoss1
(
TestSegChannelWiseLoss
):
def
init
(
self
):
self
.
act_name
=
"sigmoid"
self
.
act_func
=
F
.
sigmoid
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录