未验证 提交 14d6b7ef 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add dist algo (#2133)

* add dist_kd

* add doc

* fix some typos
上级 d4e286c3
......@@ -14,6 +14,7 @@
- [1.2.3 UDML](#1.2.3)
- [1.2.4 AFD](#1.2.4)
- [1.2.5 DKD](#1.2.5)
- [1.2.6 DIST](#1.2.6)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
......@@ -444,6 +445,74 @@ Loss:
- CELoss:
weight: 1.0
```
<a name='1.2.6'></a>
#### 1.2.6 DIST
##### 1.2.6.1 DIST 算法介绍
论文信息:
> [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
>
> Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu
>
> 2022, under review
使用KD方法进行模型蒸馏时,教师模型精度提升时,蒸馏的效果往往难以同步提升。本文提出DIST方法,使用皮尔逊相关系数(Pearson correlation coefficient)去表征学生模型与教师模型之间的差异,替代蒸馏过程中默认的KL散度,从而保证模型可以学到更加准确的相关性信息。
在ImageNet1k公开数据集上,效果如下所示。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| DIST | ResNet18 | [resnet34_distill_resnet18_dist.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml) | 71.99%(**+1.19%**) | - |
##### 1.2.6.2 DIST 配置
DIST 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationDISTLoss`(学生与教师之间的DIST loss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。
```yaml
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationDISTLoss:
weight: 2.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
```
<a name="2"></a>
## 2. 模型训练、评估和预测
......@@ -601,3 +670,5 @@ python3 tools/export_model.py \
[11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022.
[12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952.
[13] Huang T, You S, Wang F, et al. Knowledge Distillation from A Stronger Teacher[J]. arXiv preprint arXiv:2205.10536, 2022.
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/r34_r18_dist
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
to_static: False
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
infer_model_name: "Student"
models:
- Teacher:
name: ResNet34
class_num: *class_num
pretrained: True
- Student:
name: ResNet18
class_num: *class_num
pretrained: False
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationDISTLoss:
weight: 2.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 1e-4
lr:
name: Piecewise
learning_rate: 0.1
decay_epochs: [30, 60, 90]
values: [0.1, 0.01, 0.001, 0.0001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
......@@ -25,6 +25,8 @@ from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
from .distillationloss import DistillationMultiLabelLoss
from .distillationloss import DistillationDISTLoss
from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def cosine_similarity(a, b, eps=1e-8):
return (a * b).sum(1) / (a.norm(axis=1) * b.norm(axis=1) + eps)
def pearson_correlation(a, b, eps=1e-8):
return cosine_similarity(a - a.mean(1).unsqueeze(1),
b - b.mean(1).unsqueeze(1), eps)
def inter_class_relation(y_s, y_t):
return 1 - pearson_correlation(y_s, y_t).mean()
def intra_class_relation(y_s, y_t):
return inter_class_relation(y_s.transpose([1, 0]), y_t.transpose([1, 0]))
class DISTLoss(nn.Layer):
# DISTLoss
# paper [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
# code reference: https://github.com/hunto/image_classification_sota/blob/d4f15a0494/lib/models/losses/dist_kd.py
def __init__(self, beta=1.0, gamma=1.0):
super().__init__()
self.beta = beta
self.gamma = gamma
def forward(self, z_s, z_t):
y_s = F.softmax(z_s, axis=-1)
y_t = F.softmax(z_t, axis=-1)
inter_loss = inter_class_relation(y_s, y_t)
intra_loss = intra_class_relation(y_s, y_t)
kd_loss = self.beta * inter_loss + self.gamma * intra_loss
return kd_loss
......@@ -22,6 +22,7 @@ from .distanceloss import DistanceLoss
from .rkdloss import RKdAngle, RkdDistance
from .kldivloss import KLDivLoss
from .dkdloss import DKDLoss
from .dist_loss import DISTLoss
from .multilabelloss import MultiLabelLoss
......@@ -289,3 +290,32 @@ class DistillationMultiLabelLoss(MultiLabelLoss):
for key in loss:
loss_dict["{}_{}".format(key, name)] = loss[key]
return loss_dict
class DistillationDISTLoss(DISTLoss):
"""
DistillationDISTLoss
"""
def __init__(self,
model_name_pairs=[],
key=None,
beta=1.0,
gamma=1.0,
name="loss_dist"):
super().__init__(beta=beta, gamma=gamma)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
return loss_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册