未验证 提交 9b12161b 编写于 作者: L Lixin Luo 提交者: GitHub

add weighted soft labels loss function (#2288)

* add weighted soft labels loss function

add weighted soft labels loss function

* fix typo in docs/zh_CN/advanced_tutorials/knowledge_distillation.md
上级 2ed68d5d
......@@ -16,6 +16,7 @@
- [1.2.5 DKD](#1.2.5)
- [1.2.6 DIST](#1.2.6)
- [1.2.7 MGD](#1.2.7)
- [1.2.8 WSL](#1.2.8)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
......@@ -399,7 +400,7 @@ DKD将蒸馏中常用的 KD Loss 进行了解耦成为Target Class Knowledge Dis
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| AFD | ResNet18 | [resnet34_distill_resnet18_dkd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml) | 72.59%(**+1.79%**) | - |
| DKD | ResNet18 | [resnet34_distill_resnet18_dkd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml) | 72.59%(**+1.79%**) | - |
##### 1.2.5.2 DKD 配置
......@@ -583,6 +584,73 @@ Loss:
weight: 1.0
```
<a name='1.2.8'></a>
#### 1.2.8 WSL
##### 1.2.8.1 WSL 算法介绍
论文信息:
> [Rethinking Soft Labels For Knowledge Distillation: A Bias-variance Tradeoff Perspective](https://arxiv.org/abs/2102.0650)
>
> Helong Zhou, Liangchen Song, Jiajie Chen, Ye Zhou, Guoli Wang, Junsong Yuan, Qian Zhang
>
> ICLR, 2021
WSL (Weighted Soft Labels) 损失函数根据教师模型与学生模型关于真值标签的 CE Loss 比值,对每个样本的 KD Loss 分别赋予权重。若学生模型相对教师模型在某个样本上预测结果更好,则对该样本赋予较小的权重。该方法简单、有效,使各个样本的权重可自适应调节,提升了蒸馏精度。
在ImageNet1k公开数据集上,效果如下所示。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| WSL | ResNet18 | [resnet34_distill_resnet18_wsl.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_wsl.yaml) | 72.23%(**+1.43%**) | - |
##### 1.2.8.2 WSL 配置
WSL 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationGTCELoss`(学生与真值标签之间的CE loss)以及`DistillationWSLLoss`(学生与教师之间的WSL loss),作为训练的损失函数。
```yaml
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationWSLLoss:
weight: 2.5
model_name_pairs: [["Student", "Teacher"]]
temperature: 2
Eval:
- CELoss:
weight: 1.0
```
<a name="2"></a>
## 2. 模型训练、评估和预测
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/r34_r18_wsl
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationWSLLoss:
weight: 2.5
model_name_pairs: [["Student", "Teacher"]]
temperature: 2
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 1e-4
lr:
name: MultiStepDecay
learning_rate: 0.1
milestones: [30, 60, 90]
step_each_epoch: 1
gamma: 0.1
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
......@@ -26,6 +26,7 @@ from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
from .distillationloss import DistillationWSLLoss
from .distillationloss import DistillationMultiLabelLoss
from .distillationloss import DistillationDISTLoss
from .distillationloss import DistillationPairLoss
......
......@@ -22,6 +22,7 @@ from .distanceloss import DistanceLoss
from .rkdloss import RKdAngle, RkdDistance
from .kldivloss import KLDivLoss
from .dkdloss import DKDLoss
from .wslloss import WSLLoss
from .dist_loss import DISTLoss
from .multilabelloss import MultiLabelLoss
from .mgd_loss import MGDLoss
......@@ -262,6 +263,34 @@ class DistillationDKDLoss(DKDLoss):
return loss_dict
class DistillationWSLLoss(WSLLoss):
"""
DistillationWSLLoss
"""
def __init__(self,
model_name_pairs=[],
key=None,
temperature=2.0,
name="wsl_loss"):
super().__init__(temperature)
self.model_name_pairs = model_name_pairs
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2, batch)
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
return loss_dict
class DistillationMultiLabelLoss(MultiLabelLoss):
"""
DistillationMultiLabelLoss
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class WSLLoss(nn.Layer):
"""
Weighted Soft Labels Loss
paper: https://arxiv.org/pdf/2102.00650.pdf
code reference: https://github.com/bellymonster/Weighted-Soft-Label-Distillation
"""
def __init__(self, temperature=2.0, use_target_as_gt=False):
super().__init__()
self.temperature = temperature
self.use_target_as_gt = use_target_as_gt
def forward(self, logits_student, logits_teacher, target=None):
"""Compute weighted soft labels loss.
Args:
logits_student: student's logits with shape (batch_size, num_classes)
logits_teacher: teacher's logits with shape (batch_size, num_classes)
target: ground truth labels with shape (batch_size)
"""
if target is None or self.use_target_as_gt:
target = logits_teacher.argmax(axis=-1)
target = F.one_hot(
target.reshape([-1]), num_classes=logits_student[0].shape[0])
s_input_for_softmax = logits_student / self.temperature
t_input_for_softmax = logits_teacher / self.temperature
ce_loss_s = -paddle.sum(target *
F.log_softmax(logits_student.detach()),
axis=1)
ce_loss_t = -paddle.sum(target *
F.log_softmax(logits_teacher.detach()),
axis=1)
ratio = ce_loss_s / (ce_loss_t + 1e-7)
ratio = paddle.maximum(ratio, paddle.zeros_like(ratio))
kd_loss = -paddle.sum(F.softmax(t_input_for_softmax) *
F.log_softmax(s_input_for_softmax),
axis=1)
weight = 1 - paddle.exp(-ratio)
weighted_kd_loss = (self.temperature**2) * paddle.mean(kd_loss *
weight)
return weighted_kd_loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册