提交 7ee8471d 编写于 作者: U user3984 提交者: littletomatodonkey

add pefd

上级 17a13d00
......@@ -18,6 +18,7 @@
- [1.2.7 MGD](#1.2.7)
- [1.2.8 WSL](#1.2.8)
- [1.2.9 SKD](#1.2.9)
- [1.2.10 PEFD](#1.2.10)
- [2. Usage](#2)
- [2.1 Environment Configuration](#2.1)
- [2.2 Data Preparation](#2.2)
......@@ -580,8 +581,8 @@ Loss:
model_name_pairs: [["Student", "Teacher"]] # calculate mgdloss for Student and Teacher
name: "loss_mgd"
base_loss_name: MGDLoss # MGD loss, the following are parameters of 'MGD loss'
s_keys: ["blocks[7]"] # feature map used to calculate MGD loss in student model
t_keys: ["blocks[15]"] # feature map used to calculate MGD loss in teacher model
s_key: "blocks[7]" # feature map used to calculate MGD loss in student model
t_key: "blocks[15]" # feature map used to calculate MGD loss in teacher model
student_channels: 512 # channel num for stduent feature map
teacher_channels: 512 # channel num for teacher feature map
Eval:
......@@ -722,6 +723,80 @@ Loss:
weight: 1.0
```
<a name='1.2.10'></a>
#### 1.2.10 PEFD
##### 1.2.10.1 Introduction to PEFD
Paper:
> [Improved Feature Distillation via Projector Ensemble](https://arxiv.org/pdf/2210.15274.pdf)
>
> Yudong Chen, Sen Wang, Jiajun Liu, Xuwei Xu, Frank de Hoog, Zi Huang
>
> NeurIPS 2022
PEFD uses an ensemble of multiple projectors to transform student's features before applying the feature distillation loss, so as to prevent the student from overfitting the teacher's features and further improve the performance of feature distillation.
Performance on ImageNet1k is shown below.
| Strategy | Backbone | Config | Top-1 acc | Download Link |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| PEFD | ResNet18 | [resnet34_distill_resnet18_pefd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_pefd.yaml) | 72.23%(**+1.43%**) | - |
##### 1.2.10.2 Configuration of PEFD
The PEFD configuration is shown below. In the `Arch` field, you need to define both the student model and the teacher model. The teacher model has fixed parameters, and the pretrained parameters are loaded. In the `Loss` field, you need to define `DistillationPairLoss` (PEFD loss between student and teacher) and `DistillationGTCELoss` (CE loss with ground truth labels) as the training loss.
```yaml
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
infer_model_name: "Student"
models:
- Teacher:
name: ResNet34
class_num: *class_num
pretrained: True
return_patterns: &t_stages ["avg_pool"]
- Student:
name: ResNet18
class_num: *class_num
pretrained: False
return_patterns: &s_stages ["avg_pool"]
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationPairLoss:
weight: 25.0
base_loss_name: PEFDLoss
model_name_pairs: [["Student", "Teacher"]]
s_key: "avg_pool"
t_key: "avg_pool"
name: "loss_pefd"
student_channel: 512
teacher_channel: 512
Eval:
- CELoss:
weight: 1.0
```
<a name="2"></a>
## 2. Training, Evaluation and Prediction
......
......@@ -18,6 +18,7 @@
- [1.2.7 MGD](#1.2.7)
- [1.2.8 WSL](#1.2.8)
- [1.2.9 SKD](#1.2.9)
- [1.2.10 PEFD](#1.2.10)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
......@@ -579,8 +580,8 @@ Loss:
model_name_pairs: [["Student", "Teacher"]] # calculate mgdloss for Student and Teacher
name: "loss_mgd"
base_loss_name: MGDLoss # MGD loss,the following are parameters of 'MGD loss'
s_keys: ["blocks[7]"] # feature map used to calculate MGD loss in student model
t_keys: ["blocks[15]"] # feature map used to calculate MGD loss in teacher model
s_key: "blocks[7]" # feature map used to calculate MGD loss in student model
t_key: "blocks[15]" # feature map used to calculate MGD loss in teacher model
student_channels: 512 # channel num for stduent feature map
teacher_channels: 512 # channel num for teacher feature map
Eval:
......@@ -721,6 +722,80 @@ Loss:
weight: 1.0
```
<a name='1.2.10'></a>
#### 1.2.10 PEFD
##### 1.2.10.1 PEFD 算法介绍
论文信息:
> [Improved Feature Distillation via Projector Ensemble](https://arxiv.org/pdf/2210.15274.pdf)
>
> Yudong Chen, Sen Wang, Jiajun Liu, Xuwei Xu, Frank de Hoog, Zi Huang
>
> NeurIPS 2022
PEFD使用多个projector对学生特征图进行投影并ensemble,来拟合教师的特征图。与不使用projector或使用单个projector相比,该方法可以避免学生模型对教师特征的过拟合,进一步提高特征蒸馏的性能。
在ImageNet1k公开数据集上,效果如下所示。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| PEFD | ResNet18 | [resnet34_distill_resnet18_pefd.yaml](../../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_pefd.yaml) | 72.23%(**+1.43%**) | - |
##### 1.2.10.2 PEFD 配置
PEFD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationPairLoss`(学生与教师模型之间的PEFDLoss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。
```yaml
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
infer_model_name: "Student"
models:
- Teacher:
name: ResNet34
class_num: *class_num
pretrained: True
return_patterns: &t_stages ["avg_pool"]
- Student:
name: ResNet18
class_num: *class_num
pretrained: False
return_patterns: &s_stages ["avg_pool"]
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationPairLoss:
weight: 25.0
base_loss_name: PEFDLoss
model_name_pairs: [["Student", "Teacher"]]
s_key: "avg_pool"
t_key: "avg_pool"
name: "loss_pefd"
student_channel: 512
teacher_channel: 512
Eval:
- CELoss:
weight: 1.0
```
<a name="2"></a>
## 2. 模型训练、评估和预测
......
......@@ -48,8 +48,8 @@ Loss:
weight: 1.0
base_loss_name: MGDLoss
model_name_pairs: [["Student", "Teacher"]]
s_keys: ["blocks[7]"]
t_keys: ["blocks[15]"]
s_key: "blocks[7]"
t_key: "blocks[15]"
name: "loss_mgd"
student_channels: 512
teacher_channels: 512
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/r34_r18_pefd
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
to_static: False
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
infer_model_name: "Student"
models:
- Teacher:
name: ResNet34
class_num: *class_num
pretrained: True
return_patterns: &t_stages ["avg_pool"]
- Student:
name: ResNet18
class_num: *class_num
pretrained: False
return_patterns: &s_stages ["avg_pool"]
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationPairLoss:
weight: 25.0
base_loss_name: PEFDLoss
model_name_pairs: [["Student", "Teacher"]]
s_key: "avg_pool"
t_key: "avg_pool"
name: "loss_pefd"
student_channel: 512
teacher_channel: 512
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 1e-4
lr:
name: Piecewise
learning_rate: 0.1
decay_epochs: [30, 60, 90]
values: [0.1, 0.01, 0.001, 0.0001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
......@@ -27,6 +27,7 @@ from .dist_loss import DISTLoss
from .multilabelloss import MultiLabelLoss
from .mgd_loss import MGDLoss
from .skdloss import SKDLoss
from .pefdloss import PEFDLoss
class DistillationCELoss(CELoss):
......@@ -391,18 +392,15 @@ class DistillationPairLoss(nn.Layer):
def __init__(self,
base_loss_name,
model_name_pairs=[],
s_keys=None,
t_keys=None,
s_key=None,
t_key=None,
name="loss",
**kwargs):
super().__init__()
self.loss_func = eval(base_loss_name)(**kwargs)
if not isinstance(s_keys, list):
s_keys = [s_keys]
if not isinstance(t_keys, list):
t_keys = [t_keys]
self.s_keys = s_keys
self.t_keys = t_keys
assert type(s_key) == type(t_key)
self.s_key = s_key
self.t_key = t_key
self.model_name_pairs = model_name_pairs
self.name = name
......@@ -411,16 +409,18 @@ class DistillationPairLoss(nn.Layer):
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
out1 = [out1[k] if k is not None else out1 for k in self.s_keys]
out2 = [out2[k] if k is not None else out2 for k in self.t_keys]
for feat_idx, (o1, o2) in enumerate(zip(out1, out2)):
loss = self.loss_func.forward(o1, o2)
if isinstance(loss, dict):
for k in loss:
loss_dict[
f"{self.name}_{idx}_{feat_idx}_{pair[0]}_{pair[1]}_{k}"] = loss[
k]
else:
if isinstance(self.s_key, str):
out1 = out1[self.s_key]
out2 = out2[self.t_key]
else:
out1 = [out1[k] if k is not None else out1 for k in self.s_key]
out2 = [out2[k] if k is not None else out2 for k in self.t_key]
loss = self.loss_func.forward(out1, out2)
if isinstance(loss, dict):
for k in loss:
loss_dict[
f"{self.name}_{idx}_{feat_idx}_{pair[0]}_{pair[1]}"] = loss
f"{self.name}_{idx}_{pair[0]}_{pair[1]}_{k}"] = loss[k]
else:
loss_dict[f"{self.name}_{idx}_{pair[0]}_{pair[1]}"] = loss
return loss_dict
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppcls.utils.initializer import kaiming_normal_, kaiming_uniform_
class Regressor(nn.Layer):
"""Linear regressor"""
def __init__(self, dim_in=1024, dim_out=1024):
super(Regressor, self).__init__()
self.conv = nn.Conv2D(dim_in, dim_out, 1)
def forward(self, x):
x = self.conv(x)
x = F.relu(x)
return x
class PEFDLoss(nn.Layer):
"""Improved Feature Distillation via Projector Ensemble
Reference: https://arxiv.org/pdf/2210.15274.pdf
Code reference: https://github.com/chenyd7/PEFD
"""
def __init__(self, student_channel, teacher_channel, num_projectors=3):
super().__init__()
if num_projectors <= 0:
raise ValueError("Number of projectors must be greater than 0.")
self.projectors = nn.LayerList()
for _ in range(num_projectors):
self.projectors.append(Regressor(student_channel, teacher_channel))
def forward(self, student_feature, teacher_feature):
if student_feature.shape[2:] != teacher_feature.shape[2:]:
raise ValueError(
"Student feature must have the same H and W as teacher feature."
)
q = len(self.projectors)
f_s = 0.0
for i in range(q):
f_s += self.projectors[i](student_feature)
f_s = (f_s / q).flatten(1)
f_t = teacher_feature.flatten(1)
# inner product (normalize first and inner product)
normft = f_t.pow(2).sum(1, keepdim=True).pow(1. / 2)
outft = f_t / normft
normfs = f_s.pow(2).sum(1, keepdim=True).pow(1. / 2)
outfs = f_s / normfs
cos_theta = (outft * outfs).sum(1, keepdim=True)
loss = paddle.mean(1 - cos_theta)
return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册