From 9b12161b1d8eb6bf726b56a07eea58844f1443fc Mon Sep 17 00:00:00 2001
From: Lixin Luo <63094060+user3984@users.noreply.github.com>
Date: Wed, 14 Sep 2022 14:23:26 +0800
Subject: [PATCH] add weighted soft labels loss function (#2288)
* add weighted soft labels loss function
add weighted soft labels loss function
* fix typo in docs/zh_CN/advanced_tutorials/knowledge_distillation.md
---
.../knowledge_distillation.md | 70 +++++++-
.../resnet34_distill_resnet18_wsl.yaml | 152 ++++++++++++++++++
ppcls/loss/__init__.py | 1 +
ppcls/loss/distillationloss.py | 29 ++++
ppcls/loss/wslloss.py | 66 ++++++++
5 files changed, 317 insertions(+), 1 deletion(-)
create mode 100644 ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_wsl.yaml
create mode 100644 ppcls/loss/wslloss.py
diff --git a/docs/zh_CN/advanced_tutorials/knowledge_distillation.md b/docs/zh_CN/advanced_tutorials/knowledge_distillation.md
index c7fbef0c..43fa6062 100644
--- a/docs/zh_CN/advanced_tutorials/knowledge_distillation.md
+++ b/docs/zh_CN/advanced_tutorials/knowledge_distillation.md
@@ -16,6 +16,7 @@
- [1.2.5 DKD](#1.2.5)
- [1.2.6 DIST](#1.2.6)
- [1.2.7 MGD](#1.2.7)
+ - [1.2.8 WSL](#1.2.8)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
@@ -399,7 +400,7 @@ DKD将蒸馏中常用的 KD Loss 进行了解耦成为Target Class Knowledge Dis
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
-| AFD | ResNet18 | [resnet34_distill_resnet18_dkd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml) | 72.59%(**+1.79%**) | - |
+| DKD | ResNet18 | [resnet34_distill_resnet18_dkd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml) | 72.59%(**+1.79%**) | - |
##### 1.2.5.2 DKD 配置
@@ -583,6 +584,73 @@ Loss:
weight: 1.0
```
+
+
+#### 1.2.8 WSL
+
+##### 1.2.8.1 WSL 算法介绍
+
+论文信息:
+
+
+> [Rethinking Soft Labels For Knowledge Distillation: A Bias-variance Tradeoff Perspective](https://arxiv.org/abs/2102.0650)
+>
+> Helong Zhou, Liangchen Song, Jiajie Chen, Ye Zhou, Guoli Wang, Junsong Yuan, Qian Zhang
+>
+> ICLR, 2021
+
+WSL (Weighted Soft Labels) 损失函数根据教师模型与学生模型关于真值标签的 CE Loss 比值,对每个样本的 KD Loss 分别赋予权重。若学生模型相对教师模型在某个样本上预测结果更好,则对该样本赋予较小的权重。该方法简单、有效,使各个样本的权重可自适应调节,提升了蒸馏精度。
+
+在ImageNet1k公开数据集上,效果如下所示。
+
+| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
+| --- | --- | --- | --- | --- |
+| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
+| WSL | ResNet18 | [resnet34_distill_resnet18_wsl.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_wsl.yaml) | 72.23%(**+1.43%**) | - |
+
+
+##### 1.2.8.2 WSL 配置
+
+WSL 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationGTCELoss`(学生与真值标签之间的CE loss)以及`DistillationWSLLoss`(学生与教师之间的WSL loss),作为训练的损失函数。
+
+
+```yaml
+# model architecture
+Arch:
+ name: "DistillationModel"
+ # if not null, its lengths should be same as models
+ pretrained_list:
+ # if not null, its lengths should be same as models
+ freeze_params_list:
+ - True
+ - False
+ models:
+ - Teacher:
+ name: ResNet34
+ pretrained: True
+
+ - Student:
+ name: ResNet18
+ pretrained: False
+
+ infer_model_name: "Student"
+
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - DistillationGTCELoss:
+ weight: 1.0
+ model_names: ["Student"]
+ - DistillationWSLLoss:
+ weight: 2.5
+ model_name_pairs: [["Student", "Teacher"]]
+ temperature: 2
+ Eval:
+ - CELoss:
+ weight: 1.0
+```
+
## 2. 模型训练、评估和预测
diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_wsl.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_wsl.yaml
new file mode 100644
index 00000000..7822a2be
--- /dev/null
+++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_wsl.yaml
@@ -0,0 +1,152 @@
+# global configs
+Global:
+ checkpoints: null
+ pretrained_model: null
+ output_dir: ./output/r34_r18_wsl
+ device: "gpu"
+ save_interval: 1
+ eval_during_train: True
+ eval_interval: 1
+ epochs: 100
+ print_batch_step: 10
+ use_visualdl: False
+ # used for static mode and model export
+ image_shape: [3, 224, 224]
+ save_inference_dir: "./inference"
+
+# model architecture
+Arch:
+ name: "DistillationModel"
+ # if not null, its lengths should be same as models
+ pretrained_list:
+ # if not null, its lengths should be same as models
+ freeze_params_list:
+ - True
+ - False
+ models:
+ - Teacher:
+ name: ResNet34
+ pretrained: True
+
+ - Student:
+ name: ResNet18
+ pretrained: False
+
+ infer_model_name: "Student"
+
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - DistillationGTCELoss:
+ weight: 1.0
+ model_names: ["Student"]
+ - DistillationWSLLoss:
+ weight: 2.5
+ model_name_pairs: [["Student", "Teacher"]]
+ temperature: 2
+ Eval:
+ - CELoss:
+ weight: 1.0
+
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ weight_decay: 1e-4
+ lr:
+ name: MultiStepDecay
+ learning_rate: 0.1
+ milestones: [30, 60, 90]
+ step_each_epoch: 1
+ gamma: 0.1
+
+
+# data loader for train and eval
+DataLoader:
+ Train:
+ dataset:
+ name: ImageNetDataset
+ image_root: "./dataset/ILSVRC2012/"
+ cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - NormalizeImage:
+ scale: 0.00392157
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: True
+ loader:
+ num_workers: 8
+ use_shared_memory: True
+
+ Eval:
+ dataset:
+ name: ImageNetDataset
+ image_root: "./dataset/ILSVRC2012/"
+ cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 0.00392157
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: False
+ loader:
+ num_workers: 4
+ use_shared_memory: True
+
+Infer:
+ infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
+ batch_size: 10
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+ PostProcess:
+ name: Topk
+ topk: 5
+ class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
+
+Metric:
+ Train:
+ - DistillationTopkAcc:
+ model_key: "Student"
+ topk: [1, 5]
+ Eval:
+ - DistillationTopkAcc:
+ model_key: "Student"
+ topk: [1, 5]
diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py
index cbe7e266..019eff71 100644
--- a/ppcls/loss/__init__.py
+++ b/ppcls/loss/__init__.py
@@ -26,6 +26,7 @@ from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
+from .distillationloss import DistillationWSLLoss
from .distillationloss import DistillationMultiLabelLoss
from .distillationloss import DistillationDISTLoss
from .distillationloss import DistillationPairLoss
diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py
index 5a924afe..af24f830 100644
--- a/ppcls/loss/distillationloss.py
+++ b/ppcls/loss/distillationloss.py
@@ -22,6 +22,7 @@ from .distanceloss import DistanceLoss
from .rkdloss import RKdAngle, RkdDistance
from .kldivloss import KLDivLoss
from .dkdloss import DKDLoss
+from .wslloss import WSLLoss
from .dist_loss import DISTLoss
from .multilabelloss import MultiLabelLoss
from .mgd_loss import MGDLoss
@@ -262,6 +263,34 @@ class DistillationDKDLoss(DKDLoss):
return loss_dict
+class DistillationWSLLoss(WSLLoss):
+ """
+ DistillationWSLLoss
+ """
+
+ def __init__(self,
+ model_name_pairs=[],
+ key=None,
+ temperature=2.0,
+ name="wsl_loss"):
+ super().__init__(temperature)
+ self.model_name_pairs = model_name_pairs
+ self.key = key
+ self.name = name
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+ loss = super().forward(out1, out2, batch)
+ loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
+ return loss_dict
+
+
class DistillationMultiLabelLoss(MultiLabelLoss):
"""
DistillationMultiLabelLoss
diff --git a/ppcls/loss/wslloss.py b/ppcls/loss/wslloss.py
new file mode 100644
index 00000000..8bdfaf8c
--- /dev/null
+++ b/ppcls/loss/wslloss.py
@@ -0,0 +1,66 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class WSLLoss(nn.Layer):
+ """
+ Weighted Soft Labels Loss
+ paper: https://arxiv.org/pdf/2102.00650.pdf
+ code reference: https://github.com/bellymonster/Weighted-Soft-Label-Distillation
+ """
+
+ def __init__(self, temperature=2.0, use_target_as_gt=False):
+ super().__init__()
+ self.temperature = temperature
+ self.use_target_as_gt = use_target_as_gt
+
+ def forward(self, logits_student, logits_teacher, target=None):
+ """Compute weighted soft labels loss.
+ Args:
+ logits_student: student's logits with shape (batch_size, num_classes)
+ logits_teacher: teacher's logits with shape (batch_size, num_classes)
+ target: ground truth labels with shape (batch_size)
+ """
+ if target is None or self.use_target_as_gt:
+ target = logits_teacher.argmax(axis=-1)
+
+ target = F.one_hot(
+ target.reshape([-1]), num_classes=logits_student[0].shape[0])
+
+ s_input_for_softmax = logits_student / self.temperature
+ t_input_for_softmax = logits_teacher / self.temperature
+
+ ce_loss_s = -paddle.sum(target *
+ F.log_softmax(logits_student.detach()),
+ axis=1)
+ ce_loss_t = -paddle.sum(target *
+ F.log_softmax(logits_teacher.detach()),
+ axis=1)
+
+ ratio = ce_loss_s / (ce_loss_t + 1e-7)
+ ratio = paddle.maximum(ratio, paddle.zeros_like(ratio))
+
+ kd_loss = -paddle.sum(F.softmax(t_input_for_softmax) *
+ F.log_softmax(s_input_for_softmax),
+ axis=1)
+ weight = 1 - paddle.exp(-ratio)
+
+ weighted_kd_loss = (self.temperature**2) * paddle.mean(kd_loss *
+ weight)
+
+ return weighted_kd_loss
--
GitLab