Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
aea712cc
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aea712cc
编写于
1月 05, 2022
作者:
littletomatodonkey
提交者:
GitHub
1月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dist of rec model (#1574)
* add distillation loss func and rec distillation
上级
0aa85d4f
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
564 addition
and
24 deletion
+564
-24
ppcls/arch/__init__.py
ppcls/arch/__init__.py
+6
-1
ppcls/arch/backbone/legendary_models/mobilenet_v3.py
ppcls/arch/backbone/legendary_models/mobilenet_v3.py
+6
-2
ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_dml.yaml
...neralRecognition/GeneralRecognition_PPLCNet_x2_5_dml.yaml
+194
-0
ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_udml.yaml
...eralRecognition/GeneralRecognition_PPLCNet_x2_5_udml.yaml
+193
-0
ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml
...t/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml
+5
-3
ppcls/engine/engine.py
ppcls/engine/engine.py
+5
-2
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+4
-3
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+2
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+2
-0
ppcls/loss/distillationloss.py
ppcls/loss/distillationloss.py
+36
-3
ppcls/loss/dmlloss.py
ppcls/loss/dmlloss.py
+14
-10
ppcls/loss/rkdloss.py
ppcls/loss/rkdloss.py
+97
-0
未找到文件。
ppcls/arch/__init__.py
浏览文件 @
aea712cc
...
...
@@ -77,14 +77,19 @@ class RecModel(TheseusLayer):
self
.
head
=
None
def
forward
(
self
,
x
,
label
=
None
):
out
=
dict
()
x
=
self
.
backbone
(
x
)
out
[
"backbone"
]
=
x
if
self
.
neck
is
not
None
:
x
=
self
.
neck
(
x
)
out
[
"features"
]
=
x
if
self
.
head
is
not
None
:
y
=
self
.
head
(
x
,
label
)
out
[
"neck"
]
=
x
else
:
y
=
None
return
{
"features"
:
x
,
"logits"
:
y
}
out
[
"logits"
]
=
y
return
out
class
DistillationModel
(
nn
.
Layer
):
...
...
ppcls/arch/backbone/legendary_models/mobilenet_v3.py
浏览文件 @
aea712cc
...
...
@@ -196,7 +196,10 @@ class MobileNetV3(TheseusLayer):
bias_attr
=
False
)
self
.
hardswish
=
nn
.
Hardswish
()
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
if
dropout_prob
is
not
None
:
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
else
:
self
.
dropout
=
None
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
fc
=
Linear
(
self
.
class_expand
,
class_num
)
...
...
@@ -210,7 +213,8 @@ class MobileNetV3(TheseusLayer):
x
=
self
.
avg_pool
(
x
)
x
=
self
.
last_conv
(
x
)
x
=
self
.
hardswish
(
x
)
x
=
self
.
dropout
(
x
)
if
self
.
dropout
is
not
None
:
x
=
self
.
dropout
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
...
...
ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_dml.yaml
0 → 100644
浏览文件 @
aea712cc
# global configs
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
true
eval_interval
:
1
epochs
:
100
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
eval_mode
:
retrieval
use_dali
:
False
to_static
:
False
# model architecture
Arch
:
name
:
"
DistillationModel"
infer_output_key
:
features
infer_add_softmax
:
False
is_rec
:
True
infer_model_name
:
"
Student"
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
False
-
False
models
:
-
Teacher
:
name
:
RecModel
infer_output_key
:
features
infer_add_softmax
:
False
Backbone
:
name
:
PPLCNet_x2_5
pretrained
:
True
use_ssld
:
True
BackboneStopLayer
:
name
:
"
flatten"
Neck
:
name
:
FC
embedding_size
:
1280
class_num
:
512
Head
:
name
:
ArcMargin
embedding_size
:
512
class_num
:
185341
margin
:
0.2
scale
:
30
-
Student
:
name
:
RecModel
infer_output_key
:
features
infer_add_softmax
:
False
Backbone
:
name
:
PPLCNet_x2_5
pretrained
:
True
use_ssld
:
True
BackboneStopLayer
:
name
:
"
flatten"
Neck
:
name
:
FC
embedding_size
:
1280
class_num
:
512
Head
:
name
:
ArcMargin
embedding_size
:
512
class_num
:
185341
margin
:
0.2
scale
:
30
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationGTCELoss
:
weight
:
1.0
key
:
"
logits"
model_names
:
[
"
Student"
,
"
Teacher"
]
-
DistillationDMLLoss
:
weight
:
1.0
key
:
"
logits"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
-
DistillationDMLLoss
:
weight
:
1.0
key
:
"
logits"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
Eval
:
-
DistillationGTCELoss
:
weight
:
1.0
model_names
:
[
"
Student"
]
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.02
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/
cls_label_path
:
./dataset/train_reg_all_data.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Eval
:
-
Recallk
:
topk
:
[
1
,
5
]
ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_udml.yaml
0 → 100644
浏览文件 @
aea712cc
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
true
eval_interval
:
1
epochs
:
100
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
eval_mode
:
retrieval
use_dali
:
False
to_static
:
False
# model architecture
Arch
:
name
:
"
DistillationModel"
infer_output_key
:
features
infer_add_softmax
:
False
is_rec
:
True
infer_model_name
:
"
Student"
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
False
-
False
models
:
-
Teacher
:
name
:
RecModel
infer_output_key
:
features
infer_add_softmax
:
False
Backbone
:
name
:
PPLCNet_x2_5
pretrained
:
True
use_ssld
:
True
BackboneStopLayer
:
name
:
"
flatten"
Neck
:
name
:
FC
embedding_size
:
1280
class_num
:
512
Head
:
name
:
ArcMargin
embedding_size
:
512
class_num
:
185341
margin
:
0.2
scale
:
30
-
Student
:
name
:
RecModel
infer_output_key
:
features
infer_add_softmax
:
False
Backbone
:
name
:
PPLCNet_x2_5
pretrained
:
True
use_ssld
:
True
BackboneStopLayer
:
name
:
"
flatten"
Neck
:
name
:
FC
embedding_size
:
1280
class_num
:
512
Head
:
name
:
ArcMargin
embedding_size
:
512
class_num
:
185341
margin
:
0.2
scale
:
30
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationGTCELoss
:
weight
:
1.0
key
:
"
logits"
model_names
:
[
"
Student"
,
"
Teacher"
]
-
DistillationDMLLoss
:
weight
:
1.0
key
:
"
logits"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
-
DistillationDistanceLoss
:
weight
:
1.0
key
:
"
backbone"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
Eval
:
-
DistillationGTCELoss
:
weight
:
1.0
model_names
:
[
"
Student"
]
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.02
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/
cls_label_path
:
./dataset/train_reg_all_data.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Eval
:
-
Recallk
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml
浏览文件 @
aea712cc
...
...
@@ -13,6 +13,7 @@ Global:
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
"
./inference"
use_dali
:
false
# model architecture
Arch
:
...
...
@@ -29,9 +30,11 @@ Arch:
name
:
MobileNetV3_large_x1_0
pretrained
:
True
use_ssld
:
True
dropout_prob
:
null
-
Student
:
name
:
MobileNetV3_small_x1_0
pretrained
:
False
dropout_prob
:
null
infer_model_name
:
"
Student"
...
...
@@ -76,7 +79,6 @@ DataLoader:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
AutoAugment
:
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
...
...
@@ -85,7 +87,7 @@ DataLoader:
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
...
...
@@ -112,7 +114,7 @@ DataLoader:
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
...
...
ppcls/engine/engine.py
浏览文件 @
aea712cc
...
...
@@ -53,7 +53,8 @@ class Engine(object):
self
.
config
=
config
self
.
eval_mode
=
self
.
config
[
"Global"
].
get
(
"eval_mode"
,
"classification"
)
if
"Head"
in
self
.
config
[
"Arch"
]:
if
"Head"
in
self
.
config
[
"Arch"
]
or
self
.
config
[
"Arch"
].
get
(
"is_rec"
,
False
):
self
.
is_rec
=
True
else
:
self
.
is_rec
=
False
...
...
@@ -357,7 +358,9 @@ class Engine(object):
out
=
self
.
model
(
batch_tensor
)
if
isinstance
(
out
,
list
):
out
=
out
[
0
]
if
isinstance
(
out
,
dict
):
if
isinstance
(
out
,
dict
)
and
"logits"
in
out
:
out
=
out
[
"logits"
]
if
isinstance
(
out
,
dict
)
and
"output"
in
out
:
out
=
out
[
"output"
]
result
=
self
.
postprocess_func
(
out
,
image_file_list
)
print
(
result
)
...
...
ppcls/engine/evaluation/classification.py
浏览文件 @
aea712cc
...
...
@@ -78,10 +78,10 @@ def classification_eval(engine, epoch_id=0):
labels
=
paddle
.
concat
(
label_list
,
0
)
if
isinstance
(
out
,
dict
):
if
"logits"
in
out
:
out
=
out
[
"logits"
]
elif
"Student"
in
out
:
if
"Student"
in
out
:
out
=
out
[
"Student"
]
elif
"logits"
in
out
:
out
=
out
[
"logits"
]
else
:
msg
=
"Error: Wrong key in out!"
raise
Exception
(
msg
)
...
...
@@ -106,6 +106,7 @@ def classification_eval(engine, epoch_id=0):
metric_dict
=
engine
.
eval_metric_func
(
pred
,
labels
)
else
:
metric_dict
=
engine
.
eval_metric_func
(
out
,
batch
[
1
])
for
key
in
metric_dict
:
if
metric_key
is
None
:
metric_key
=
key
...
...
ppcls/engine/evaluation/retrieval.py
浏览文件 @
aea712cc
...
...
@@ -123,6 +123,8 @@ def cal_feature(engine, name='gallery'):
has_unique_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
out
=
engine
.
model
(
batch
[
0
],
batch
[
1
])
if
"Student"
in
out
:
out
=
out
[
"Student"
]
batch_feas
=
out
[
"features"
]
# do norm
...
...
ppcls/loss/__init__.py
浏览文件 @
aea712cc
...
...
@@ -20,6 +20,8 @@ from .distanceloss import DistanceLoss
from
.distillationloss
import
DistillationCELoss
from
.distillationloss
import
DistillationGTCELoss
from
.distillationloss
import
DistillationDMLLoss
from
.distillationloss
import
DistillationDistanceLoss
from
.distillationloss
import
DistillationRKDLoss
from
.multilabelloss
import
MultiLabelLoss
from
.deephashloss
import
DSHSDLoss
,
LCDSHLoss
...
...
ppcls/loss/distillationloss.py
浏览文件 @
aea712cc
...
...
@@ -18,6 +18,7 @@ import paddle.nn as nn
from
.celoss
import
CELoss
from
.dmlloss
import
DMLLoss
from
.distanceloss
import
DistanceLoss
from
.rkdloss
import
RKdAngle
,
RkdDistance
class
DistillationCELoss
(
CELoss
):
...
...
@@ -68,7 +69,7 @@ class DistillationGTCELoss(CELoss):
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
name
in
enumerate
(
self
.
model_names
):
for
_
,
name
in
enumerate
(
self
.
model_names
):
out
=
predicts
[
name
]
if
self
.
key
is
not
None
:
out
=
out
[
self
.
key
]
...
...
@@ -84,7 +85,7 @@ class DistillationDMLLoss(DMLLoss):
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
act
=
"softmax"
,
key
=
None
,
name
=
"loss_dml"
):
super
().
__init__
(
act
=
act
)
...
...
@@ -125,7 +126,7 @@ class DistillationDistanceLoss(DistanceLoss):
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
+
"_l2"
self
.
name
=
name
+
mode
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
...
...
@@ -139,3 +140,35 @@ class DistillationDistanceLoss(DistanceLoss):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
key
,
idx
)]
=
loss
[
key
]
return
loss_dict
class
DistillationRKDLoss
(
nn
.
Layer
):
def
__init__
(
self
,
target_size
=
None
,
model_name_pairs
=
([
"Student"
,
"Teacher"
],
),
student_keepkeys
=
[],
teacher_keepkeys
=
[]):
super
().
__init__
()
self
.
student_keepkeys
=
student_keepkeys
self
.
teacher_keepkeys
=
teacher_keepkeys
self
.
model_name_pairs
=
model_name_pairs
assert
len
(
self
.
student_keepkeys
)
==
len
(
self
.
teacher_keepkeys
)
self
.
rkd_angle_loss
=
RKdAngle
(
target_size
=
target_size
)
self
.
rkd_dist_loss
=
RkdDistance
(
target_size
=
target_size
)
def
__call__
(
self
,
predicts
,
batch
):
loss_dict
=
{}
for
m1
,
m2
in
self
.
model_name_pairs
:
for
idx
,
(
student_name
,
teacher_name
)
in
enumerate
(
zip
(
self
.
student_keepkeys
,
self
.
teacher_keepkeys
)):
student_out
=
predicts
[
m1
][
student_name
]
teacher_out
=
predicts
[
m2
][
teacher_name
]
loss_dict
[
f
"loss_angle_
{
idx
}
_
{
m1
}
_
{
m2
}
"
]
=
self
.
rkd_angle_loss
(
student_out
,
teacher_out
)
loss_dict
[
f
"loss_dist_
{
idx
}
_
{
m1
}
_
{
m2
}
"
]
=
self
.
rkd_dist_loss
(
student_out
,
teacher_out
)
return
loss_dict
ppcls/loss/dmlloss.py
浏览文件 @
aea712cc
...
...
@@ -22,7 +22,7 @@ class DMLLoss(nn.Layer):
DMLLoss
"""
def
__init__
(
self
,
act
=
"softmax"
):
def
__init__
(
self
,
act
=
"softmax"
,
eps
=
1e-12
):
super
().
__init__
()
if
act
is
not
None
:
assert
act
in
[
"softmax"
,
"sigmoid"
]
...
...
@@ -32,15 +32,19 @@ class DMLLoss(nn.Layer):
self
.
act
=
nn
.
Sigmoid
()
else
:
self
.
act
=
None
self
.
eps
=
eps
def
forward
(
self
,
out1
,
out2
):
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
def
_kldiv
(
self
,
x
,
target
):
class_num
=
x
.
shape
[
-
1
]
cost
=
target
*
paddle
.
log
(
(
target
+
self
.
eps
)
/
(
x
+
self
.
eps
))
*
class_num
return
cost
log_out1
=
paddle
.
log
(
out1
)
log_out2
=
paddle
.
log
(
out2
)
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
def
forward
(
self
,
x
,
target
):
if
self
.
act
is
not
None
:
x
=
F
.
softmax
(
x
)
target
=
F
.
softmax
(
target
)
loss
=
self
.
_kldiv
(
x
,
target
)
+
self
.
_kldiv
(
target
,
x
)
loss
=
loss
/
2
loss
=
paddle
.
mean
(
loss
)
return
{
"DMLLoss"
:
loss
}
ppcls/loss/rkdloss.py
0 → 100644
浏览文件 @
aea712cc
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
def
pdist
(
e
,
squared
=
False
,
eps
=
1e-12
):
e_square
=
e
.
pow
(
2
).
sum
(
axis
=
1
)
prod
=
paddle
.
mm
(
e
,
e
.
t
())
res
=
(
e_square
.
unsqueeze
(
1
)
+
e_square
.
unsqueeze
(
0
)
-
2
*
prod
).
clip
(
min
=
eps
)
if
not
squared
:
res
=
res
.
sqrt
()
return
res
class
RKdAngle
(
nn
.
Layer
):
# reference: https://github.com/lenscloth/RKD/blob/master/metric/loss.py
def
__init__
(
self
,
target_size
=
None
):
super
().
__init__
()
if
target_size
is
not
None
:
self
.
avgpool
=
paddle
.
nn
.
AdaptiveAvgPool2D
(
target_size
)
else
:
self
.
avgpool
=
None
def
forward
(
self
,
student
,
teacher
):
# GAP to reduce memory
if
self
.
avgpool
is
not
None
:
# NxC1xH1xW1 -> NxC1x1x1
student
=
self
.
avgpool
(
student
)
# NxC2xH2xW2 -> NxC2x1x1
teacher
=
self
.
avgpool
(
teacher
)
# reshape for feature map distillation
bs
=
student
.
shape
[
0
]
student
=
student
.
reshape
([
bs
,
-
1
])
teacher
=
teacher
.
reshape
([
bs
,
-
1
])
td
=
(
teacher
.
unsqueeze
(
0
)
-
teacher
.
unsqueeze
(
1
))
norm_td
=
F
.
normalize
(
td
,
p
=
2
,
axis
=
2
)
t_angle
=
paddle
.
bmm
(
norm_td
,
norm_td
.
transpose
([
0
,
2
,
1
])).
reshape
(
[
-
1
,
1
])
sd
=
(
student
.
unsqueeze
(
0
)
-
student
.
unsqueeze
(
1
))
norm_sd
=
F
.
normalize
(
sd
,
p
=
2
,
axis
=
2
)
s_angle
=
paddle
.
bmm
(
norm_sd
,
norm_sd
.
transpose
([
0
,
2
,
1
])).
reshape
(
[
-
1
,
1
])
loss
=
F
.
smooth_l1_loss
(
s_angle
,
t_angle
,
reduction
=
'mean'
)
return
loss
class
RkdDistance
(
nn
.
Layer
):
# reference: https://github.com/lenscloth/RKD/blob/master/metric/loss.py
def
__init__
(
self
,
eps
=
1e-12
,
target_size
=
1
):
super
().
__init__
()
self
.
eps
=
eps
if
target_size
is
not
None
:
self
.
avgpool
=
paddle
.
nn
.
AdaptiveAvgPool2D
(
target_size
)
else
:
self
.
avgpool
=
None
def
forward
(
self
,
student
,
teacher
):
# GAP to reduce memory
if
self
.
avgpool
is
not
None
:
# NxC1xH1xW1 -> NxC1x1x1
student
=
self
.
avgpool
(
student
)
# NxC2xH2xW2 -> NxC2x1x1
teacher
=
self
.
avgpool
(
teacher
)
bs
=
student
.
shape
[
0
]
student
=
student
.
reshape
([
bs
,
-
1
])
teacher
=
teacher
.
reshape
([
bs
,
-
1
])
t_d
=
pdist
(
teacher
,
squared
=
False
)
mean_td
=
t_d
.
mean
()
t_d
=
t_d
/
(
mean_td
+
self
.
eps
)
d
=
pdist
(
student
,
squared
=
False
)
mean_d
=
d
.
mean
()
d
=
d
/
(
mean_d
+
self
.
eps
)
loss
=
F
.
smooth_l1_loss
(
d
,
t_d
,
reduction
=
"mean"
)
return
loss
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录