Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
70c45dcd
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
70c45dcd
编写于
5月 19, 2022
作者:
W
weishengyu
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'add_person_demo' of github.com:cuicheng01/PaddleClas into add_person_demo
上级
94ef3405
1989b660
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
694 addition
and
18 deletion
+694
-18
deploy/configs/cls_demo/person/inference_person_cls.yaml
deploy/configs/cls_demo/person/inference_person_cls.yaml
+35
-0
deploy/images/cls_demo/person/objects365_01780782.jpg
deploy/images/cls_demo/person/objects365_01780782.jpg
+0
-0
deploy/images/cls_demo/person/objects365_02035329.jpg
deploy/images/cls_demo/person/objects365_02035329.jpg
+0
-0
docs/zh_CN/cls_demo/quick_start_cls_demo.md
docs/zh_CN/cls_demo/quick_start_cls_demo.md
+299
-0
ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
...s_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
+6
-3
ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml
...person/Distillation/PPLCNet_x1_0_distillation_search.yaml
+169
-0
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
+4
-3
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml
.../configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml
+150
-0
ppcls/engine/engine.py
ppcls/engine/engine.py
+6
-5
ppcls/utils/cls_demo/person_label_list.txt
ppcls/utils/cls_demo/person_label_list.txt
+2
-0
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+23
-7
未找到文件。
deploy/configs/cls_demo/person/inference_person_cls.yaml
0 → 100644
浏览文件 @
70c45dcd
Global
:
infer_imgs
:
"
./images/cls_demo/person/objects365_02035329.jpg"
inference_model_dir
:
"
./models/person_cls_infer"
batch_size
:
1
use_gpu
:
True
enable_mkldnn
:
True
cpu_num_threads
:
10
enable_benchmark
:
True
use_fp16
:
False
ir_optim
:
True
use_tensorrt
:
False
gpu_mem
:
8000
enable_profile
:
False
PreProcess
:
transform_ops
:
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
channel_num
:
3
-
ToCHWImage
:
PostProcess
:
main_indicator
:
Topk
Topk
:
topk
:
5
class_id_map_file
:
"
../ppcls/utils/cls_demo/person_label_list.txt"
SavePreLabel
:
save_dir
:
./pre_label/
deploy/images/cls_demo/person/objects365_01780782.jpg
0 → 100755
浏览文件 @
70c45dcd
275.2 KB
deploy/images/cls_demo/person/objects365_02035329.jpg
0 → 100755
浏览文件 @
70c45dcd
230.1 KB
docs/zh_CN/cls_demo/quick_start_cls_demo.md
0 → 100644
浏览文件 @
70c45dcd
# PaddleClas构建有人/无人分类案例
此处提供了用户使用 PaddleClas 快速构建轻量级、高精度、可落地的有人/无人的分类模型教程,主要基于有人/无人场景的数据,融合了轻量级骨干网络PPLCNet、SSLD预训练权重、EDA数据增强策略、KL-JS-UGI知识蒸馏策略、SHAS超参数搜索策略,得到精度高、速度快、易于部署的二分类模型。
------
## 目录
-
[
1. 环境配置
](
#1
)
-
[
2. 有人/无人场景推理预测
](
#2
)
-
[
2.1 下载模型
](
#2.1
)
-
[
2.2 模型推理预测
](
#2.2
)
-
[
2.2.1 预测单张图像
](
#2.2.1
)
-
[
2.2.2 基于文件夹的批量预测
](
#2.2.2
)
-
[
3.有人/无人场景训练
](
#3
)
-
[
3.1 数据准备
](
#3.1
)
-
[
3.2 模型训练
](
#3.2
)
-
[
3.2.1 基于默认超参数训练
](
#3.2.1
)
-
[
3.2.1.1 基于默认超参数训练轻量级模型
](
#3.2.1.1
)
-
[
3.2.1.2 基于默认超参数训练教师模型
](
#3.2.1.2
)
-
[
3.2.1.3 基于默认超参数进行蒸馏训练
](
#3.2.1.3
)
-
[
3.2.2 超参数搜索训练
](
#3.2
)
-
[
4. 模型评估与推理
](
#4
)
-
[
4.1 模型评估
](
#3.1
)
-
[
4.2 模型预测
](
#3.2
)
-
[
4.3 使用 inference 模型进行推理
](
#4.3
)
-
[
4.3.1 导出 inference 模型
](
#4.3.1
)
-
[
4.3.2 模型推理预测
](
#4.3.2
)
<a
name=
"1"
></a>
## 1. 环境配置
*
安装:请先参考
[
Paddle 安装教程
](
../installation/install_paddle.md
)
以及
[
PaddleClas 安装教程
](
../installation/install_paddleclas.md
)
配置 PaddleClas 运行环境。
<a
name=
"2"
></a>
## 2. 有人/无人场景推理预测
<a
name=
"2.1"
></a>
### 2.1 下载模型
*
进入
`deploy`
运行目录。
```
cd deploy
```
下载有人/无人分类的模型。
```
mkdir models
cd models
# 下载inference 模型并解压
wget https://paddleclas.bj.bcebos.com/models/cls_demo/person_cls_infer.tar && tar -xf person_cls_infer.tar
```
解压完毕后,
`models`
文件夹下应有如下文件结构:
```
├── person_cls_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
<a
name=
"2.2"
></a>
### 2.2 模型推理预测
<a
name=
"2.2.1"
></a>
#### 2.2.1 预测单张图像
运行下面的命令,对图像
`./images/cls_demo/person/objects365_02035329.jpg`
进行有人/无人分类。
```
shell
# 使用下面的命令使用 GPU 进行预测
python3.7 python/predict_cls.py
-c
configs/cls_demo/person/inference_person_cls.yaml
# 使用下面的命令使用 CPU 进行预测
python3.7 python/predict_system.py
-c
configs/inference_general.yaml
-o
Global.use_gpu
=
False
```
输出结果如下。
```
objects365_02035329.jpg: class id(s): [1, 0], score(s): [1.00, 0.00], label_name(s): ['someone', 'nobody']
```
其中,
`someone`
表示该图里存在人,
`nobody`
表示该图里不存在人。
<a
name=
"2.2.2"
></a>
#### 2.2.2 基于文件夹的批量预测
如果希望预测文件夹内的图像,可以直接修改配置文件中的
`Global.infer_imgs`
字段,也可以通过下面的
`-o`
参数修改对应的配置。
```
shell
# 使用下面的命令使用 GPU 进行预测,如果希望使用 CPU 预测,可以在命令后面添加 -o Global.use_gpu=False
python3.7 python/predict_system.py
-c
configs/inference_general.yaml
-o
Global.infer_imgs
=
"./images/cls_demo/person/"
```
终端中会输出该文件夹内所有图像的分类结果,如下所示。
```
objects365_01780782.jpg: class id(s): [0, 1], score(s): [1.00, 0.00], label_name(s): ['nobody', 'someone']
objects365_02035329.jpg: class id(s): [1, 0], score(s): [1.00, 0.00], label_name(s): ['someone', 'nobody']
```
<a
name=
"3"
></a>
## 3.有人/无人场景训练
<a
name=
"3.1"
></a>
### 3.1 数据准备
进入 PaddleClas 目录。
```
cd path_to_PaddleClas
```
进入
`dataset/`
目录,下载并解压有人/无人场景的数据。
```
shell
cd
dataset
wget https://paddleclas.bj.bcebos.com/data/cls_demo/person.tar
tar
-xf
person.tar
cd
../
```
执行上述命令后,
`dataset/`
下存在
`person`
目录,该目录中具有以下数据:
```
├── train
│ ├── 000000000009.jpg
│ ├── 000000000025.jpg
...
├── val
│ ├── objects365_01780637.jpg
│ ├── objects365_01780640.jpg
...
├── ImageNet_val
│ ├── ILSVRC2012_val_00000001.JPEG
│ ├── ILSVRC2012_val_00000002.JPEG
...
├── train_list.txt
├── train_list.txt.debug
├── train_list_for_distill.txt
├── val_list.txt
└── val_list.txt.debug
```
其中
`train/`
和
`val/`
分别为训练集和验证集。
`train_list.txt`
和
`val_list.txt`
分别为训练集和验证集的标签文件,
`train_list.txt.debug`
和
`val_list.txt.debug`
分别为训练集和验证集的
`debug`
标签文件,其分别是
`train_list.txt`
和
`val_list.txt`
的子集,用该文件可以快速体验本案例的流程。
`ImageNet_val/`
是ImageNet的验证集,该集合和
`train`
集合的混合数据用于本案例的
`KL-JS-UGI知识蒸馏策略`
,对应的训练标签文件为
`train_list_for_distill.txt`
。
*
**注意**
:
*
本案例中所使用的所有数据集均为开源数据,
`train`
集合为
[
MS-COCO数据
](
https://cocodataset.org/#overview
)
的训练集的子集,
`val`
集合为
[
Object365数据
](
https://www.objects365.org/overview.html
)
的训练集的子集,
`ImageNet_val`
为
[
ImageNet数据
](
https://www.image-net.org/
)
的验证集。
<a
name=
"3.2"
></a>
### 3.2 模型训练
<a
name=
"3.2.1"
></a>
#### 3.2.1 基于默认超参数训练
<a
name=
"3.2.1.1"
></a>
##### 3.2.1.1 基于默认超参数训练轻量级模型
在
`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml`
中提供了基于该场景中已经搜索好的超参数,可以通过如下脚本启动训练:
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
\
--gpus
=
"0,1,2,3"
\
tools/train.py
\
-c
./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
```
验证集的最佳 metric 在0.94-0.95之间(数据集较小,容易造成波动)。
<a
name=
"3.2.1.2"
></a>
##### 3.2.1.2 基于默认超参数训练教师模型
复用
`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml`
中的超参数,训练教师模型,训练脚本如下:
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
\
--gpus
=
"0,1,2,3"
\
tools/train.py
\
-c
./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
\
-o
Arch.name
=
ResNet101_vd
```
验证集的最佳 metric 为0.97-0.98之间,当前教师模型最好的权重保存在
`output/ResNet101_vd/best_model.pdparams`
。
<a
name=
"3.2.1.3"
></a>
##### 3.2.1.3 基于默认超参数进行蒸馏训练
配置文件
`ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml`
提供了
`KL-JS-UGI知识蒸馏策略`
的配置。该配置将
`ResNet101_vd`
当作教师模型,
`PPLCNet_x1_0`
当作学生模型,使用ImageNet数据集的验证集作为新增的无标签数据。训练脚本如下:
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
\
--gpus
=
"0,1,2,3"
\
tools/train.py
\
-c
.ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
\
-o
Arch.models.0.Teacher.pretrained
=
output/ResNet101_vd/best_model
```
<a
name=
"3.2.2"
></a>
#### 3.2.2 超参数搜索训练
[
3.2 小节
](
#3.2
)
提供了在已经搜索并得到的超参数上进行了训练,此部分内容提供了搜索的过程,此过程是为了得到更好的训练超参数。
*
搜索运行脚本如下:
```
shell
python tools/search_strategy.py
-c
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml
```
*
**注意**
:
*
此过程基于当前数据集在 V100 4 卡上大概需要耗时 6 小时,如果缺少机器资源,希望体验搜索过程,可以将
`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml`
中的
`train_list.txt`
和
`val_list.txt`
分别替换为
`train_list.txt.debug`
和
`val_list.txt.debug`
。替换list只是为了加速跑通整个搜索过程,由于数据量较小,其搜素的结果没有参考性。
*
如果此过程搜索的得到的超参数与3.2.1小节提供的超参数不一致,主要是由于训练数据较小造成的波动导致,可以忽略。
<a
name=
"4"
></a>
## 4. 模型评估与推理
<a
name=
"4.1"
></a>
### 4.1 模型评估
训练好模型之后,可以通过以下命令实现对模型精度的评估。
```
bash
python3 tools/eval.py
\
-c
./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
\
-o
Global.pretrained_model
=
"output/PPLCNet_x1_0/best_model"
```
<a
name=
"4.2"
></a>
### 4.2 模型预测
模型训练完成之后,可以加载训练得到的预训练模型,进行模型预测。在模型库的
`tools/infer.py`
中提供了完整的示例,只需执行下述命令即可完成模型预测:
```
python
python3
tools
/
infer
.
py
\
-
c
.
/
ppcls
/
configs
/
cls_demo
/
person
/
PPLCNet
/
PPLCNet_x1_0
.
yaml
\
-
o
Infer
.
infer_imgs
=
.
/
dataset
/
person
/
val
/
objects365_01780637
.
jpg
\
-
o
Global
.
pretrained_model
=
output
/
PPLCNet_x1_0
/
best_model
```
<a
name=
"4.3"
></a>
### 4.3 使用 inference 模型进行推理
<a
name=
"4.3.1"
></a>
### 4.3.1 导出 inference 模型
通过导出 inference 模型,PaddlePaddle 支持使用预测引擎进行预测推理。接下来介绍如何用预测引擎进行推理:
首先,对训练好的模型进行转换:
```
bash
python3 tools/export_model.py
\
-c
./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
\
-o
Global.pretrained_model
=
output/PPLCNet_x1_0/best_model
\
-o
Global.save_inference_dir
=
deploy/models/PPLCNet_x1_0_person
```
执行完该脚本后会在
`deploy/models/`
下生成
`PPLCNet_x1_0_person`
文件夹,该文件夹中的模型与 2.2 节下载的推理预测模型格式一致。
<a
name=
"4.3.2"
></a>
### 4.3.2 基于 inference 模型推理预测
推理预测的脚本为:
```
python3.7 python/predict_cls.py -c configs/cls_demo/person/inference_person_cls.yaml -o Global.inference_model_dir="models/PPLCNet_x1_0_person"
```
更多关于推理的细节,可以参考
[
2.2节
](
#2.2
)
。
ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
浏览文件 @
70c45dcd
...
...
@@ -33,11 +33,14 @@ Arch:
-
Teacher
:
name
:
ResNet101_vd
class_num
:
*class_num
use_sync_bn
:
True
-
Student
:
name
:
PPLCNet_x1_0
class_num
:
*class_num
pretrained
:
True
use_ssld
:
True
use_sync_bn
:
True
lr_mult_list
:
[
0.0
,
0.2
,
0.4
,
0.6
,
0.8
,
1.0
]
infer_model_name
:
"
Student"
...
...
@@ -77,21 +80,21 @@ DataLoader:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
size
:
192
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
prob
:
0.0
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
img_size
:
192
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.
0
EPSILON
:
0.
1
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
...
...
ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml
0 → 100644
浏览文件 @
70c45dcd
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output
device
:
gpu
save_interval
:
1
eval_during_train
:
True
start_eval_epoch
:
10
eval_interval
:
1
epochs
:
20
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# model architecture
Arch
:
name
:
"
DistillationModel"
class_num
:
&class_num
2
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
True
-
False
use_sync_bn
:
True
models
:
-
Teacher
:
name
:
ResNet101_vd
class_num
:
*class_num
use_sync_bn
:
True
-
Student
:
name
:
PPLCNet_x1_0
class_num
:
*class_num
pretrained
:
True
use_ssld
:
True
use_sync_bn
:
True
infer_model_name
:
"
Student"
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationDMLLoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.01
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/train_list_for_distill.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
prob
:
0.0
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.0
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
16
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
2
]
Eval
:
-
TprAtFpr
:
-
TopkAcc
:
topk
:
[
1
,
2
]
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
浏览文件 @
70c45dcd
...
...
@@ -26,6 +26,7 @@ Arch:
pretrained
:
True
use_ssld
:
True
use_sync_bn
:
True
lr_mult_list
:
[
0.0
,
0.2
,
0.4
,
0.6
,
0.8
,
1.0
]
# loss function config for traing/eval process
Loss
:
...
...
@@ -61,21 +62,21 @@ DataLoader:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
size
:
192
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
prob
:
0.0
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
img_size
:
192
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.
0
EPSILON
:
0.
1
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
...
...
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml
0 → 100644
浏览文件 @
70c45dcd
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
start_eval_epoch
:
10
epochs
:
20
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# model architecture
Arch
:
name
:
PPLCNet_x1_0
class_num
:
2
pretrained
:
True
use_ssld
:
True
use_sync_bn
:
True
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.01
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
prob
:
0.0
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.0
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
2
]
Eval
:
-
TprAtFpr
:
-
TopkAcc
:
topk
:
[
1
,
2
]
ppcls/engine/engine.py
浏览文件 @
70c45dcd
...
...
@@ -344,15 +344,15 @@ class Engine(object):
if
self
.
use_dali
:
self
.
train_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
self
.
output_info
[
key
].
avg_info
for
key
in
self
.
output_info
])
metric_msg
=
", "
.
join
(
[
self
.
output_info
[
key
].
avg_info
for
key
in
self
.
output_info
])
logger
.
info
(
"[Train][Epoch {}/{}][Avg]{}"
.
format
(
epoch_id
,
self
.
config
[
"Global"
][
"epochs"
],
metric_msg
))
self
.
output_info
.
clear
()
# eval model and save model if possible
start_eval_epoch
=
self
.
config
[
"Global"
].
get
(
"start_eval_epoch"
,
0
)
-
1
start_eval_epoch
=
self
.
config
[
"Global"
].
get
(
"start_eval_epoch"
,
0
)
-
1
if
self
.
config
[
"Global"
][
"eval_during_train"
]
and
epoch_id
%
self
.
config
[
"Global"
][
"eval_interval"
]
==
0
and
epoch_id
>
start_eval_epoch
:
...
...
@@ -367,7 +367,8 @@ class Engine(object):
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"best_model"
,
loss
=
self
.
train_loss_func
)
loss
=
self
.
train_loss_func
,
save_student_model
=
True
)
logger
.
info
(
"[Eval][Epoch {}][best metric: {}]"
.
format
(
epoch_id
,
best_metric
[
"metric"
]))
logger
.
scaler
(
...
...
ppcls/utils/cls_demo/person_label_list.txt
0 → 100644
浏览文件 @
70c45dcd
0 nobody
1 someone
ppcls/utils/save_load.py
浏览文件 @
70c45dcd
...
...
@@ -42,6 +42,14 @@ def _mkdir_if_not_exist(path):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
_extract_student_weights
(
all_params
,
student_prefix
=
"Student."
):
s_params
=
{
key
[
len
(
student_prefix
):]:
all_params
[
key
]
for
key
in
all_params
if
student_prefix
in
key
}
return
s_params
def
load_dygraph_pretrain
(
model
,
path
=
None
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {}.pdparams does not "
...
...
@@ -117,7 +125,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
else
:
# common load
load_dygraph_pretrain
(
net
,
path
=
pretrained_model
)
logger
.
info
(
"Finish load pretrained model from {}"
.
format
(
pretrained_model
))
pretrained_model
))
def
save_model
(
net
,
...
...
@@ -126,7 +134,8 @@ def save_model(net,
model_path
,
model_name
=
""
,
prefix
=
'ppcls'
,
loss
:
paddle
.
nn
.
Layer
=
None
):
loss
:
paddle
.
nn
.
Layer
=
None
,
save_student_model
=
False
):
"""
save model to the target path
"""
...
...
@@ -137,11 +146,18 @@ def save_model(net,
model_path
=
os
.
path
.
join
(
model_path
,
prefix
)
params_state_dict
=
net
.
state_dict
()
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
())
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
if
loss
is
not
None
:
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
(
))
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
if
save_student_model
:
s_params
=
_extract_student_weights
(
params_state_dict
)
if
len
(
s_params
)
>
0
:
paddle
.
save
(
s_params
,
model_path
+
"_student.pdparams"
)
paddle
.
save
(
params_state_dict
,
model_path
+
".pdparams"
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
model_path
+
".pdopt"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录