提交 a1c4bbdb 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix doc

上级 ff7da1d0
......@@ -46,13 +46,13 @@
|-------|-----------|----------|---------------|---------------|
| Res2Net200_vd_26w_4s | 91.36 | 79.46 | 293 | 使用ImageNet预训练模型 |
| ResNet50 | 89.98 | 12.83 | 92 | 使用ImageNet预训练模型 |
| MobileNetV3_small_x0_35 | 87.41 | - | 2.8 | 使用ImageNet预训练模型 |
| MobileNetV3_small_x0_35 | 87.41 | 2.91 | 2.8 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 89.57 | 2.36 | 8.2 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 90.07 | 2.36 | 8.2 | 使用SSLD预训练模型 |
| PPLCNet_x1_0 | 90.59 | 2.36 | 8.2 | 使用SSLD预训练模型+EDA策略|
| <b>PPLCNet_x1_0<b> | <b>90.81<b> | <b>2.36<b> | <b>8.2<b> | 使用SSLD预训练模型+EDA策略+SKL-UGI知识蒸馏策略|
从表中可以看出,backbone 为 Res2Net200_vd_26w_4s 时精度较高,但是推理速度较慢。将 backbone 替换为轻量级模型 MobileNetV3_small_x0_35 后,速度可以大幅提升,但是精度下降明显。将 backbone 替换为 PPLCNet_x1_0 时,精度提升 2.16%,同时速度也提升 1 倍左右。在此基础上,使用 SSLD 预训练模型后,在不改变推理速度的前提下,精度可以提升约 0.5%,进一步地,当融合EDA策略后,精度可以再提升 0.52%,最后,在使用 SKL-UGI 知识蒸馏后,精度可以继续提升 0.23%。此时,PPLCNet_x1_0 的精度与 Res2Net200_vd_26w_4s 仅相差0.55%,但是速度快32倍。关于 PULC 的训练方法和推理部署方法将在下面详细介绍。
从表中可以看出,backbone 为 Res2Net200_vd_26w_4s 时精度较高,但是推理速度较慢。将 backbone 替换为轻量级模型 MobileNetV3_small_x0_35 后,速度可以大幅提升,但是精度下降明显。将 backbone 替换为 PPLCNet_x1_0 时,精度提升 2.16%,同时速度也提升 23% 左右。在此基础上,使用 SSLD 预训练模型后,在不改变推理速度的前提下,精度可以提升约 0.5%,进一步地,当融合EDA策略后,精度可以再提升 0.52%,最后,在使用 SKL-UGI 知识蒸馏后,精度可以继续提升 0.23%。此时,PPLCNet_x1_0 的精度与 Res2Net200_vd_26w_4s 仅相差0.55%,但是速度快32倍。关于 PULC 的训练方法和推理部署方法将在下面详细介绍。
**备注:**
......
......@@ -11,8 +11,9 @@
- [1.2 PaddleClas支持的知识蒸馏算法](#1.2)
- [1.2.1 SSLD](#1.2.1)
- [1.2.2 DML](#1.2.2)
- [1.2.3 AFD](#1.2.3)
- [1.2.4 DKD](#1.2.4)
- [1.2.3 UDML](#1.2.3)
- [1.2.4 AFD](#1.2.4)
- [1.2.5 DKD](#1.2.5)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
......@@ -196,9 +197,80 @@ Loss:
<a name='1.2.3'></a>
#### 1.2.3 AFD
#### 1.2.3 UDML
##### 1.2.3.1 AFD 算法介绍
##### 1.2.3.1 UDML 算法介绍
论文信息:
UDML 是百度飞桨视觉团队提出的无需依赖教师模型的知识蒸馏算法,它基于DML进行改进,在蒸馏的过程中,除了考虑两个模型的输出信息,也考虑两个模型的中间层特征信息,从而进一步提升知识蒸馏的精度。更多关于UDML的说明与应用,请参考[PP-ShiTu论文](https://arxiv.org/abs/2111.00775)以及[PP-OCRv3论文](https://arxiv.org/abs/2109.03144)
在ImageNet1k公开数据集上,效果如下所示。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | PPLCNet_x2_5 | [PPLCNet_x2_5.yaml](../../../ppcls/configs/ImageNet/PPLCNet/PPLCNet_x2_5.yaml) | 74.93% | - |
| UDML | PPLCNet_x2_5 | [PPLCNet_x2_5_dml.yaml](../../../ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_udml.yaml) | 76.74%(**+1.81%**) | - |
##### 1.2.3.2 UDML 配置
```yaml
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- False
- False
models:
- Teacher:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
# return_patterns表示除了返回输出的logits,也会返回对应名称的中间层feature map
return_patterns: ["blocks3", "blocks4", "blocks5", "blocks6"]
- Student:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
return_patterns: ["blocks3", "blocks4", "blocks5", "blocks6"]
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
key: logits
model_names: ["Student", "Teacher"]
- DistillationDMLLoss:
weight: 1.0
key: logits
model_name_pairs:
- ["Student", "Teacher"]
- DistillationDistanceLoss: # 基于蒸馏结果的距离loss,这里默认使用l2 loss计算block5之间的损失函数
weight: 1.0
key: "blocks5"
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
```
**注意(:** 上述在网络中指定`return_patterns`,返回中间层特征的功能是基于TheseusLayer,更多关于TheseusLayer的使用说明,请参考:[TheseusLayer 使用说明](./theseus_layer.md)
<a name='1.2.4'></a>
#### 1.2.4 AFD
##### 1.2.4.1 AFD 算法介绍
论文信息:
......@@ -220,7 +292,7 @@ AFD提出在蒸馏的过程中,利用基于注意力的元网络学习特征
注意:这里为了与论文的训练配置保持对齐,设置训练的迭代轮数为100epoch,因此baseline精度低于PaddleClas中开源出的模型精度(71.0%)
##### 1.2.3.2 AFD 配置
##### 1.2.4.2 AFD 配置
AFD配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,固定教师模型的权重。这里需要对从教师模型获取的特征进行变换,进而与学生模型进行损失函数的计算。在损失函数Loss字段中,需要定义`DistillationKLDivLoss`(学生与教师之间的KL-Div loss)、`AFDLoss`(学生与教师之间的AFD loss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。
......@@ -305,11 +377,11 @@ Loss:
**注意(:** 上述在网络中指定`return_patterns`,返回中间层特征的功能是基于TheseusLayer,更多关于TheseusLayer的使用说明,请参考:[TheseusLayer 使用说明](./theseus_layer.md)
<a name='1.2.4'></a>
<a name='1.2.5'></a>
#### 1.2.4 DKD
#### 1.2.5 DKD
##### 1.2.4.1 DKD 算法介绍
##### 1.2.5.1 DKD 算法介绍
论文信息:
......@@ -330,7 +402,7 @@ DKD将蒸馏中常用的 KD Loss 进行了解耦成为Target Class Knowledge Dis
| AFD | ResNet18 | [resnet34_distill_resnet18_dkd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml) | 72.59%(**+1.79%**) | - |
##### 1.2.4.2 DKD 配置
##### 1.2.5.2 DKD 配置
DKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationDKDLoss`(学生与教师之间的DKD loss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output_lcnet_x2_5_udml
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- False
- False
models:
- Teacher:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
return_patterns: ["blocks3", "blocks4", "blocks5", "blocks6"]
- Student:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
return_patterns: ["blocks3", "blocks4", "blocks5", "blocks6"]
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
key: logits
model_names: ["Student", "Teacher"]
- DistillationDMLLoss:
weight: 1.0
key: logits
model_name_pairs:
- ["Student", "Teacher"]
- DistillationDistanceLoss:
weight: 1.0
key: "blocks5"
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.4
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册