From c5884bb24f63d18488d773321b3b2785d2a69579 Mon Sep 17 00:00:00 2001
From: user3984 <2287245853@qq.com>
Date: Tue, 8 Nov 2022 11:08:17 +0000
Subject: [PATCH] add skd
---
.../advanced/knowledge_distillation.md | 67 ++++++++
.../resnet34_distill_resnet18_skd.yaml | 151 ++++++++++++++++++
ppcls/loss/__init__.py | 1 +
ppcls/loss/distillationloss.py | 32 ++++
ppcls/loss/skdloss.py | 72 +++++++++
5 files changed, 323 insertions(+)
create mode 100644 ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml
create mode 100644 ppcls/loss/skdloss.py
diff --git a/docs/zh_CN/training/advanced/knowledge_distillation.md b/docs/zh_CN/training/advanced/knowledge_distillation.md
index 9a6b15ce..be526c97 100644
--- a/docs/zh_CN/training/advanced/knowledge_distillation.md
+++ b/docs/zh_CN/training/advanced/knowledge_distillation.md
@@ -17,6 +17,7 @@
- [1.2.6 DIST](#1.2.6)
- [1.2.7 MGD](#1.2.7)
- [1.2.8 WSL](#1.2.8)
+ - [1.2.9 SKD](#1.2.9)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
@@ -654,6 +655,72 @@ Loss:
weight: 1.0
```
+
+
+#### 1.2.9 SKD
+
+##### 1.2.9.1 SKD 算法介绍
+
+论文信息:
+
+
+> [Reducing the Teacher-Student Gap via Spherical Knowledge Disitllation](https://arxiv.org/abs/2010.07485)
+>
+> Jia Guo, Minghao Chen, Yao Hu, Chen Zhu, Xiaofei He, Deng Cai
+>
+> 2022, under review
+
+使用更大、精度更高的教师模型蒸馏学生模型,学生模型的精度往往反而降低。SKD (Spherical Knowledge Disitllation) 方法显式地消除了教师与学生之间的置信度差距,缓解了教师与学生之间的容量差距问题。SKD在ImageNet1k上蒸馏ResNet18的任务上显著超越了SOTA。
+
+在ImageNet1k公开数据集上,效果如下所示。
+
+| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
+| --- | --- | --- | --- | --- |
+| baseline | ResNet18 | [ResNet18.yaml](../../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
+| SKD | ResNet18 | [resnet34_distill_resnet18_skd.yaml](../../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml) | 72.84%(**+2.04%**) | - |
+
+
+##### 1.2.9.2 SKD 配置
+
+SKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationSKDLoss`(学生与教师之间的SKD loss),作为训练的损失函数。
+
+
+```yaml
+# model architecture
+Arch:
+ name: "DistillationModel"
+ # if not null, its lengths should be same as models
+ pretrained_list:
+ # if not null, its lengths should be same as models
+ freeze_params_list:
+ - True
+ - False
+ models:
+ - Teacher:
+ name: ResNet34
+ pretrained: True
+
+ - Student:
+ name: ResNet18
+ pretrained: False
+
+ infer_model_name: "Student"
+
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - DistillationSKDLoss:
+ weight: 1.0
+ model_name_pairs: [["Student", "Teacher"]]
+ temperature: 1.0
+ multiplier: 2.0
+ alpha: 0.9
+ Eval:
+ - CELoss:
+ weight: 1.0
+```
+
## 2. 模型训练、评估和预测
diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml
new file mode 100644
index 00000000..798717d2
--- /dev/null
+++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml
@@ -0,0 +1,151 @@
+# global configs
+Global:
+ checkpoints: null
+ pretrained_model: null
+ output_dir: "./output/"
+ device: "gpu"
+ save_interval: 1
+ eval_during_train: True
+ eval_interval: 1
+ epochs: 100
+ print_batch_step: 10
+ use_visualdl: False
+ # used for static mode and model export
+ image_shape: [3, 224, 224]
+ save_inference_dir: "./inference"
+
+# model architecture
+Arch:
+ name: "DistillationModel"
+ # if not null, its lengths should be same as models
+ pretrained_list:
+ # if not null, its lengths should be same as models
+ freeze_params_list:
+ - True
+ - False
+ models:
+ - Teacher:
+ name: ResNet34
+ pretrained: True
+
+ - Student:
+ name: ResNet18
+ pretrained: False
+
+ infer_model_name: "Student"
+
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - DistillationSKDLoss:
+ weight: 1.0
+ model_name_pairs: [["Student", "Teacher"]]
+ temperature: 1.0
+ multiplier: 2.0
+ alpha: 0.9
+ Eval:
+ - CELoss:
+ weight: 1.0
+
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ weight_decay: 1e-4
+ lr:
+ name: MultiStepDecay
+ learning_rate: 0.2
+ milestones: [30, 60, 90]
+ step_each_epoch: 1
+ gamma: 0.1
+
+
+# data loader for train and eval
+DataLoader:
+ Train:
+ dataset:
+ name: ImageNetDataset
+ image_root: "./dataset/ILSVRC2012/"
+ cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - NormalizeImage:
+ scale: 0.00392157
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 128
+ drop_last: False
+ shuffle: True
+ loader:
+ num_workers: 8
+ use_shared_memory: True
+
+ Eval:
+ dataset:
+ name: ImageNetDataset
+ image_root: "./dataset/ILSVRC2012/"
+ cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 0.00392157
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: False
+ loader:
+ num_workers: 4
+ use_shared_memory: True
+
+Infer:
+ infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
+ batch_size: 10
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+ PostProcess:
+ name: Topk
+ topk: 5
+ class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
+
+Metric:
+ Train:
+ - DistillationTopkAcc:
+ model_key: "Student"
+ topk: [1, 5]
+ Eval:
+ - DistillationTopkAcc:
+ model_key: "Student"
+ topk: [1, 5]
diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py
index 513ede13..90e9abf0 100644
--- a/ppcls/loss/__init__.py
+++ b/ppcls/loss/__init__.py
@@ -29,6 +29,7 @@ from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
from .distillationloss import DistillationWSLLoss
+from .distillationloss import DistillationSKDLoss
from .distillationloss import DistillationMultiLabelLoss
from .distillationloss import DistillationDISTLoss
from .distillationloss import DistillationPairLoss
diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py
index af24f830..a5bc70d0 100644
--- a/ppcls/loss/distillationloss.py
+++ b/ppcls/loss/distillationloss.py
@@ -26,6 +26,7 @@ from .wslloss import WSLLoss
from .dist_loss import DISTLoss
from .multilabelloss import MultiLabelLoss
from .mgd_loss import MGDLoss
+from .skdloss import SKDLoss
class DistillationCELoss(CELoss):
@@ -291,6 +292,37 @@ class DistillationWSLLoss(WSLLoss):
return loss_dict
+class DistillationSKDLoss(SKDLoss):
+ """
+ DistillationSKDLoss
+ """
+
+ def __init__(self,
+ model_name_pairs=[],
+ key=None,
+ temperature=1.0,
+ multiplier=2.0,
+ alpha=0.9,
+ use_target_as_gt=False,
+ name="skd_loss"):
+ super().__init__(temperature, multiplier, alpha, use_target_as_gt)
+ self.model_name_pairs = model_name_pairs
+ self.key = key
+ self.name = name
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+ loss = super().forward(out1, out2, batch)
+ loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
+ return loss_dict
+
+
class DistillationMultiLabelLoss(MultiLabelLoss):
"""
DistillationMultiLabelLoss
diff --git a/ppcls/loss/skdloss.py b/ppcls/loss/skdloss.py
new file mode 100644
index 00000000..fe8e8e14
--- /dev/null
+++ b/ppcls/loss/skdloss.py
@@ -0,0 +1,72 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class SKDLoss(nn.Layer):
+ """
+ Spherical Knowledge Distillation
+ paper: https://arxiv.org/pdf/2010.07485.pdf
+ code reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
+ """
+
+ def __init__(self,
+ temperature,
+ multiplier=2.0,
+ alpha=0.9,
+ use_target_as_gt=False):
+ super().__init__()
+ self.temperature = temperature
+ self.multiplier = multiplier
+ self.alpha = alpha
+ self.use_target_as_gt = use_target_as_gt
+
+ def forward(self, logits_student, logits_teacher, target=None):
+ """Compute Spherical Knowledge Distillation loss.
+ Args:
+ logits_student: student's logits with shape (batch_size, num_classes)
+ logits_teacher: teacher's logits with shape (batch_size, num_classes)
+ """
+ if target is None or self.use_target_as_gt:
+ target = logits_teacher.argmax(axis=-1)
+
+ target = F.one_hot(
+ target.reshape([-1]), num_classes=logits_student[0].shape[0])
+
+ logits_student = F.layer_norm(
+ logits_student,
+ logits_student.shape[1:],
+ weight=None,
+ bias=None,
+ epsilon=1e-7) * self.multiplier
+ logits_teacher = F.layer_norm(
+ logits_teacher,
+ logits_teacher.shape[1:],
+ weight=None,
+ bias=None,
+ epsilon=1e-7) * self.multiplier
+
+ kd_loss = -paddle.sum(F.softmax(logits_teacher / self.temperature) *
+ F.log_softmax(logits_student / self.temperature),
+ axis=1)
+
+ kd_loss = paddle.mean(kd_loss) * self.temperature**2
+
+ ce_loss = paddle.mean(-paddle.sum(
+ target * F.log_softmax(logits_student), axis=1))
+
+ return kd_loss * self.alpha + ce_loss * (1 - self.alpha)
--
GitLab