Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
c5884bb2
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c5884bb2
编写于
11月 08, 2022
作者:
U
user3984
提交者:
littletomatodonkey
11月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add skd
上级
221cbe47
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
323 addition
and
0 deletion
+323
-0
docs/zh_CN/training/advanced/knowledge_distillation.md
docs/zh_CN/training/advanced/knowledge_distillation.md
+67
-0
ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml
.../ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml
+151
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/loss/distillationloss.py
ppcls/loss/distillationloss.py
+32
-0
ppcls/loss/skdloss.py
ppcls/loss/skdloss.py
+72
-0
未找到文件。
docs/zh_CN/training/advanced/knowledge_distillation.md
浏览文件 @
c5884bb2
...
...
@@ -17,6 +17,7 @@
-
[
1.2.6 DIST
](
#1.2.6
)
-
[
1.2.7 MGD
](
#1.2.7
)
-
[
1.2.8 WSL
](
#1.2.8
)
-
[
1.2.9 SKD
](
#1.2.9
)
-
[
2. 使用方法
](
#2
)
-
[
2.1 环境配置
](
#2.1
)
-
[
2.2 数据准备
](
#2.2
)
...
...
@@ -654,6 +655,72 @@ Loss:
weight
:
1.0
```
<a
name=
'1.2.9'
></a>
#### 1.2.9 SKD
##### 1.2.9.1 SKD 算法介绍
论文信息:
> [Reducing the Teacher-Student Gap via Spherical Knowledge Disitllation](https://arxiv.org/abs/2010.07485)
>
> Jia Guo, Minghao Chen, Yao Hu, Chen Zhu, Xiaofei He, Deng Cai
>
> 2022, under review
使用更大、精度更高的教师模型蒸馏学生模型,学生模型的精度往往反而降低。SKD (Spherical Knowledge Disitllation) 方法显式地消除了教师与学生之间的置信度差距,缓解了教师与学生之间的容量差距问题。SKD在ImageNet1k上蒸馏ResNet18的任务上显著超越了SOTA。
在ImageNet1k公开数据集上,效果如下所示。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 |
[
ResNet18.yaml
](
../../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml
)
| 70.8% | - |
| SKD | ResNet18 |
[
resnet34_distill_resnet18_skd.yaml
](
../../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml
)
| 72.84%(
**+2.04%**
) | - |
##### 1.2.9.2 SKD 配置
SKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义
`DistillationSKDLoss`
(学生与教师之间的SKD loss),作为训练的损失函数。
```
yaml
# model architecture
Arch
:
name
:
"
DistillationModel"
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
True
-
False
models
:
-
Teacher
:
name
:
ResNet34
pretrained
:
True
-
Student
:
name
:
ResNet18
pretrained
:
False
infer_model_name
:
"
Student"
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationSKDLoss
:
weight
:
1.0
model_name_pairs
:
[[
"
Student"
,
"
Teacher"
]]
temperature
:
1.0
multiplier
:
2.0
alpha
:
0.9
Eval
:
-
CELoss
:
weight
:
1.0
```
<a
name=
"2"
></a>
## 2. 模型训练、评估和预测
...
...
ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml
0 → 100644
浏览文件 @
c5884bb2
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
"
./inference"
# model architecture
Arch
:
name
:
"
DistillationModel"
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
True
-
False
models
:
-
Teacher
:
name
:
ResNet34
pretrained
:
True
-
Student
:
name
:
ResNet18
pretrained
:
False
infer_model_name
:
"
Student"
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationSKDLoss
:
weight
:
1.0
model_name_pairs
:
[[
"
Student"
,
"
Teacher"
]]
temperature
:
1.0
multiplier
:
2.0
alpha
:
0.9
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
weight_decay
:
1e-4
lr
:
name
:
MultiStepDecay
learning_rate
:
0.2
milestones
:
[
30
,
60
,
90
]
step_each_epoch
:
1
gamma
:
0.1
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/train_list.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/val_list.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
"
docs/images/inference_deployment/whl_demo.jpg"
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
"
ppcls/utils/imagenet1k_label_list.txt"
Metric
:
Train
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
Eval
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
ppcls/loss/__init__.py
浏览文件 @
c5884bb2
...
...
@@ -29,6 +29,7 @@ from .distillationloss import DistillationRKDLoss
from
.distillationloss
import
DistillationKLDivLoss
from
.distillationloss
import
DistillationDKDLoss
from
.distillationloss
import
DistillationWSLLoss
from
.distillationloss
import
DistillationSKDLoss
from
.distillationloss
import
DistillationMultiLabelLoss
from
.distillationloss
import
DistillationDISTLoss
from
.distillationloss
import
DistillationPairLoss
...
...
ppcls/loss/distillationloss.py
浏览文件 @
c5884bb2
...
...
@@ -26,6 +26,7 @@ from .wslloss import WSLLoss
from
.dist_loss
import
DISTLoss
from
.multilabelloss
import
MultiLabelLoss
from
.mgd_loss
import
MGDLoss
from
.skdloss
import
SKDLoss
class
DistillationCELoss
(
CELoss
):
...
...
@@ -291,6 +292,37 @@ class DistillationWSLLoss(WSLLoss):
return
loss_dict
class
DistillationSKDLoss
(
SKDLoss
):
"""
DistillationSKDLoss
"""
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
temperature
=
1.0
,
multiplier
=
2.0
,
alpha
=
0.9
,
use_target_as_gt
=
False
,
name
=
"skd_loss"
):
super
().
__init__
(
temperature
,
multiplier
,
alpha
,
use_target_as_gt
)
self
.
model_name_pairs
=
model_name_pairs
self
.
key
=
key
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
,
batch
)
loss_dict
[
f
"
{
self
.
name
}
_
{
pair
[
0
]
}
_
{
pair
[
1
]
}
"
]
=
loss
return
loss_dict
class
DistillationMultiLabelLoss
(
MultiLabelLoss
):
"""
DistillationMultiLabelLoss
...
...
ppcls/loss/skdloss.py
0 → 100644
浏览文件 @
c5884bb2
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
SKDLoss
(
nn
.
Layer
):
"""
Spherical Knowledge Distillation
paper: https://arxiv.org/pdf/2010.07485.pdf
code reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
"""
def
__init__
(
self
,
temperature
,
multiplier
=
2.0
,
alpha
=
0.9
,
use_target_as_gt
=
False
):
super
().
__init__
()
self
.
temperature
=
temperature
self
.
multiplier
=
multiplier
self
.
alpha
=
alpha
self
.
use_target_as_gt
=
use_target_as_gt
def
forward
(
self
,
logits_student
,
logits_teacher
,
target
=
None
):
"""Compute Spherical Knowledge Distillation loss.
Args:
logits_student: student's logits with shape (batch_size, num_classes)
logits_teacher: teacher's logits with shape (batch_size, num_classes)
"""
if
target
is
None
or
self
.
use_target_as_gt
:
target
=
logits_teacher
.
argmax
(
axis
=-
1
)
target
=
F
.
one_hot
(
target
.
reshape
([
-
1
]),
num_classes
=
logits_student
[
0
].
shape
[
0
])
logits_student
=
F
.
layer_norm
(
logits_student
,
logits_student
.
shape
[
1
:],
weight
=
None
,
bias
=
None
,
epsilon
=
1e-7
)
*
self
.
multiplier
logits_teacher
=
F
.
layer_norm
(
logits_teacher
,
logits_teacher
.
shape
[
1
:],
weight
=
None
,
bias
=
None
,
epsilon
=
1e-7
)
*
self
.
multiplier
kd_loss
=
-
paddle
.
sum
(
F
.
softmax
(
logits_teacher
/
self
.
temperature
)
*
F
.
log_softmax
(
logits_student
/
self
.
temperature
),
axis
=
1
)
kd_loss
=
paddle
.
mean
(
kd_loss
)
*
self
.
temperature
**
2
ce_loss
=
paddle
.
mean
(
-
paddle
.
sum
(
target
*
F
.
log_softmax
(
logits_student
),
axis
=
1
))
return
kd_loss
*
self
.
alpha
+
ce_loss
*
(
1
-
self
.
alpha
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录