Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
d2cc9663
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d2cc9663
编写于
7月 28, 2021
作者:
C
ceci3
提交者:
GitHub
7月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[distill] how to get feature map (#799)
上级
6238fd7b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
358 addition
and
0 deletion
+358
-0
paddleslim/dygraph/dist/__init__.py
paddleslim/dygraph/dist/__init__.py
+5
-0
paddleslim/dygraph/dist/distill.py
paddleslim/dygraph/dist/distill.py
+186
-0
tests/dygraph/test_distill.py
tests/dygraph/test_distill.py
+167
-0
未找到文件。
paddleslim/dygraph/dist/__init__.py
浏览文件 @
d2cc9663
...
@@ -12,4 +12,9 @@
...
@@ -12,4 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.
import
distill
from
.distill
import
*
__all__
=
[]
__all__
=
[]
__all__
+=
distill
.
__all__
paddleslim/dygraph/dist/distill.py
0 → 100644
浏览文件 @
d2cc9663
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
collections
from
collections
import
namedtuple
import
paddle.nn
as
nn
from
.losses
import
*
__all__
=
[
'Distill'
,
'AdaptorBase'
]
class
LayerConfig
:
def
__init__
(
self
,
s_feature_idx
,
t_feature_idx
,
feature_type
,
loss_function
,
weight
=
1.0
,
align
=
False
,
align_shape
=
None
):
self
.
s_feature_idx
=
s_feature_idx
self
.
t_feature_idx
=
t_feature_idx
self
.
feature_type
=
feature_type
if
loss_function
in
[
'l1'
,
'l2'
,
'smooth_l1'
]:
self
.
loss_function
=
'DistillationDistanceLoss'
elif
loss_function
in
[
'dml'
]:
self
.
loss_function
=
'DistillationDMLLoss'
elif
loss_function
in
[
'rkl'
]:
self
.
loss_function
=
'DistillationRKDLoss'
else
:
raise
NotImplementedError
(
"loss function is not support!!!"
)
self
.
weight
=
weight
self
.
align
=
align
self
.
align_shape
=
align_shape
class
AdaptorBase
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
add_tensor
=
False
def
_get_activation
(
self
,
outs
,
name
):
def
get_output_hook
(
layer
,
input
,
output
):
outs
[
name
]
=
output
return
get_output_hook
def
_add_distill_hook
(
self
,
outs
,
mapping_layers_name
,
layers_type
):
"""
Get output by name.
outs(dict): save the middle outputs of model according to the name.
mapping_layers(list): name of middle layers.
layers_type(list): type of the middle layers to calculate distill loss.
"""
### TODO: support DP model
for
idx
,
(
n
,
m
)
in
enumerate
(
self
.
model
.
named_sublayers
()):
if
n
in
mapping_layers_name
:
midx
=
mapping_layers_name
.
index
(
n
)
m
.
register_forward_post_hook
(
self
.
_get_activation
(
outs
,
layers_type
[
midx
]))
def
mapping_layers
(
self
):
raise
NotImplementedError
(
"function mapping_layers is not implemented"
)
class
Distill
(
nn
.
Layer
):
### TODO: support list of student model and teacher model
def
__init__
(
self
,
distill_configs
,
student_models
,
teacher_models
,
adaptors_S
,
adaptors_T
):
super
(
Distill
,
self
).
__init__
()
self
.
_distill_configs
=
distill_configs
self
.
_student_models
=
student_models
self
.
_teacher_models
=
teacher_models
self
.
_adaptors_S
=
adaptors_S
(
self
.
_student_models
)
self
.
_adaptors_T
=
adaptors_T
(
self
.
_teacher_models
)
self
.
stu_outs_dict
,
self
.
tea_outs_dict
=
self
.
_prepare_outputs
()
self
.
configs
=
[]
for
c
in
self
.
_distill_configs
:
self
.
configs
.
append
(
LayerConfig
(
**
c
).
__dict__
)
self
.
distill_idx
=
self
.
_get_distill_idx
()
self
.
_loss_config_list
=
[]
for
c
in
self
.
configs
:
loss_config
=
{}
loss_config
[
str
(
c
[
'loss_function'
])]
=
{}
loss_config
[
str
(
c
[
'loss_function'
])][
'weight'
]
=
c
[
'weight'
]
loss_config
[
str
(
c
[
'loss_function'
])][
'key'
]
=
c
[
'feature_type'
]
+
'_'
+
str
(
c
[
's_feature_idx'
])
+
'_'
+
str
(
c
[
't_feature_idx'
])
### TODO: support list of student models and teacher_models
loss_config
[
str
(
c
[
'loss_function'
])][
'model_name_pairs'
]
=
[[
'student'
,
'teacher'
]]
self
.
_loss_config_list
.
append
(
loss_config
)
self
.
_prepare_loss
()
def
_prepare_hook
(
self
,
adaptors
,
outs_dict
):
mapping_layers
=
adaptors
.
mapping_layers
()
for
layer_type
,
layer
in
mapping_layers
.
items
():
if
isinstance
(
layer
,
str
):
adaptors
.
_add_distill_hook
(
outs_dict
,
[
layer
],
[
layer_type
])
return
outs_dict
def
_get_model_intermediate_output
(
self
,
adaptors
,
outs_dict
):
mapping_layers
=
adaptors
.
mapping_layers
()
for
layer_type
,
layer
in
mapping_layers
.
items
():
if
isinstance
(
layer
,
str
):
continue
outs_dict
[
layer_type
]
=
layer
return
outs_dict
def
_get_distill_idx
(
self
):
distill_idx
=
{}
for
config
in
self
.
_distill_configs
:
if
config
[
'feature_type'
]
not
in
distill_idx
:
distill_idx
[
config
[
'feature_type'
]]
=
[[
int
(
config
[
's_feature_idx'
]),
int
(
config
[
't_feature_idx'
])
]]
else
:
distill_idx
[
config
[
'feature_type'
]].
append
([
int
(
config
[
's_feature_idx'
]),
int
(
config
[
't_feature_idx'
])
])
return
distill_idx
def
_prepare_loss
(
self
):
self
.
distill_loss
=
CombinedLoss
(
self
.
_loss_config_list
)
def
_prepare_outputs
(
self
):
stu_outs_dict
=
collections
.
OrderedDict
()
tea_outs_dict
=
collections
.
OrderedDict
()
stu_outs_dict
=
self
.
_prepare_hook
(
self
.
_adaptors_S
,
stu_outs_dict
)
tea_outs_dict
=
self
.
_prepare_hook
(
self
.
_adaptors_T
,
tea_outs_dict
)
return
stu_outs_dict
,
tea_outs_dict
def
_post_outputs
(
self
):
final_keys
=
[]
for
key
,
value
in
self
.
stu_outs_dict
.
items
():
if
len
(
key
.
split
(
'_'
))
==
1
:
final_keys
.
append
(
key
)
### TODO: support list of student models and teacher_models
final_distill_dict
=
{
"student"
:
collections
.
OrderedDict
(),
"teacher"
:
collections
.
OrderedDict
()
}
for
feature_type
,
dist_idx
in
self
.
distill_idx
.
items
():
for
idx
,
idx_list
in
enumerate
(
dist_idx
):
sidx
,
tidx
=
idx_list
[
0
],
idx_list
[
1
]
final_distill_dict
[
'student'
][
feature_type
+
'_'
+
str
(
sidx
)
+
'_'
+
str
(
tidx
)]
=
self
.
stu_outs_dict
[
feature_type
+
'_'
+
str
(
sidx
)]
final_distill_dict
[
'teacher'
][
feature_type
+
'_'
+
str
(
sidx
)
+
'_'
+
str
(
tidx
)]
=
self
.
tea_outs_dict
[
feature_type
+
'_'
+
str
(
tidx
)]
return
final_distill_dict
def
forward
(
self
,
*
inputs
,
**
kwargs
):
stu_batch_outs
=
self
.
_student_models
.
forward
(
*
inputs
,
**
kwargs
)
tea_batch_outs
=
self
.
_teacher_models
.
forward
(
*
inputs
,
**
kwargs
)
if
self
.
_adaptors_S
.
add_tensor
==
False
:
self
.
_adaptors_S
.
add_tensor
=
True
if
self
.
_adaptors_T
.
add_tensor
==
False
:
self
.
_adaptors_T
.
add_tensor
=
True
self
.
stu_outs_dict
=
self
.
_get_model_intermediate_output
(
self
.
_adaptors_S
,
self
.
stu_outs_dict
)
self
.
tea_outs_dict
=
self
.
_get_model_intermediate_output
(
self
.
_adaptors_T
,
self
.
tea_outs_dict
)
distill_inputs
=
self
.
_post_outputs
()
### batch is None just for now
distill_outputs
=
self
.
distill_loss
(
distill_inputs
,
None
)
distill_loss
=
distill_outputs
[
'loss'
]
return
stu_batch_outs
,
tea_batch_outs
,
distill_loss
tests/dygraph/test_distill.py
0 → 100644
浏览文件 @
d2cc9663
import
sys
sys
.
path
.
append
(
"../../"
)
import
logging
import
numpy
as
np
import
unittest
import
paddle
import
paddle.nn
as
nn
from
paddle.vision.models
import
MobileNetV1
import
paddle.vision.transforms
as
T
from
paddleslim.dygraph.dist
import
Distill
,
AdaptorBase
from
paddleslim.common.log_helper
import
get_logger
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
class
TestImperativeDistill
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
s_model
,
self
.
t_model
=
self
.
prepare_model
()
self
.
t_model
.
eval
()
self
.
distill_configs
=
self
.
prepare_config
()
self
.
adaptor
=
self
.
prepare_adaptor
()
def
prepare_model
(
self
):
return
MobileNetV1
(),
MobileNetV1
()
def
prepare_config
(
self
):
distill_configs
=
[{
's_feature_idx'
:
0
,
't_feature_idx'
:
0
,
'feature_type'
:
'hidden'
,
'loss_function'
:
'l2'
},
{
's_feature_idx'
:
1
,
't_feature_idx'
:
1
,
'feature_type'
:
'hidden'
,
'loss_function'
:
'l2'
},
{
's_feature_idx'
:
0
,
't_feature_idx'
:
0
,
'feature_type'
:
'logits'
,
'loss_function'
:
'l2'
}]
return
distill_configs
def
prepare_adaptor
(
self
):
class
Adaptor
(
AdaptorBase
):
def
mapping_layers
(
self
):
mapping_layers
=
{}
mapping_layers
[
'hidden_0'
]
=
'conv1'
mapping_layers
[
'hidden_1'
]
=
'conv2_2'
mapping_layers
[
'hidden_2'
]
=
'conv3_2'
mapping_layers
[
'logits_0'
]
=
'fc'
return
mapping_layers
return
Adaptor
def
test_distill
(
self
):
transform
=
T
.
Compose
([
T
.
Transpose
(),
T
.
Normalize
([
127.5
],
[
127.5
])])
train_dataset
=
paddle
.
vision
.
datasets
.
Cifar10
(
mode
=
'train'
,
backend
=
'cv2'
,
transform
=
transform
)
val_dataset
=
paddle
.
vision
.
datasets
.
Cifar10
(
mode
=
'test'
,
backend
=
'cv2'
,
transform
=
transform
)
place
=
paddle
.
CUDAPlace
(
0
)
if
paddle
.
is_compiled_with_cuda
(
)
else
paddle
.
CPUPlace
()
train_reader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
drop_last
=
True
,
places
=
place
,
batch_size
=
64
)
test_reader
=
paddle
.
io
.
DataLoader
(
val_dataset
,
places
=
place
,
batch_size
=
64
)
def
test
(
model
):
model
.
eval
()
avg_acc
=
[[],
[]]
for
batch_id
,
data
in
enumerate
(
test_reader
):
img
=
paddle
.
to_tensor
(
data
[
0
])
label
=
paddle
.
to_tensor
(
data
[
1
])
label
=
paddle
.
reshape
(
label
,
[
-
1
,
1
])
out
=
model
(
img
)
acc_top1
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
avg_acc
[
0
].
append
(
acc_top1
.
numpy
())
avg_acc
[
1
].
append
(
acc_top5
.
numpy
())
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Test | step {}: acc1 = {:}, acc5 = {:}"
.
format
(
batch_id
,
acc_top1
.
numpy
(),
acc_top5
.
numpy
()))
_logger
.
info
(
"Test |Average: acc_top1 {}, acc_top5 {}"
.
format
(
np
.
mean
(
avg_acc
[
0
]),
np
.
mean
(
avg_acc
[
1
])))
def
train
(
model
):
adam
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameters
=
model
.
parameters
())
for
batch_id
,
data
in
enumerate
(
train_reader
):
img
=
paddle
.
to_tensor
(
data
[
0
])
label
=
paddle
.
to_tensor
(
data
[
1
])
student_out
,
teacher_out
,
distill_loss
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
loss
.
cross_entropy
(
student_out
,
label
)
avg_loss
=
paddle
.
mean
(
loss
)
all_loss
=
avg_loss
+
distill_loss
all_loss
.
backward
()
adam
.
step
()
adam
.
clear_grad
()
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Train | At epoch {} step {}: loss = {:}"
.
format
(
str
(
0
),
batch_id
,
all_loss
.
numpy
()))
test
(
self
.
s_model
)
self
.
s_model
.
train
()
distill_model
=
Distill
(
self
.
distill_configs
,
self
.
s_model
,
self
.
t_model
,
self
.
adaptor
,
self
.
adaptor
)
train
(
distill_model
)
class
TestImperativeDistillCase1
(
TestImperativeDistill
):
def
prepare_model
(
self
):
class
Model
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
3
,
3
,
3
,
padding
=
1
)
self
.
conv2
=
nn
.
Conv2D
(
3
,
3
,
3
,
padding
=
1
)
self
.
conv3
=
nn
.
Conv2D
(
3
,
3
,
3
,
padding
=
1
)
self
.
fc
=
nn
.
Linear
(
3072
,
10
)
def
forward
(
self
,
x
):
self
.
conv1_out
=
self
.
conv1
(
x
)
conv2_out
=
self
.
conv2
(
self
.
conv1_out
)
self
.
conv3_out
=
self
.
conv3
(
conv2_out
)
out
=
paddle
.
reshape
(
self
.
conv3_out
,
shape
=
[
x
.
shape
[
0
],
-
1
])
out
=
self
.
fc
(
out
)
return
out
return
Model
(),
Model
()
def
prepare_adaptor
(
self
):
class
Adaptor
(
AdaptorBase
):
def
mapping_layers
(
self
):
mapping_layers
=
{}
mapping_layers
[
'hidden_1'
]
=
'conv2'
if
self
.
add_tensor
:
mapping_layers
[
'hidden_0'
]
=
self
.
model
.
conv1_out
mapping_layers
[
'hidden_2'
]
=
self
.
model
.
conv3_out
return
mapping_layers
return
Adaptor
def
prepare_config
(
self
):
distill_configs
=
[{
's_feature_idx'
:
0
,
't_feature_idx'
:
0
,
'feature_type'
:
'hidden'
,
'loss_function'
:
'l2'
},
{
's_feature_idx'
:
1
,
't_feature_idx'
:
2
,
'feature_type'
:
'hidden'
,
'loss_function'
:
'l2'
}]
return
distill_configs
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录