提交 c5884bb2 编写于 作者: U user3984 提交者: littletomatodonkey

add skd

上级 221cbe47
......@@ -17,6 +17,7 @@
- [1.2.6 DIST](#1.2.6)
- [1.2.7 MGD](#1.2.7)
- [1.2.8 WSL](#1.2.8)
- [1.2.9 SKD](#1.2.9)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
......@@ -654,6 +655,72 @@ Loss:
weight: 1.0
```
<a name='1.2.9'></a>
#### 1.2.9 SKD
##### 1.2.9.1 SKD 算法介绍
论文信息:
> [Reducing the Teacher-Student Gap via Spherical Knowledge Disitllation](https://arxiv.org/abs/2010.07485)
>
> Jia Guo, Minghao Chen, Yao Hu, Chen Zhu, Xiaofei He, Deng Cai
>
> 2022, under review
使用更大、精度更高的教师模型蒸馏学生模型,学生模型的精度往往反而降低。SKD (Spherical Knowledge Disitllation) 方法显式地消除了教师与学生之间的置信度差距,缓解了教师与学生之间的容量差距问题。SKD在ImageNet1k上蒸馏ResNet18的任务上显著超越了SOTA。
在ImageNet1k公开数据集上,效果如下所示。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| SKD | ResNet18 | [resnet34_distill_resnet18_skd.yaml](../../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml) | 72.84%(**+2.04%**) | - |
##### 1.2.9.2 SKD 配置
SKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationSKDLoss`(学生与教师之间的SKD loss),作为训练的损失函数。
```yaml
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationSKDLoss:
weight: 1.0
model_name_pairs: [["Student", "Teacher"]]
temperature: 1.0
multiplier: 2.0
alpha: 0.9
Eval:
- CELoss:
weight: 1.0
```
<a name="2"></a>
## 2. 模型训练、评估和预测
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationSKDLoss:
weight: 1.0
model_name_pairs: [["Student", "Teacher"]]
temperature: 1.0
multiplier: 2.0
alpha: 0.9
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 1e-4
lr:
name: MultiStepDecay
learning_rate: 0.2
milestones: [30, 60, 90]
step_each_epoch: 1
gamma: 0.1
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
......@@ -29,6 +29,7 @@ from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
from .distillationloss import DistillationWSLLoss
from .distillationloss import DistillationSKDLoss
from .distillationloss import DistillationMultiLabelLoss
from .distillationloss import DistillationDISTLoss
from .distillationloss import DistillationPairLoss
......
......@@ -26,6 +26,7 @@ from .wslloss import WSLLoss
from .dist_loss import DISTLoss
from .multilabelloss import MultiLabelLoss
from .mgd_loss import MGDLoss
from .skdloss import SKDLoss
class DistillationCELoss(CELoss):
......@@ -291,6 +292,37 @@ class DistillationWSLLoss(WSLLoss):
return loss_dict
class DistillationSKDLoss(SKDLoss):
"""
DistillationSKDLoss
"""
def __init__(self,
model_name_pairs=[],
key=None,
temperature=1.0,
multiplier=2.0,
alpha=0.9,
use_target_as_gt=False,
name="skd_loss"):
super().__init__(temperature, multiplier, alpha, use_target_as_gt)
self.model_name_pairs = model_name_pairs
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2, batch)
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
return loss_dict
class DistillationMultiLabelLoss(MultiLabelLoss):
"""
DistillationMultiLabelLoss
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class SKDLoss(nn.Layer):
"""
Spherical Knowledge Distillation
paper: https://arxiv.org/pdf/2010.07485.pdf
code reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
"""
def __init__(self,
temperature,
multiplier=2.0,
alpha=0.9,
use_target_as_gt=False):
super().__init__()
self.temperature = temperature
self.multiplier = multiplier
self.alpha = alpha
self.use_target_as_gt = use_target_as_gt
def forward(self, logits_student, logits_teacher, target=None):
"""Compute Spherical Knowledge Distillation loss.
Args:
logits_student: student's logits with shape (batch_size, num_classes)
logits_teacher: teacher's logits with shape (batch_size, num_classes)
"""
if target is None or self.use_target_as_gt:
target = logits_teacher.argmax(axis=-1)
target = F.one_hot(
target.reshape([-1]), num_classes=logits_student[0].shape[0])
logits_student = F.layer_norm(
logits_student,
logits_student.shape[1:],
weight=None,
bias=None,
epsilon=1e-7) * self.multiplier
logits_teacher = F.layer_norm(
logits_teacher,
logits_teacher.shape[1:],
weight=None,
bias=None,
epsilon=1e-7) * self.multiplier
kd_loss = -paddle.sum(F.softmax(logits_teacher / self.temperature) *
F.log_softmax(logits_student / self.temperature),
axis=1)
kd_loss = paddle.mean(kd_loss) * self.temperature**2
ce_loss = paddle.mean(-paddle.sum(
target * F.log_softmax(logits_student), axis=1))
return kd_loss * self.alpha + ce_loss * (1 - self.alpha)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册