Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
098bc1d8
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
098bc1d8
编写于
5月 10, 2022
作者:
Y
Yang Nie
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/develop' into ConvNeXt
上级
7fa948f8
44112687
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
366 addition
and
4 deletion
+366
-4
ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
.../ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
+155
-0
ppcls/configs/Pedestrian/strong_baseline_baseline.yaml
ppcls/configs/Pedestrian/strong_baseline_baseline.yaml
+8
-1
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/loss/distillationloss.py
ppcls/loss/distillationloss.py
+31
-0
ppcls/loss/dkdloss.py
ppcls/loss/dkdloss.py
+61
-0
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+2
-3
test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_amp_infer_python.txt
.../resnet34_distill_resnet18_dkd_train_amp_infer_python.txt
+54
-0
test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_infer_python.txt
...tion/resnet34_distill_resnet18_dkd_train_infer_python.txt
+54
-0
未找到文件。
ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
0 → 100644
浏览文件 @
098bc1d8
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
"
./inference"
# model architecture
Arch
:
name
:
"
DistillationModel"
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
True
-
False
models
:
-
Teacher
:
name
:
ResNet34
pretrained
:
True
-
Student
:
name
:
ResNet18
pretrained
:
False
infer_model_name
:
"
Student"
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationGTCELoss
:
weight
:
1.0
model_names
:
[
"
Student"
]
-
DistillationDKDLoss
:
weight
:
1.0
model_name_pairs
:
[[
"
Student"
,
"
Teacher"
]]
temperature
:
1
alpha
:
1.0
beta
:
1.0
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
weight_decay
:
1e-4
lr
:
name
:
MultiStepDecay
learning_rate
:
0.2
milestones
:
[
30
,
60
,
90
]
step_each_epoch
:
1
gamma
:
0.1
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/train_list.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/val_list.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
"
docs/images/inference_deployment/whl_demo.jpg"
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
DistillationPostProcess
func
:
Topk
topk
:
5
class_id_map_file
:
"
ppcls/utils/imagenet1k_label_list.txt"
Metric
:
Train
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
Eval
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
ppcls/configs/Pedestrian/strong_baseline_baseline.yaml
浏览文件 @
098bc1d8
...
@@ -31,6 +31,14 @@ Arch:
...
@@ -31,6 +31,14 @@ Arch:
name
:
"
FC"
name
:
"
FC"
embedding_size
:
2048
embedding_size
:
2048
class_num
:
751
class_num
:
751
weight_attr
:
initializer
:
name
:
KaimingUniform
fan_in
:
12288
# 6*embedding_size
bias_attr
:
initializer
:
name
:
KaimingUniform
fan_in
:
12288
# 6*embedding_size
# loss function config for traing/eval process
# loss function config for traing/eval process
Loss
:
Loss
:
...
@@ -52,7 +60,6 @@ Optimizer:
...
@@ -52,7 +60,6 @@ Optimizer:
name
:
Piecewise
name
:
Piecewise
decay_epochs
:
[
40
,
70
]
decay_epochs
:
[
40
,
70
]
values
:
[
0.00035
,
0.000035
,
0.0000035
]
values
:
[
0.00035
,
0.000035
,
0.0000035
]
warmup_epoch
:
10
by_epoch
:
True
by_epoch
:
True
last_epoch
:
0
last_epoch
:
0
regularizer
:
regularizer
:
...
...
ppcls/loss/__init__.py
浏览文件 @
098bc1d8
...
@@ -23,6 +23,7 @@ from .distillationloss import DistillationDMLLoss
...
@@ -23,6 +23,7 @@ from .distillationloss import DistillationDMLLoss
from
.distillationloss
import
DistillationDistanceLoss
from
.distillationloss
import
DistillationDistanceLoss
from
.distillationloss
import
DistillationRKDLoss
from
.distillationloss
import
DistillationRKDLoss
from
.distillationloss
import
DistillationKLDivLoss
from
.distillationloss
import
DistillationKLDivLoss
from
.distillationloss
import
DistillationDKDLoss
from
.multilabelloss
import
MultiLabelLoss
from
.multilabelloss
import
MultiLabelLoss
from
.afdloss
import
AFDLoss
from
.afdloss
import
AFDLoss
...
...
ppcls/loss/distillationloss.py
浏览文件 @
098bc1d8
...
@@ -21,6 +21,7 @@ from .dmlloss import DMLLoss
...
@@ -21,6 +21,7 @@ from .dmlloss import DMLLoss
from
.distanceloss
import
DistanceLoss
from
.distanceloss
import
DistanceLoss
from
.rkdloss
import
RKdAngle
,
RkdDistance
from
.rkdloss
import
RKdAngle
,
RkdDistance
from
.kldivloss
import
KLDivLoss
from
.kldivloss
import
KLDivLoss
from
.dkdloss
import
DKDLoss
class
DistillationCELoss
(
CELoss
):
class
DistillationCELoss
(
CELoss
):
...
@@ -204,3 +205,33 @@ class DistillationKLDivLoss(KLDivLoss):
...
@@ -204,3 +205,33 @@ class DistillationKLDivLoss(KLDivLoss):
for
key
in
loss
:
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
])]
=
loss
[
key
]
loss_dict
[
"{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
])]
=
loss
[
key
]
return
loss_dict
return
loss_dict
class
DistillationDKDLoss
(
DKDLoss
):
"""
DistillationDKDLoss
"""
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
temperature
=
1.0
,
alpha
=
1.0
,
beta
=
1.0
,
name
=
"loss_dkd"
):
super
().
__init__
(
temperature
=
temperature
,
alpha
=
alpha
,
beta
=
beta
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
,
batch
)
loss_dict
[
f
"
{
self
.
name
}
_
{
pair
[
0
]
}
_
{
pair
[
1
]
}
"
]
=
loss
return
loss_dict
ppcls/loss/dkdloss.py
0 → 100644
浏览文件 @
098bc1d8
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
DKDLoss
(
nn
.
Layer
):
"""
DKDLoss
Reference: https://arxiv.org/abs/2203.08679
Code was heavily based on https://github.com/megvii-research/mdistiller
"""
def
__init__
(
self
,
temperature
=
1.0
,
alpha
=
1.0
,
beta
=
1.0
):
super
().
__init__
()
self
.
temperature
=
temperature
self
.
alpha
=
alpha
self
.
beta
=
beta
def
forward
(
self
,
logits_student
,
logits_teacher
,
target
):
gt_mask
=
_get_gt_mask
(
logits_student
,
target
)
other_mask
=
1
-
gt_mask
pred_student
=
F
.
softmax
(
logits_student
/
self
.
temperature
,
axis
=
1
)
pred_teacher
=
F
.
softmax
(
logits_teacher
/
self
.
temperature
,
axis
=
1
)
pred_student
=
cat_mask
(
pred_student
,
gt_mask
,
other_mask
)
pred_teacher
=
cat_mask
(
pred_teacher
,
gt_mask
,
other_mask
)
log_pred_student
=
paddle
.
log
(
pred_student
)
tckd_loss
=
(
F
.
kl_div
(
log_pred_student
,
pred_teacher
,
reduction
=
'sum'
)
*
(
self
.
temperature
**
2
)
/
target
.
shape
[
0
])
pred_teacher_part2
=
F
.
softmax
(
logits_teacher
/
self
.
temperature
-
1000.0
*
gt_mask
,
axis
=
1
)
log_pred_student_part2
=
F
.
log_softmax
(
logits_student
/
self
.
temperature
-
1000.0
*
gt_mask
,
axis
=
1
)
nckd_loss
=
(
F
.
kl_div
(
log_pred_student_part2
,
pred_teacher_part2
,
reduction
=
'sum'
)
*
(
self
.
temperature
**
2
)
/
target
.
shape
[
0
])
return
self
.
alpha
*
tckd_loss
+
self
.
beta
*
nckd_loss
def
_get_gt_mask
(
logits
,
target
):
target
=
target
.
reshape
([
-
1
]).
unsqueeze
(
1
)
updates
=
paddle
.
ones_like
(
target
)
mask
=
scatter
(
paddle
.
zeros_like
(
logits
),
target
,
updates
.
astype
(
'float32'
))
return
mask
def
cat_mask
(
t
,
mask1
,
mask2
):
t1
=
(
t
*
mask1
).
sum
(
axis
=
1
,
keepdim
=
True
)
t2
=
(
t
*
mask2
).
sum
(
axis
=
1
,
keepdim
=
True
)
rt
=
paddle
.
concat
([
t1
,
t2
],
axis
=
1
)
return
rt
def
scatter
(
x
,
index
,
updates
):
i
,
j
=
index
.
shape
grid_x
,
grid_y
=
paddle
.
meshgrid
(
paddle
.
arange
(
i
),
paddle
.
arange
(
j
))
index
=
paddle
.
stack
([
grid_x
.
flatten
(),
index
.
flatten
()],
axis
=
1
)
updates_index
=
paddle
.
stack
([
grid_x
.
flatten
(),
grid_y
.
flatten
()],
axis
=
1
)
updates
=
paddle
.
gather_nd
(
updates
,
index
=
updates_index
)
return
paddle
.
scatter_nd_add
(
x
,
index
,
updates
)
ppcls/utils/save_load.py
浏览文件 @
098bc1d8
...
@@ -125,9 +125,8 @@ def init_model(config,
...
@@ -125,9 +125,8 @@ def init_model(config,
load_distillation_model
(
net
,
pretrained_model
)
load_distillation_model
(
net
,
pretrained_model
)
else
:
# common load
else
:
# common load
load_dygraph_pretrain
(
net
,
path
=
pretrained_model
)
load_dygraph_pretrain
(
net
,
path
=
pretrained_model
)
logger
.
info
(
logger
.
info
(
"Finish load pretrained model from {}"
.
format
(
logger
.
coloring
(
"Finish load pretrained model from {}"
.
format
(
pretrained_model
))
pretrained_model
),
"HEADER"
))
def
save_model
(
net
,
def
save_model
(
net
,
...
...
test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_amp_infer_python.txt
0 → 100644
浏览文件 @
098bc1d8
===========================train_params===========================
model_name:DistillationModel
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=100
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:amp_train
amp_train:tools/train.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o AMP.scale_loss=128 -o AMP.use_dynamic_loss_scaling=True -o AMP.level=O2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_infer_python.txt
0 → 100644
浏览文件 @
098bc1d8
===========================train_params===========================
model_name:DistillationModel
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=100
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录