diff --git a/deploy/configs/cls_demo/person/inference_person_cls.yaml b/deploy/configs/cls_demo/person/inference_person_cls.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cf7a75659dbfe5690fb903e91b6df775bc506814
--- /dev/null
+++ b/deploy/configs/cls_demo/person/inference_person_cls.yaml
@@ -0,0 +1,35 @@
+Global:
+ infer_imgs: "./images/cls_demo/person/objects365_02035329.jpg"
+ inference_model_dir: "./models/person_cls_infer"
+ batch_size: 1
+ use_gpu: True
+ enable_mkldnn: True
+ cpu_num_threads: 10
+ enable_benchmark: True
+ use_fp16: False
+ ir_optim: True
+ use_tensorrt: False
+ gpu_mem: 8000
+ enable_profile: False
+
+PreProcess:
+ transform_ops:
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 0.00392157
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ channel_num: 3
+ - ToCHWImage:
+
+PostProcess:
+ main_indicator: Topk
+ Topk:
+ topk: 5
+ class_id_map_file: "../ppcls/utils/cls_demo/person_label_list.txt"
+ SavePreLabel:
+ save_dir: ./pre_label/
diff --git a/deploy/images/cls_demo/person/objects365_01780782.jpg b/deploy/images/cls_demo/person/objects365_01780782.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..a0dd0df59ae5a6386a04a8e0cf9cdbc529139c16
Binary files /dev/null and b/deploy/images/cls_demo/person/objects365_01780782.jpg differ
diff --git a/deploy/images/cls_demo/person/objects365_02035329.jpg b/deploy/images/cls_demo/person/objects365_02035329.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..16d7f2d08cd87bda1b67d21655f00f94a0c6e4e4
Binary files /dev/null and b/deploy/images/cls_demo/person/objects365_02035329.jpg differ
diff --git a/docs/zh_CN/cls_demo/quick_start_cls_demo.md b/docs/zh_CN/cls_demo/quick_start_cls_demo.md
new file mode 100644
index 0000000000000000000000000000000000000000..dc1f8cd34be2ce9808380e75b1f8652d73cb11a6
--- /dev/null
+++ b/docs/zh_CN/cls_demo/quick_start_cls_demo.md
@@ -0,0 +1,299 @@
+# PaddleClas构建有人/无人分类案例
+
+此处提供了用户使用 PaddleClas 快速构建轻量级、高精度、可落地的有人/无人的分类模型教程,主要基于有人/无人场景的数据,融合了轻量级骨干网络PPLCNet、SSLD预训练权重、EDA数据增强策略、KL-JS-UGI知识蒸馏策略、SHAS超参数搜索策略,得到精度高、速度快、易于部署的二分类模型。
+
+------
+
+
+## 目录
+
+- [1. 环境配置](#1)
+- [2. 有人/无人场景推理预测](#2)
+ - [2.1 下载模型](#2.1)
+ - [2.2 模型推理预测](#2.2)
+ - [2.2.1 预测单张图像](#2.2.1)
+ - [2.2.2 基于文件夹的批量预测](#2.2.2)
+- [3.有人/无人场景训练](#3)
+ - [3.1 数据准备](#3.1)
+ - [3.2 模型训练](#3.2)
+ - [3.2.1 基于默认超参数训练](#3.2.1)
+ - [3.2.1.1 基于默认超参数训练轻量级模型](#3.2.1.1)
+ - [3.2.1.2 基于默认超参数训练教师模型](#3.2.1.2)
+ - [3.2.1.3 基于默认超参数进行蒸馏训练](#3.2.1.3)
+ - [3.2.2 超参数搜索训练](#3.2)
+- [4. 模型评估与推理](#4)
+ - [4.1 模型评估](#3.1)
+ - [4.2 模型预测](#3.2)
+ - [4.3 使用 inference 模型进行推理](#4.3)
+ - [4.3.1 导出 inference 模型](#4.3.1)
+ - [4.3.2 模型推理预测](#4.3.2)
+
+
+
+
+## 1. 环境配置
+
+* 安装:请先参考 [Paddle 安装教程](../installation/install_paddle.md) 以及 [PaddleClas 安装教程](../installation/install_paddleclas.md) 配置 PaddleClas 运行环境。
+
+
+
+## 2. 有人/无人场景推理预测
+
+
+
+### 2.1 下载模型
+
+* 进入 `deploy` 运行目录。
+
+```
+cd deploy
+```
+
+下载有人/无人分类的模型。
+
+```
+mkdir models
+cd models
+# 下载inference 模型并解压
+wget https://paddleclas.bj.bcebos.com/models/cls_demo/person_cls_infer.tar && tar -xf person_cls_infer.tar
+```
+
+解压完毕后,`models` 文件夹下应有如下文件结构:
+
+```
+├── person_cls_infer
+│ ├── inference.pdiparams
+│ ├── inference.pdiparams.info
+│ └── inference.pdmodel
+```
+
+
+
+### 2.2 模型推理预测
+
+
+
+#### 2.2.1 预测单张图像
+
+运行下面的命令,对图像 `./images/cls_demo/person/objects365_02035329.jpg` 进行有人/无人分类。
+
+```shell
+# 使用下面的命令使用 GPU 进行预测
+python3.7 python/predict_cls.py -c configs/cls_demo/person/inference_person_cls.yaml
+# 使用下面的命令使用 CPU 进行预测
+python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.use_gpu=False
+```
+
+输出结果如下。
+
+```
+objects365_02035329.jpg: class id(s): [1, 0], score(s): [1.00, 0.00], label_name(s): ['someone', 'nobody']
+```
+
+其中,`someone` 表示该图里存在人,`nobody` 表示该图里不存在人。
+
+
+
+
+#### 2.2.2 基于文件夹的批量预测
+
+如果希望预测文件夹内的图像,可以直接修改配置文件中的 `Global.infer_imgs` 字段,也可以通过下面的 `-o` 参数修改对应的配置。
+
+```shell
+# 使用下面的命令使用 GPU 进行预测,如果希望使用 CPU 预测,可以在命令后面添加 -o Global.use_gpu=False
+python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.infer_imgs="./images/cls_demo/person/"
+```
+
+终端中会输出该文件夹内所有图像的分类结果,如下所示。
+
+```
+objects365_01780782.jpg: class id(s): [0, 1], score(s): [1.00, 0.00], label_name(s): ['nobody', 'someone']
+objects365_02035329.jpg: class id(s): [1, 0], score(s): [1.00, 0.00], label_name(s): ['someone', 'nobody']
+```
+
+
+
+## 3.有人/无人场景训练
+
+
+
+### 3.1 数据准备
+
+进入 PaddleClas 目录。
+
+```
+cd path_to_PaddleClas
+```
+
+进入 `dataset/` 目录,下载并解压有人/无人场景的数据。
+
+```shell
+cd dataset
+wget https://paddleclas.bj.bcebos.com/data/cls_demo/person.tar
+tar -xf person.tar
+cd ../
+```
+
+执行上述命令后,`dataset/`下存在`person`目录,该目录中具有以下数据:
+
+```
+
+├── train
+│ ├── 000000000009.jpg
+│ ├── 000000000025.jpg
+...
+├── val
+│ ├── objects365_01780637.jpg
+│ ├── objects365_01780640.jpg
+...
+├── ImageNet_val
+│ ├── ILSVRC2012_val_00000001.JPEG
+│ ├── ILSVRC2012_val_00000002.JPEG
+...
+├── train_list.txt
+├── train_list.txt.debug
+├── train_list_for_distill.txt
+├── val_list.txt
+└── val_list.txt.debug
+```
+
+其中`train/`和`val/`分别为训练集和验证集。`train_list.txt`和`val_list.txt`分别为训练集和验证集的标签文件,`train_list.txt.debug`和`val_list.txt.debug`分别为训练集和验证集的`debug`标签文件,其分别是`train_list.txt`和`val_list.txt`的子集,用该文件可以快速体验本案例的流程。`ImageNet_val/`是ImageNet的验证集,该集合和`train`集合的混合数据用于本案例的`KL-JS-UGI知识蒸馏策略`,对应的训练标签文件为`train_list_for_distill.txt`。
+
+* **注意**:
+
+* 本案例中所使用的所有数据集均为开源数据,`train`集合为[MS-COCO数据](https://cocodataset.org/#overview)的训练集的子集,`val`集合为[Object365数据](https://www.objects365.org/overview.html)的训练集的子集,`ImageNet_val`为[ImageNet数据](https://www.image-net.org/)的验证集。
+
+
+
+### 3.2 模型训练
+
+
+
+#### 3.2.1 基于默认超参数训练
+
+
+
+##### 3.2.1.1 基于默认超参数训练轻量级模型
+
+在`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml`中提供了基于该场景中已经搜索好的超参数,可以通过如下脚本启动训练:
+
+```shell
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python3 -m paddle.distributed.launch \
+ --gpus="0,1,2,3" \
+ tools/train.py \
+ -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
+```
+
+验证集的最佳 metric 在0.94-0.95之间(数据集较小,容易造成波动)。
+
+
+
+##### 3.2.1.2 基于默认超参数训练教师模型
+
+复用`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml`中的超参数,训练教师模型,训练脚本如下:
+
+```shell
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python3 -m paddle.distributed.launch \
+ --gpus="0,1,2,3" \
+ tools/train.py \
+ -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml \
+ -o Arch.name=ResNet101_vd
+```
+
+验证集的最佳 metric 为0.97-0.98之间,当前教师模型最好的权重保存在`output/ResNet101_vd/best_model.pdparams`。
+
+
+
+##### 3.2.1.3 基于默认超参数进行蒸馏训练
+
+配置文件`ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml`提供了`KL-JS-UGI知识蒸馏策略`的配置。该配置将`ResNet101_vd`当作教师模型,`PPLCNet_x1_0`当作学生模型,使用ImageNet数据集的验证集作为新增的无标签数据。训练脚本如下:
+
+```shell
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python3 -m paddle.distributed.launch \
+ --gpus="0,1,2,3" \
+ tools/train.py \
+ -c .ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml \
+ -o Arch.models.0.Teacher.pretrained=output/ResNet101_vd/best_model
+```
+
+
+
+#### 3.2.2 超参数搜索训练
+
+[3.2 小节](#3.2) 提供了在已经搜索并得到的超参数上进行了训练,此部分内容提供了搜索的过程,此过程是为了得到更好的训练超参数。
+
+* 搜索运行脚本如下:
+
+```shell
+python tools/search_strategy.py -c ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml
+```
+
+* **注意**:
+
+* 此过程基于当前数据集在 V100 4 卡上大概需要耗时 6 小时,如果缺少机器资源,希望体验搜索过程,可以将`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml`中的`train_list.txt`和`val_list.txt`分别替换为`train_list.txt.debug`和`val_list.txt.debug`。替换list只是为了加速跑通整个搜索过程,由于数据量较小,其搜素的结果没有参考性。
+
+* 如果此过程搜索的得到的超参数与3.2.1小节提供的超参数不一致,主要是由于训练数据较小造成的波动导致,可以忽略。
+
+
+
+## 4. 模型评估与推理
+
+
+
+
+### 4.1 模型评估
+
+训练好模型之后,可以通过以下命令实现对模型精度的评估。
+
+```bash
+python3 tools/eval.py \
+ -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml \
+ -o Global.pretrained_model="output/PPLCNet_x1_0/best_model"
+```
+
+
+
+### 4.2 模型预测
+
+模型训练完成之后,可以加载训练得到的预训练模型,进行模型预测。在模型库的 `tools/infer.py` 中提供了完整的示例,只需执行下述命令即可完成模型预测:
+
+```python
+python3 tools/infer.py \
+ -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml \
+ -o Infer.infer_imgs=./dataset/person/val/objects365_01780637.jpg \
+ -o Global.pretrained_model=output/PPLCNet_x1_0/best_model
+```
+
+
+
+### 4.3 使用 inference 模型进行推理
+
+
+
+### 4.3.1 导出 inference 模型
+
+通过导出 inference 模型,PaddlePaddle 支持使用预测引擎进行预测推理。接下来介绍如何用预测引擎进行推理:
+首先,对训练好的模型进行转换:
+
+```bash
+python3 tools/export_model.py \
+ -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml \
+ -o Global.pretrained_model=output/PPLCNet_x1_0/best_model \
+ -o Global.save_inference_dir=deploy/models/PPLCNet_x1_0_person
+```
+执行完该脚本后会在`deploy/models/`下生成`PPLCNet_x1_0_person`文件夹,该文件夹中的模型与 2.2 节下载的推理预测模型格式一致。
+
+
+
+### 4.3.2 基于 inference 模型推理预测
+推理预测的脚本为:
+
+```
+python3.7 python/predict_cls.py -c configs/cls_demo/person/inference_person_cls.yaml -o Global.inference_model_dir="models/PPLCNet_x1_0_person"
+```
+
+更多关于推理的细节,可以参考[2.2节](#2.2)。
+
diff --git a/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml b/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
index 54e1a313c69afeebb2ec69a8b0257f6554a4ea61..a9c3db29682933f19cf93ef000d5b6ec83007aa7 100644
--- a/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
+++ b/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
@@ -33,11 +33,14 @@ Arch:
- Teacher:
name: ResNet101_vd
class_num: *class_num
+ use_sync_bn: True
- Student:
name: PPLCNet_x1_0
class_num: *class_num
pretrained: True
use_ssld: True
+ use_sync_bn: True
+ lr_mult_list: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
infer_model_name: "Student"
@@ -77,21 +80,21 @@ DataLoader:
to_rgb: True
channel_first: False
- RandCropImage:
- size: 224
+ size: 192
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
prob: 0.0
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
- img_size: 224
+ img_size: 192
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
- EPSILON: 0.0
+ EPSILON: 0.1
sl: 0.02
sh: 1.0/3.0
r1: 0.3
diff --git a/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml b/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..231766d846459e5157a48fcde110b391f7f5cd5a
--- /dev/null
+++ b/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml
@@ -0,0 +1,169 @@
+# global configs
+Global:
+ checkpoints: null
+ pretrained_model: null
+ output_dir: ./output
+ device: gpu
+ save_interval: 1
+ eval_during_train: True
+ start_eval_epoch: 10
+ eval_interval: 1
+ epochs: 20
+ print_batch_step: 10
+ use_visualdl: False
+ # used for static mode and model export
+ image_shape: [3, 224, 224]
+ save_inference_dir: ./inference
+ # training model under @to_static
+ to_static: False
+ use_dali: False
+
+# model architecture
+Arch:
+ name: "DistillationModel"
+ class_num: &class_num 2
+ # if not null, its lengths should be same as models
+ pretrained_list:
+ # if not null, its lengths should be same as models
+ freeze_params_list:
+ - True
+ - False
+ use_sync_bn: True
+ models:
+ - Teacher:
+ name: ResNet101_vd
+ class_num: *class_num
+ use_sync_bn: True
+ - Student:
+ name: PPLCNet_x1_0
+ class_num: *class_num
+ pretrained: True
+ use_ssld: True
+ use_sync_bn: True
+
+ infer_model_name: "Student"
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - DistillationDMLLoss:
+ weight: 1.0
+ model_name_pairs:
+ - ["Student", "Teacher"]
+ Eval:
+ - CELoss:
+ weight: 1.0
+
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ lr:
+ name: Cosine
+ learning_rate: 0.01
+ warmup_epoch: 5
+ regularizer:
+ name: 'L2'
+ coeff: 0.00004
+
+
+# data loader for train and eval
+DataLoader:
+ Train:
+ dataset:
+ name: ImageNetDataset
+ image_root: ./dataset/person/
+ cls_label_path: ./dataset/person/train_list_for_distill.txt
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - TimmAutoAugment:
+ prob: 0.0
+ config_str: rand-m9-mstd0.5-inc1
+ interpolation: bicubic
+ img_size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - RandomErasing:
+ EPSILON: 0.0
+ sl: 0.02
+ sh: 1.0/3.0
+ r1: 0.3
+ attempt: 10
+ use_log_aspect: True
+ mode: pixel
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: True
+ loader:
+ num_workers: 16
+ use_shared_memory: True
+
+ Eval:
+ dataset:
+ name: ImageNetDataset
+ image_root: ./dataset/person/
+ cls_label_path: ./dataset/person/val_list.txt
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: False
+ loader:
+ num_workers: 4
+ use_shared_memory: True
+
+Infer:
+ infer_imgs: docs/images/inference_deployment/whl_demo.jpg
+ batch_size: 10
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+ PostProcess:
+ name: Topk
+ topk: 5
+ class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
+
+Metric:
+ Train:
+ - DistillationTopkAcc:
+ model_key: "Student"
+ topk: [1, 2]
+ Eval:
+ - TprAtFpr:
+ - TopkAcc:
+ topk: [1, 2]
diff --git a/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml b/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
index 3c9bfb5dbb10fdea9c1209ecaace08c2fb59ac6a..97ae1ba73b0c499a5e2b80f5d32c62964b061a40 100644
--- a/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
+++ b/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
@@ -26,6 +26,7 @@ Arch:
pretrained: True
use_ssld: True
use_sync_bn: True
+ lr_mult_list: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
# loss function config for traing/eval process
Loss:
@@ -61,21 +62,21 @@ DataLoader:
to_rgb: True
channel_first: False
- RandCropImage:
- size: 224
+ size: 192
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
prob: 0.0
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
- img_size: 224
+ img_size: 192
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
- EPSILON: 0.0
+ EPSILON: 0.1
sl: 0.02
sh: 1.0/3.0
r1: 0.3
diff --git a/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml b/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3c9bfb5dbb10fdea9c1209ecaace08c2fb59ac6a
--- /dev/null
+++ b/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml
@@ -0,0 +1,150 @@
+# global configs
+Global:
+ checkpoints: null
+ pretrained_model: null
+ output_dir: ./output/
+ device: gpu
+ save_interval: 1
+ eval_during_train: True
+ eval_interval: 1
+ start_eval_epoch: 10
+ epochs: 20
+ print_batch_step: 10
+ use_visualdl: False
+ # used for static mode and model export
+ image_shape: [3, 224, 224]
+ save_inference_dir: ./inference
+ # training model under @to_static
+ to_static: False
+ use_dali: False
+
+
+# model architecture
+Arch:
+ name: PPLCNet_x1_0
+ class_num: 2
+ pretrained: True
+ use_ssld: True
+ use_sync_bn: True
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - CELoss:
+ weight: 1.0
+ Eval:
+ - CELoss:
+ weight: 1.0
+
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ lr:
+ name: Cosine
+ learning_rate: 0.01
+ warmup_epoch: 5
+ regularizer:
+ name: 'L2'
+ coeff: 0.00004
+
+
+# data loader for train and eval
+DataLoader:
+ Train:
+ dataset:
+ name: ImageNetDataset
+ image_root: ./dataset/person/
+ cls_label_path: ./dataset/person/train_list.txt
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - TimmAutoAugment:
+ prob: 0.0
+ config_str: rand-m9-mstd0.5-inc1
+ interpolation: bicubic
+ img_size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - RandomErasing:
+ EPSILON: 0.0
+ sl: 0.02
+ sh: 1.0/3.0
+ r1: 0.3
+ attempt: 10
+ use_log_aspect: True
+ mode: pixel
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: True
+ loader:
+ num_workers: 8
+ use_shared_memory: True
+
+ Eval:
+ dataset:
+ name: ImageNetDataset
+ image_root: ./dataset/person/
+ cls_label_path: ./dataset/person/val_list.txt
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: False
+ loader:
+ num_workers: 4
+ use_shared_memory: True
+
+Infer:
+ infer_imgs: docs/images/inference_deployment/whl_demo.jpg
+ batch_size: 10
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+ PostProcess:
+ name: Topk
+ topk: 5
+ class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
+
+Metric:
+ Train:
+ - TopkAcc:
+ topk: [1, 2]
+ Eval:
+ - TprAtFpr:
+ - TopkAcc:
+ topk: [1, 2]
diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py
index 99dd5f3b83bc9e47608b934afb7ee0c9c167a535..9c2dc2e66fa55120eebf1271758a2d62d09c92f4 100644
--- a/ppcls/engine/engine.py
+++ b/ppcls/engine/engine.py
@@ -344,15 +344,15 @@ class Engine(object):
if self.use_dali:
self.train_dataloader.reset()
- metric_msg = ", ".join([
- self.output_info[key].avg_info for key in self.output_info
- ])
+ metric_msg = ", ".join(
+ [self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.config["Global"]["epochs"], metric_msg))
self.output_info.clear()
# eval model and save model if possible
- start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1
+ start_eval_epoch = self.config["Global"].get("start_eval_epoch",
+ 0) - 1
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > start_eval_epoch:
@@ -367,7 +367,8 @@ class Engine(object):
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="best_model",
- loss=self.train_loss_func)
+ loss=self.train_loss_func,
+ save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"]))
logger.scaler(
diff --git a/ppcls/utils/cls_demo/person_label_list.txt b/ppcls/utils/cls_demo/person_label_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8eea2b6dc2433abf303a0ea508021698559b749b
--- /dev/null
+++ b/ppcls/utils/cls_demo/person_label_list.txt
@@ -0,0 +1,2 @@
+0 nobody
+1 someone
diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py
index 4e27f12c1d4830f2f16580bfa976cf3ace78d934..b7ce8684fadfd3fd75a33294f731fd64947d7b98 100644
--- a/ppcls/utils/save_load.py
+++ b/ppcls/utils/save_load.py
@@ -42,6 +42,14 @@ def _mkdir_if_not_exist(path):
raise OSError('Failed to mkdir {}'.format(path))
+def _extract_student_weights(all_params, student_prefix="Student."):
+ s_params = {
+ key[len(student_prefix):]: all_params[key]
+ for key in all_params if student_prefix in key
+ }
+ return s_params
+
+
def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {}.pdparams does not "
@@ -117,7 +125,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
else: # common load
load_dygraph_pretrain(net, path=pretrained_model)
logger.info("Finish load pretrained model from {}".format(
- pretrained_model))
+ pretrained_model))
def save_model(net,
@@ -126,7 +134,8 @@ def save_model(net,
model_path,
model_name="",
prefix='ppcls',
- loss: paddle.nn.Layer=None):
+ loss: paddle.nn.Layer=None,
+ save_student_model=False):
"""
save model to the target path
"""
@@ -137,11 +146,18 @@ def save_model(net,
model_path = os.path.join(model_path, prefix)
params_state_dict = net.state_dict()
- loss_state_dict = loss.state_dict()
- keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys())
- assert len(keys_inter) == 0, \
- f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
- params_state_dict.update(loss_state_dict)
+ if loss is not None:
+ loss_state_dict = loss.state_dict()
+ keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys(
+ ))
+ assert len(keys_inter) == 0, \
+ f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
+ params_state_dict.update(loss_state_dict)
+
+ if save_student_model:
+ s_params = _extract_student_weights(params_state_dict)
+ if len(s_params) > 0:
+ paddle.save(s_params, model_path + "_student.pdparams")
paddle.save(params_state_dict, model_path + ".pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")