Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
7595ba6d
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7595ba6d
编写于
2月 28, 2022
作者:
wc晨曦
提交者:
GitHub
2月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add AFD (#1683)
* add AFD
上级
b27acf6a
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
554 addition
and
2 deletion
+554
-2
ppcls/arch/__init__.py
ppcls/arch/__init__.py
+23
-1
ppcls/arch/backbone/base/theseus_layer.py
ppcls/arch/backbone/base/theseus_layer.py
+1
-1
ppcls/arch/distill/afd_attention.py
ppcls/arch/distill/afd_attention.py
+123
-0
ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml
.../ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml
+202
-0
ppcls/data/postprocess/topk.py
ppcls/data/postprocess/topk.py
+2
-0
ppcls/engine/engine.py
ppcls/engine/engine.py
+2
-0
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+2
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+2
-0
ppcls/loss/afdloss.py
ppcls/loss/afdloss.py
+132
-0
ppcls/loss/distillationloss.py
ppcls/loss/distillationloss.py
+32
-0
ppcls/loss/kldivloss.py
ppcls/loss/kldivloss.py
+33
-0
未找到文件。
ppcls/arch/__init__.py
浏览文件 @
7595ba6d
...
@@ -27,8 +27,9 @@ from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
...
@@ -27,8 +27,9 @@ from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
from
ppcls.utils.save_load
import
load_dygraph_pretrain
from
ppcls.utils.save_load
import
load_dygraph_pretrain
from
ppcls.arch.slim
import
prune_model
,
quantize_model
from
ppcls.arch.slim
import
prune_model
,
quantize_model
from
ppcls.arch.distill.afd_attention
import
LinearTransformStudent
,
LinearTransformTeacher
__all__
=
[
"build_model"
,
"RecModel"
,
"DistillationModel"
]
__all__
=
[
"build_model"
,
"RecModel"
,
"DistillationModel"
,
"AttentionModel"
]
def
build_model
(
config
):
def
build_model
(
config
):
...
@@ -132,3 +133,24 @@ class DistillationModel(nn.Layer):
...
@@ -132,3 +133,24 @@ class DistillationModel(nn.Layer):
else
:
else
:
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
,
label
)
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
,
label
)
return
result_dict
return
result_dict
class
AttentionModel
(
DistillationModel
):
def
__init__
(
self
,
models
=
None
,
pretrained_list
=
None
,
freeze_params_list
=
None
,
**
kargs
):
super
().
__init__
(
models
,
pretrained_list
,
freeze_params_list
,
**
kargs
)
def
forward
(
self
,
x
,
label
=
None
):
result_dict
=
dict
()
out
=
x
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
if
label
is
None
:
out
=
self
.
model_list
[
idx
](
out
)
result_dict
.
update
(
out
)
else
:
out
=
self
.
model_list
[
idx
](
out
,
label
)
result_dict
.
update
(
out
)
return
result_dict
ppcls/arch/backbone/base/theseus_layer.py
浏览文件 @
7595ba6d
...
@@ -35,7 +35,7 @@ class TheseusLayer(nn.Layer):
...
@@ -35,7 +35,7 @@ class TheseusLayer(nn.Layer):
self
.
quanter
=
None
self
.
quanter
=
None
def
_return_dict_hook
(
self
,
layer
,
input
,
output
):
def
_return_dict_hook
(
self
,
layer
,
input
,
output
):
res_dict
=
{
"
output
"
:
output
}
res_dict
=
{
"
logits
"
:
output
}
# 'list' is needed to avoid error raised by popping self.res_dict
# 'list' is needed to avoid error raised by popping self.res_dict
for
res_key
in
list
(
self
.
res_dict
):
for
res_key
in
list
(
self
.
res_dict
):
# clear the res_dict because the forward process may change according to input
# clear the res_dict because the forward process may change according to input
...
...
ppcls/arch/distill/afd_attention.py
0 → 100644
浏览文件 @
7595ba6d
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
paddle
import
numpy
as
np
class
LinearBNReLU
(
nn
.
Layer
):
def
__init__
(
self
,
nin
,
nout
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
nin
,
nout
)
self
.
bn
=
nn
.
BatchNorm1D
(
nout
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
,
relu
=
True
):
if
relu
:
return
self
.
relu
(
self
.
bn
(
self
.
linear
(
x
)))
return
self
.
bn
(
self
.
linear
(
x
))
def
unique_shape
(
s_shapes
):
n_s
=
[]
unique_shapes
=
[]
n
=
-
1
for
s_shape
in
s_shapes
:
if
s_shape
not
in
unique_shapes
:
unique_shapes
.
append
(
s_shape
)
n
+=
1
n_s
.
append
(
n
)
return
n_s
,
unique_shapes
class
LinearTransformTeacher
(
nn
.
Layer
):
def
__init__
(
self
,
qk_dim
,
t_shapes
,
keys
):
super
().
__init__
()
self
.
teacher_keys
=
keys
self
.
t_shapes
=
[[
1
]
+
t_i
for
t_i
in
t_shapes
]
self
.
query_layer
=
nn
.
LayerList
(
[
LinearBNReLU
(
t_shape
[
1
],
qk_dim
)
for
t_shape
in
self
.
t_shapes
])
def
forward
(
self
,
t_features_dict
):
g_t
=
[
t_features_dict
[
key
]
for
key
in
self
.
teacher_keys
]
bs
=
g_t
[
0
].
shape
[
0
]
channel_mean
=
[
f_t
.
mean
(
3
).
mean
(
2
)
for
f_t
in
g_t
]
spatial_mean
=
[]
for
i
in
range
(
len
(
g_t
)):
c
,
h
,
w
=
g_t
[
i
].
shape
[
1
:]
spatial_mean
.
append
(
g_t
[
i
].
pow
(
2
).
mean
(
1
).
reshape
([
bs
,
h
*
w
]))
query
=
paddle
.
stack
(
[
query_layer
(
f_t
,
relu
=
False
)
for
f_t
,
query_layer
in
zip
(
channel_mean
,
self
.
query_layer
)
],
axis
=
1
)
value
=
[
F
.
normalize
(
f_s
,
axis
=
1
)
for
f_s
in
spatial_mean
]
return
{
"query"
:
query
,
"value"
:
value
}
class
LinearTransformStudent
(
nn
.
Layer
):
def
__init__
(
self
,
qk_dim
,
t_shapes
,
s_shapes
,
keys
):
super
().
__init__
()
self
.
student_keys
=
keys
self
.
t_shapes
=
[[
1
]
+
t_i
for
t_i
in
t_shapes
]
self
.
s_shapes
=
[[
1
]
+
s_i
for
s_i
in
s_shapes
]
self
.
t
=
len
(
self
.
t_shapes
)
self
.
s
=
len
(
self
.
s_shapes
)
self
.
qk_dim
=
qk_dim
self
.
n_t
,
self
.
unique_t_shapes
=
unique_shape
(
self
.
t_shapes
)
self
.
relu
=
nn
.
ReLU
()
self
.
samplers
=
nn
.
LayerList
(
[
Sample
(
t_shape
)
for
t_shape
in
self
.
unique_t_shapes
])
self
.
key_layer
=
nn
.
LayerList
([
LinearBNReLU
(
s_shape
[
1
],
self
.
qk_dim
)
for
s_shape
in
self
.
s_shapes
])
self
.
bilinear
=
LinearBNReLU
(
qk_dim
,
qk_dim
*
len
(
self
.
t_shapes
))
def
forward
(
self
,
s_features_dict
):
g_s
=
[
s_features_dict
[
key
]
for
key
in
self
.
student_keys
]
bs
=
g_s
[
0
].
shape
[
0
]
channel_mean
=
[
f_s
.
mean
(
3
).
mean
(
2
)
for
f_s
in
g_s
]
spatial_mean
=
[
sampler
(
g_s
,
bs
)
for
sampler
in
self
.
samplers
]
key
=
paddle
.
stack
(
[
key_layer
(
f_s
)
for
key_layer
,
f_s
in
zip
(
self
.
key_layer
,
channel_mean
)
],
axis
=
1
).
reshape
([
-
1
,
self
.
qk_dim
])
# Bs x h
bilinear_key
=
self
.
bilinear
(
key
,
relu
=
False
).
reshape
([
bs
,
self
.
s
,
self
.
t
,
self
.
qk_dim
])
value
=
[
F
.
normalize
(
s_m
,
axis
=
2
)
for
s_m
in
spatial_mean
]
return
{
"bilinear_key"
:
bilinear_key
,
"value"
:
value
}
class
Sample
(
nn
.
Layer
):
def
__init__
(
self
,
t_shape
):
super
().
__init__
()
self
.
t_N
,
self
.
t_C
,
self
.
t_H
,
self
.
t_W
=
t_shape
self
.
sample
=
nn
.
AdaptiveAvgPool2D
((
self
.
t_H
,
self
.
t_W
))
def
forward
(
self
,
g_s
,
bs
):
g_s
=
paddle
.
stack
(
[
self
.
sample
(
f_s
.
pow
(
2
).
mean
(
1
,
keepdim
=
True
)).
reshape
([
bs
,
self
.
t_H
*
self
.
t_W
])
for
f_s
in
g_s
],
axis
=
1
)
return
g_s
ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml
0 → 100644
浏览文件 @
7595ba6d
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
"
./inference"
# model architecture
Arch
:
name
:
"
DistillationModel"
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
models
:
-
Teacher
:
name
:
AttentionModel
pretrained_list
:
freeze_params_list
:
-
True
-
False
models
:
-
ResNet34
:
name
:
ResNet34
pretrained
:
True
return_patterns
:
&t_keys
[
"
blocks[0]"
,
"
blocks[1]"
,
"
blocks[2]"
,
"
blocks[3]"
,
"
blocks[4]"
,
"
blocks[5]"
,
"
blocks[6]"
,
"
blocks[7]"
,
"
blocks[8]"
,
"
blocks[9]"
,
"
blocks[10]"
,
"
blocks[11]"
,
"
blocks[12]"
,
"
blocks[13]"
,
"
blocks[14]"
,
"
blocks[15]"
]
-
LinearTransformTeacher
:
name
:
LinearTransformTeacher
qk_dim
:
128
keys
:
*t_keys
t_shapes
:
&t_shapes
[[
64
,
56
,
56
],
[
64
,
56
,
56
],
[
64
,
56
,
56
],
[
128
,
28
,
28
],
[
128
,
28
,
28
],
[
128
,
28
,
28
],
[
128
,
28
,
28
],
[
256
,
14
,
14
],
[
256
,
14
,
14
],
[
256
,
14
,
14
],
[
256
,
14
,
14
],
[
256
,
14
,
14
],
[
256
,
14
,
14
],
[
512
,
7
,
7
],
[
512
,
7
,
7
],
[
512
,
7
,
7
]]
-
Student
:
name
:
AttentionModel
pretrained_list
:
freeze_params_list
:
-
False
-
False
models
:
-
ResNet18
:
name
:
ResNet18
pretrained
:
False
return_patterns
:
&s_keys
[
"
blocks[0]"
,
"
blocks[1]"
,
"
blocks[2]"
,
"
blocks[3]"
,
"
blocks[4]"
,
"
blocks[5]"
,
"
blocks[6]"
,
"
blocks[7]"
]
-
LinearTransformStudent
:
name
:
LinearTransformStudent
qk_dim
:
128
keys
:
*s_keys
s_shapes
:
&s_shapes
[[
64
,
56
,
56
],
[
64
,
56
,
56
],
[
128
,
28
,
28
],
[
128
,
28
,
28
],
[
256
,
14
,
14
],
[
256
,
14
,
14
],
[
512
,
7
,
7
],
[
512
,
7
,
7
]]
t_shapes
:
*t_shapes
infer_model_name
:
"
Student"
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationGTCELoss
:
weight
:
1.0
model_names
:
[
"
Student"
]
key
:
logits
-
DistillationKLDivLoss
:
weight
:
0.9
model_name_pairs
:
[[
"
Student"
,
"
Teacher"
]]
temperature
:
4
key
:
logits
-
AFDLoss
:
weight
:
50.0
model_name_pair
:
[
"
Student"
,
"
Teacher"
]
student_keys
:
[
"
bilinear_key"
,
"
value"
]
teacher_keys
:
[
"
query"
,
"
value"
]
s_shapes
:
*s_shapes
t_shapes
:
*t_shapes
Eval
:
-
DistillationGTCELoss
:
weight
:
1.0
model_names
:
[
"
Student"
]
Optimizer
:
name
:
Momentum
momentum
:
0.9
weight_decay
:
1e-4
lr
:
name
:
MultiStepDecay
learning_rate
:
0.1
milestones
:
[
30
,
60
,
90
]
step_each_epoch
:
1
gamma
:
0.1
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/train_list.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
bicubic
backend
:
pil
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/val_list.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
"
docs/images/inference_deployment/whl_demo.jpg"
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
DistillationPostProcess
func
:
Topk
topk
:
5
class_id_map_file
:
"
ppcls/utils/imagenet1k_label_list.txt"
Metric
:
Train
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
Eval
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
ppcls/data/postprocess/topk.py
浏览文件 @
7595ba6d
...
@@ -46,6 +46,8 @@ class Topk(object):
...
@@ -46,6 +46,8 @@ class Topk(object):
return
class_id_map
return
class_id_map
def
__call__
(
self
,
x
,
file_names
=
None
,
multilabel
=
False
):
def
__call__
(
self
,
x
,
file_names
=
None
,
multilabel
=
False
):
if
isinstance
(
x
,
dict
):
x
=
x
[
'logits'
]
assert
isinstance
(
x
,
paddle
.
Tensor
)
assert
isinstance
(
x
,
paddle
.
Tensor
)
if
file_names
is
not
None
:
if
file_names
is
not
None
:
assert
x
.
shape
[
0
]
==
len
(
file_names
)
assert
x
.
shape
[
0
]
==
len
(
file_names
)
...
...
ppcls/engine/engine.py
浏览文件 @
7595ba6d
...
@@ -459,5 +459,7 @@ class ExportModel(TheseusLayer):
...
@@ -459,5 +459,7 @@ class ExportModel(TheseusLayer):
if
self
.
infer_output_key
is
not
None
:
if
self
.
infer_output_key
is
not
None
:
x
=
x
[
self
.
infer_output_key
]
x
=
x
[
self
.
infer_output_key
]
if
self
.
out_act
is
not
None
:
if
self
.
out_act
is
not
None
:
if
isinstance
(
x
,
dict
):
x
=
x
[
"logits"
]
x
=
self
.
out_act
(
x
)
x
=
self
.
out_act
(
x
)
return
x
return
x
ppcls/engine/evaluation/classification.py
浏览文件 @
7595ba6d
...
@@ -99,6 +99,8 @@ def classification_eval(engine, epoch_id=0):
...
@@ -99,6 +99,8 @@ def classification_eval(engine, epoch_id=0):
if
isinstance
(
out
,
dict
):
if
isinstance
(
out
,
dict
):
if
"Student"
in
out
:
if
"Student"
in
out
:
out
=
out
[
"Student"
]
out
=
out
[
"Student"
]
if
isinstance
(
out
,
dict
):
out
=
out
[
"logits"
]
elif
"logits"
in
out
:
elif
"logits"
in
out
:
out
=
out
[
"logits"
]
out
=
out
[
"logits"
]
else
:
else
:
...
...
ppcls/loss/__init__.py
浏览文件 @
7595ba6d
...
@@ -22,7 +22,9 @@ from .distillationloss import DistillationGTCELoss
...
@@ -22,7 +22,9 @@ from .distillationloss import DistillationGTCELoss
from
.distillationloss
import
DistillationDMLLoss
from
.distillationloss
import
DistillationDMLLoss
from
.distillationloss
import
DistillationDistanceLoss
from
.distillationloss
import
DistillationDistanceLoss
from
.distillationloss
import
DistillationRKDLoss
from
.distillationloss
import
DistillationRKDLoss
from
.distillationloss
import
DistillationKLDivLoss
from
.multilabelloss
import
MultiLabelLoss
from
.multilabelloss
import
MultiLabelLoss
from
.afdloss
import
AFDLoss
from
.deephashloss
import
DSHSDLoss
,
LCDSHLoss
from
.deephashloss
import
DSHSDLoss
,
LCDSHLoss
...
...
ppcls/loss/afdloss.py
0 → 100644
浏览文件 @
7595ba6d
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
paddle
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
cv2
import
warnings
warnings
.
filterwarnings
(
'ignore'
)
class
LinearBNReLU
(
nn
.
Layer
):
def
__init__
(
self
,
nin
,
nout
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
nin
,
nout
)
self
.
bn
=
nn
.
BatchNorm1D
(
nout
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
,
relu
=
True
):
if
relu
:
return
self
.
relu
(
self
.
bn
(
self
.
linear
(
x
)))
return
self
.
bn
(
self
.
linear
(
x
))
def
unique_shape
(
s_shapes
):
n_s
=
[]
unique_shapes
=
[]
n
=
-
1
for
s_shape
in
s_shapes
:
if
s_shape
not
in
unique_shapes
:
unique_shapes
.
append
(
s_shape
)
n
+=
1
n_s
.
append
(
n
)
return
n_s
,
unique_shapes
class
AFDLoss
(
nn
.
Layer
):
"""
AFDLoss
https://www.aaai.org/AAAI21Papers/AAAI-9785.JiM.pdf
https://github.com/clovaai/attention-feature-distillation
"""
def
__init__
(
self
,
model_name_pair
=
[
"Student"
,
"Teacher"
],
student_keys
=
[
"bilinear_key"
,
"value"
],
teacher_keys
=
[
"query"
,
"value"
],
s_shapes
=
[[
64
,
16
,
160
],
[
128
,
8
,
160
],
[
256
,
4
,
160
],
[
512
,
2
,
160
]],
t_shapes
=
[[
640
,
48
],
[
320
,
96
],
[
160
,
192
]],
qk_dim
=
128
,
name
=
"loss_afd"
):
super
().
__init__
()
assert
isinstance
(
model_name_pair
,
list
)
self
.
model_name_pair
=
model_name_pair
self
.
student_keys
=
student_keys
self
.
teacher_keys
=
teacher_keys
self
.
s_shapes
=
[[
1
]
+
s_i
for
s_i
in
s_shapes
]
self
.
t_shapes
=
[[
1
]
+
t_i
for
t_i
in
t_shapes
]
self
.
qk_dim
=
qk_dim
self
.
n_t
,
self
.
unique_t_shapes
=
unique_shape
(
self
.
t_shapes
)
self
.
attention
=
Attention
(
self
.
qk_dim
,
self
.
t_shapes
,
self
.
s_shapes
,
self
.
n_t
,
self
.
unique_t_shapes
)
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
s_features_dict
=
predicts
[
self
.
model_name_pair
[
0
]]
t_features_dict
=
predicts
[
self
.
model_name_pair
[
1
]]
g_s
=
[
s_features_dict
[
key
]
for
key
in
self
.
student_keys
]
g_t
=
[
t_features_dict
[
key
]
for
key
in
self
.
teacher_keys
]
loss
=
self
.
attention
(
g_s
,
g_t
)
sum_loss
=
sum
(
loss
)
loss_dict
=
dict
()
loss_dict
[
self
.
name
]
=
sum_loss
return
loss_dict
class
Attention
(
nn
.
Layer
):
def
__init__
(
self
,
qk_dim
,
t_shapes
,
s_shapes
,
n_t
,
unique_t_shapes
):
super
().
__init__
()
self
.
qk_dim
=
qk_dim
self
.
n_t
=
n_t
# self.linear_trans_s = LinearTransformStudent(qk_dim, t_shapes, s_shapes, unique_t_shapes)
# self.linear_trans_t = LinearTransformTeacher(qk_dim, t_shapes)
self
.
p_t
=
self
.
create_parameter
(
shape
=
[
len
(
t_shapes
),
qk_dim
],
default_initializer
=
nn
.
initializer
.
XavierNormal
())
self
.
p_s
=
self
.
create_parameter
(
shape
=
[
len
(
s_shapes
),
qk_dim
],
default_initializer
=
nn
.
initializer
.
XavierNormal
())
def
forward
(
self
,
g_s
,
g_t
):
bilinear_key
,
h_hat_s_all
=
g_s
query
,
h_t_all
=
g_t
p_logit
=
paddle
.
matmul
(
self
.
p_t
,
self
.
p_s
.
t
())
logit
=
paddle
.
add
(
paddle
.
einsum
(
'bstq,btq->bts'
,
bilinear_key
,
query
),
p_logit
)
/
np
.
sqrt
(
self
.
qk_dim
)
atts
=
F
.
softmax
(
logit
,
axis
=
2
)
# b x t x s
loss
=
[]
for
i
,
(
n
,
h_t
)
in
enumerate
(
zip
(
self
.
n_t
,
h_t_all
)):
h_hat_s
=
h_hat_s_all
[
n
]
diff
=
self
.
cal_diff
(
h_hat_s
,
h_t
,
atts
[:,
i
])
loss
.
append
(
diff
)
return
loss
def
cal_diff
(
self
,
v_s
,
v_t
,
att
):
diff
=
(
v_s
-
v_t
.
unsqueeze
(
1
)).
pow
(
2
).
mean
(
2
)
diff
=
paddle
.
multiply
(
diff
,
att
).
sum
(
1
).
mean
()
return
diff
ppcls/loss/distillationloss.py
浏览文件 @
7595ba6d
...
@@ -14,11 +14,13 @@
...
@@ -14,11 +14,13 @@
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
.celoss
import
CELoss
from
.celoss
import
CELoss
from
.dmlloss
import
DMLLoss
from
.dmlloss
import
DMLLoss
from
.distanceloss
import
DistanceLoss
from
.distanceloss
import
DistanceLoss
from
.rkdloss
import
RKdAngle
,
RkdDistance
from
.rkdloss
import
RKdAngle
,
RkdDistance
from
.kldivloss
import
KLDivLoss
class
DistillationCELoss
(
CELoss
):
class
DistillationCELoss
(
CELoss
):
...
@@ -172,3 +174,33 @@ class DistillationRKDLoss(nn.Layer):
...
@@ -172,3 +174,33 @@ class DistillationRKDLoss(nn.Layer):
student_out
,
teacher_out
)
student_out
,
teacher_out
)
return
loss_dict
return
loss_dict
class
DistillationKLDivLoss
(
KLDivLoss
):
"""
DistillationKLDivLoss
"""
def
__init__
(
self
,
model_name_pairs
=
[],
temperature
=
4
,
key
=
None
,
name
=
"loss_kl"
):
super
().
__init__
(
temperature
=
temperature
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
)
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
])]
=
loss
[
key
]
return
loss_dict
ppcls/loss/kldivloss.py
0 → 100644
浏览文件 @
7595ba6d
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
KLDivLoss
(
nn
.
Layer
):
"""
Distilling the Knowledge in a Neural Network
"""
def
__init__
(
self
,
temperature
=
4
):
super
(
KLDivLoss
,
self
).
__init__
()
self
.
T
=
temperature
def
forward
(
self
,
y_s
,
y_t
):
p_s
=
F
.
log_softmax
(
y_s
/
self
.
T
,
axis
=
1
)
p_t
=
F
.
softmax
(
y_t
/
self
.
T
,
axis
=
1
)
loss
=
F
.
kl_div
(
p_s
,
p_t
,
reduction
=
'sum'
)
*
(
self
.
T
**
2
)
/
y_s
.
shape
[
0
]
return
{
"loss_kldiv"
:
loss
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录