From c5884bb24f63d18488d773321b3b2785d2a69579 Mon Sep 17 00:00:00 2001 From: user3984 <2287245853@qq.com> Date: Tue, 8 Nov 2022 11:08:17 +0000 Subject: [PATCH] add skd --- .../advanced/knowledge_distillation.md | 67 ++++++++ .../resnet34_distill_resnet18_skd.yaml | 151 ++++++++++++++++++ ppcls/loss/__init__.py | 1 + ppcls/loss/distillationloss.py | 32 ++++ ppcls/loss/skdloss.py | 72 +++++++++ 5 files changed, 323 insertions(+) create mode 100644 ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml create mode 100644 ppcls/loss/skdloss.py diff --git a/docs/zh_CN/training/advanced/knowledge_distillation.md b/docs/zh_CN/training/advanced/knowledge_distillation.md index 9a6b15ce..be526c97 100644 --- a/docs/zh_CN/training/advanced/knowledge_distillation.md +++ b/docs/zh_CN/training/advanced/knowledge_distillation.md @@ -17,6 +17,7 @@ - [1.2.6 DIST](#1.2.6) - [1.2.7 MGD](#1.2.7) - [1.2.8 WSL](#1.2.8) + - [1.2.9 SKD](#1.2.9) - [2. 使用方法](#2) - [2.1 环境配置](#2.1) - [2.2 数据准备](#2.2) @@ -654,6 +655,72 @@ Loss: weight: 1.0 ``` + + +#### 1.2.9 SKD + +##### 1.2.9.1 SKD 算法介绍 + +论文信息: + + +> [Reducing the Teacher-Student Gap via Spherical Knowledge Disitllation](https://arxiv.org/abs/2010.07485) +> +> Jia Guo, Minghao Chen, Yao Hu, Chen Zhu, Xiaofei He, Deng Cai +> +> 2022, under review + +使用更大、精度更高的教师模型蒸馏学生模型,学生模型的精度往往反而降低。SKD (Spherical Knowledge Disitllation) 方法显式地消除了教师与学生之间的置信度差距,缓解了教师与学生之间的容量差距问题。SKD在ImageNet1k上蒸馏ResNet18的任务上显著超越了SOTA。 + +在ImageNet1k公开数据集上,效果如下所示。 + +| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 | +| --- | --- | --- | --- | --- | +| baseline | ResNet18 | [ResNet18.yaml](../../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - | +| SKD | ResNet18 | [resnet34_distill_resnet18_skd.yaml](../../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml) | 72.84%(**+2.04%**) | - | + + +##### 1.2.9.2 SKD 配置 + +SKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationSKDLoss`(学生与教师之间的SKD loss),作为训练的损失函数。 + + +```yaml +# model architecture +Arch: + name: "DistillationModel" + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + - True + - False + models: + - Teacher: + name: ResNet34 + pretrained: True + + - Student: + name: ResNet18 + pretrained: False + + infer_model_name: "Student" + + +# loss function config for traing/eval process +Loss: + Train: + - DistillationSKDLoss: + weight: 1.0 + model_name_pairs: [["Student", "Teacher"]] + temperature: 1.0 + multiplier: 2.0 + alpha: 0.9 + Eval: + - CELoss: + weight: 1.0 +``` + ## 2. 模型训练、评估和预测 diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml new file mode 100644 index 00000000..798717d2 --- /dev/null +++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml @@ -0,0 +1,151 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "DistillationModel" + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + - True + - False + models: + - Teacher: + name: ResNet34 + pretrained: True + + - Student: + name: ResNet18 + pretrained: False + + infer_model_name: "Student" + + +# loss function config for traing/eval process +Loss: + Train: + - DistillationSKDLoss: + weight: 1.0 + model_name_pairs: [["Student", "Teacher"]] + temperature: 1.0 + multiplier: 2.0 + alpha: 0.9 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + weight_decay: 1e-4 + lr: + name: MultiStepDecay + learning_rate: 0.2 + milestones: [30, 60, 90] + step_each_epoch: 1 + gamma: 0.1 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: "./dataset/ILSVRC2012/" + cls_label_path: "./dataset/ILSVRC2012/train_list.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: "./dataset/ILSVRC2012/" + cls_label_path: "./dataset/ILSVRC2012/val_list.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: "docs/images/inference_deployment/whl_demo.jpg" + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt" + +Metric: + Train: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] + Eval: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 513ede13..90e9abf0 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -29,6 +29,7 @@ from .distillationloss import DistillationRKDLoss from .distillationloss import DistillationKLDivLoss from .distillationloss import DistillationDKDLoss from .distillationloss import DistillationWSLLoss +from .distillationloss import DistillationSKDLoss from .distillationloss import DistillationMultiLabelLoss from .distillationloss import DistillationDISTLoss from .distillationloss import DistillationPairLoss diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py index af24f830..a5bc70d0 100644 --- a/ppcls/loss/distillationloss.py +++ b/ppcls/loss/distillationloss.py @@ -26,6 +26,7 @@ from .wslloss import WSLLoss from .dist_loss import DISTLoss from .multilabelloss import MultiLabelLoss from .mgd_loss import MGDLoss +from .skdloss import SKDLoss class DistillationCELoss(CELoss): @@ -291,6 +292,37 @@ class DistillationWSLLoss(WSLLoss): return loss_dict +class DistillationSKDLoss(SKDLoss): + """ + DistillationSKDLoss + """ + + def __init__(self, + model_name_pairs=[], + key=None, + temperature=1.0, + multiplier=2.0, + alpha=0.9, + use_target_as_gt=False, + name="skd_loss"): + super().__init__(temperature, multiplier, alpha, use_target_as_gt) + self.model_name_pairs = model_name_pairs + self.key = key + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2, batch) + loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss + return loss_dict + + class DistillationMultiLabelLoss(MultiLabelLoss): """ DistillationMultiLabelLoss diff --git a/ppcls/loss/skdloss.py b/ppcls/loss/skdloss.py new file mode 100644 index 00000000..fe8e8e14 --- /dev/null +++ b/ppcls/loss/skdloss.py @@ -0,0 +1,72 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class SKDLoss(nn.Layer): + """ + Spherical Knowledge Distillation + paper: https://arxiv.org/pdf/2010.07485.pdf + code reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation + """ + + def __init__(self, + temperature, + multiplier=2.0, + alpha=0.9, + use_target_as_gt=False): + super().__init__() + self.temperature = temperature + self.multiplier = multiplier + self.alpha = alpha + self.use_target_as_gt = use_target_as_gt + + def forward(self, logits_student, logits_teacher, target=None): + """Compute Spherical Knowledge Distillation loss. + Args: + logits_student: student's logits with shape (batch_size, num_classes) + logits_teacher: teacher's logits with shape (batch_size, num_classes) + """ + if target is None or self.use_target_as_gt: + target = logits_teacher.argmax(axis=-1) + + target = F.one_hot( + target.reshape([-1]), num_classes=logits_student[0].shape[0]) + + logits_student = F.layer_norm( + logits_student, + logits_student.shape[1:], + weight=None, + bias=None, + epsilon=1e-7) * self.multiplier + logits_teacher = F.layer_norm( + logits_teacher, + logits_teacher.shape[1:], + weight=None, + bias=None, + epsilon=1e-7) * self.multiplier + + kd_loss = -paddle.sum(F.softmax(logits_teacher / self.temperature) * + F.log_softmax(logits_student / self.temperature), + axis=1) + + kd_loss = paddle.mean(kd_loss) * self.temperature**2 + + ce_loss = paddle.mean(-paddle.sum( + target * F.log_softmax(logits_student), axis=1)) + + return kd_loss * self.alpha + ce_loss * (1 - self.alpha) -- GitLab