Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
8b639e93
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8b639e93
编写于
5月 23, 2022
作者:
C
cuicheng01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update code and docs
上级
2abbb704
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
128 addition
and
227 deletion
+128
-227
deploy/configs/PULC/person/inference_person_cls.yaml
deploy/configs/PULC/person/inference_person_cls.yaml
+4
-4
deploy/images/PULC/person/objects365_01780782.jpg
deploy/images/PULC/person/objects365_01780782.jpg
+0
-0
deploy/images/PULC/person/objects365_02035329.jpg
deploy/images/PULC/person/objects365_02035329.jpg
+0
-0
docs/zh_CN/PULC/PULC_person_cls.md
docs/zh_CN/PULC/PULC_person_cls.md
+60
-29
ppcls/arch/__init__.py
ppcls/arch/__init__.py
+1
-0
ppcls/configs/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml
...s/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml
+5
-7
ppcls/configs/PULC/person/OtherModels/MobileNetV3_large_x1_0.yaml
...nfigs/PULC/person/OtherModels/MobileNetV3_large_x1_0.yaml
+4
-3
ppcls/configs/PULC/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml
.../OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml
+4
-3
ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml
ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml
+4
-4
ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0_search.yaml
ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0_search.yaml
+4
-3
ppcls/configs/StrategySearch/person.yaml
ppcls/configs/StrategySearch/person.yaml
+3
-3
ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml
...person/Distillation/PPLCNet_x1_0_distillation_search.yaml
+0
-169
ppcls/data/postprocess/__init__.py
ppcls/data/postprocess/__init__.py
+2
-1
ppcls/data/postprocess/threshoutput.py
ppcls/data/postprocess/threshoutput.py
+36
-0
tools/search_strategy.py
tools/search_strategy.py
+1
-1
未找到文件。
deploy/configs/
cls_demo
/person/inference_person_cls.yaml
→
deploy/configs/
PULC
/person/inference_person_cls.yaml
浏览文件 @
8b639e93
Global
:
infer_imgs
:
"
./images/
cls_demo
/person/objects365_02035329.jpg"
infer_imgs
:
"
./images/
PULC
/person/objects365_02035329.jpg"
inference_model_dir
:
"
./models/person_cls_infer"
batch_size
:
1
use_gpu
:
True
enable_mkldnn
:
Tru
e
enable_mkldnn
:
Fals
e
cpu_num_threads
:
10
enable_benchmark
:
True
use_fp16
:
False
...
...
@@ -30,7 +30,7 @@ PostProcess:
main_indicator
:
ThreshOutput
ThreshOutput
:
threshold
:
0.9
label_0
:
invalid
label_1
:
valid
label_0
:
nobody
label_1
:
someone
SavePreLabel
:
save_dir
:
./pre_label/
deploy/images/
cls_demo
/person/objects365_01780782.jpg
→
deploy/images/
PULC
/person/objects365_01780782.jpg
浏览文件 @
8b639e93
文件已移动
deploy/images/
cls_demo
/person/objects365_02035329.jpg
→
deploy/images/
PULC
/person/objects365_02035329.jpg
浏览文件 @
8b639e93
文件已移动
docs/zh_CN/
cls_demo/quick_start_cls_demo
.md
→
docs/zh_CN/
PULC/PULC_person_cls
.md
浏览文件 @
8b639e93
# PaddleClas构建有人/无人分类案例
此处提供了用户使用 PaddleClas 快速构建轻量级、高精度、可落地的有人/无人的分类模型教程,主要基于有人/无人场景的数据,融合了轻量级骨干网络PPLCNet、SSLD预训练权重、EDA数据增强策略、
KL-JS
-UGI知识蒸馏策略、SHAS超参数搜索策略,得到精度高、速度快、易于部署的二分类模型。
此处提供了用户使用 PaddleClas 快速构建轻量级、高精度、可落地的有人/无人的分类模型教程,主要基于有人/无人场景的数据,融合了轻量级骨干网络PPLCNet、SSLD预训练权重、EDA数据增强策略、
SKL
-UGI知识蒸馏策略、SHAS超参数搜索策略,得到精度高、速度快、易于部署的二分类模型。
------
...
...
@@ -55,7 +55,7 @@ cd deploy
mkdir models
cd models
# 下载inference 模型并解压
wget https://paddleclas.bj.bcebos.com/models/
cls_demo
/person_cls_infer.tar && tar -xf person_cls_infer.tar
wget https://paddleclas.bj.bcebos.com/models/
PULC
/person_cls_infer.tar && tar -xf person_cls_infer.tar
```
解压完毕后,
`models`
文件夹下应有如下文件结构:
...
...
@@ -75,23 +75,29 @@ wget https://paddleclas.bj.bcebos.com/models/cls_demo/person_cls_infer.tar && ta
#### 2.2.1 预测单张图像
运行下面的命令,对图像
`./images/cls_demo/person/objects365_02035329.jpg`
进行有人/无人分类。
返回
`deploy`
目录:
```
cd ../
```
运行下面的命令,对图像
`./images/PULC/person/objects365_02035329.jpg`
进行有人/无人分类。
```
shell
# 使用下面的命令使用 GPU 进行预测
python3.7 python/predict_cls.py
-c
configs/
cls_demo/person/inference_person_cls.yaml
python3.7 python/predict_cls.py
-c
configs/
PULC/person/inference_person_cls.yaml
-o
PostProcess.ThreshOutput.threshold
=
0.9794
# 使用下面的命令使用 CPU 进行预测
python3.7 python/predict_
system.py
-c
configs/inference_general.yaml
-o
Global.use_gpu
=
False
python3.7 python/predict_
cls.py
-c
configs/PULC/person/inference_person_cls.yaml
-o
PostProcess.ThreshOutput.threshold
=
0.9794
-o
Global.use_gpu
=
False
```
输出结果如下。
```
objects365_02035329.jpg: class id(s): [1
, 0], score(s): [1.00, 0.00], label_name(s): ['someone', 'nobody
']
objects365_02035329.jpg: class id(s): [1
], score(s): [1.00], label_name(s): ['someone
']
```
其中,
`someone`
表示该图里存在人,
`nobody`
表示该图里不存在人。
**备注:**
真实场景中往往需要在假正类率(Fpr)小于某一个指标下求真正类率(Tpr),该场景中的
`val`
数据集在千分之一Fpr下得到的最佳Tpr所得到的阈值为
`0.9794`
,故此处的
`threshold`
为
`0.9794`
。该阈值的确定方法可以参考
[
3.2节
](
#3.2
)
<a
name=
"2.2.2"
></a>
...
...
@@ -101,16 +107,18 @@ objects365_02035329.jpg: class id(s): [1, 0], score(s): [1.00, 0.00], label_name
```
shell
# 使用下面的命令使用 GPU 进行预测,如果希望使用 CPU 预测,可以在命令后面添加 -o Global.use_gpu=False
python3.7 python/predict_
system.py
-c
configs/inference_general.yaml
-o
Global.infer_imgs
=
"./images/cls_demo
/person/"
python3.7 python/predict_
cls.py
-c
configs/PULC/person/inference_person_cls.yaml
-o
Global.infer_imgs
=
"./images/PULC
/person/"
```
终端中会输出该文件夹内所有图像的分类结果,如下所示。
```
objects365_01780782.jpg: class id(s): [0
, 1], score(s): [1.00, 0.00], label_name(s): ['nobody', 'someone
']
objects365_02035329.jpg: class id(s): [1
, 0], score(s): [1.00, 0.00], label_name(s): ['someone', 'nobody
']
objects365_01780782.jpg: class id(s): [0
], score(s): [1.00], label_name(s): ['nobody
']
objects365_02035329.jpg: class id(s): [1
], score(s): [1.00], label_name(s): ['someone
']
```
其中,
`someone`
表示该图里存在人,
`nobody`
表示该图里不存在人。
<a
name=
"3"
></a>
## 3.有人/无人场景训练
...
...
@@ -161,7 +169,7 @@ cd ../
*
**注意**
:
*
本案例中所使用的所有数据集均为开源数据,
`train`
集合为
[
MS-COCO数据
](
https://cocodataset.org/#overview
)
的训练集的子集,
`val`
集合为
[
Object365数据
](
https://www.objects365.org/overview.html
)
的训练集的子集,
`ImageNet_val`
为
[
ImageNet数据
](
https://www.image-net.org/
)
的验证集。
*
本案例中所使用的所有数据集均为开源数据,
`train`
集合为
[
MS-COCO数据
](
https://cocodataset.org/#overview
)
的训练集的子集,
`val`
集合为
[
Object365数据
](
https://www.objects365.org/overview.html
)
的训练集的子集,
`ImageNet_val`
为
[
ImageNet数据
](
https://www.image-net.org/
)
的验证集。
数据集的筛选流程可以参考
[
有人/无人场景数据集筛选方法
](
)。
<a
name=
"3.2"
></a>
...
...
@@ -175,47 +183,53 @@ cd ../
##### 3.2.1.1 基于默认超参数训练轻量级模型
在
`ppcls/configs/
cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml`
中提供了基于该场景中已经搜索好的超参数
,可以通过如下脚本启动训练:
在
`ppcls/configs/
PULC/person/PPLCNet/PPLCNet_x1_0.yaml`
中提供了基于该场景的训练配置
,可以通过如下脚本启动训练:
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
\
--gpus
=
"0,1,2,3"
\
tools/train.py
\
-c
./ppcls/configs/
cls_demo
/person/PPLCNet/PPLCNet_x1_0.yaml
-c
./ppcls/configs/
PULC
/person/PPLCNet/PPLCNet_x1_0.yaml
```
验证集的最佳 metric 在0.94-0.95之间(数据集较小,容易造成波动)。
验证集的最佳指标在0.94-0.95之间(数据集较小,容易造成波动)。
**备注:**
*
此时使用的指标为Tpr,该指标描述了在假正类率(Fpr)小于某一个指标时的真正类率(Tpr),是产业中二分类问题常用的指标之一。在本案例中,Fpr为千分之一。关于Fpr和Tpr的更多介绍,可以参考
[
这里
](
https://baike.baidu.com/item/AUC/19282953
)
。
*
在eval时,会打印出来当前最佳的TprAtFpr指标,具体地,其会打印当前的
`Fpr`
、
`Tpr`
值,以及当前的
`threshold`
值,
`Tpr`
值反映了在当前
`Fpr`
值下的召回率,该值越高,代表模型越好。
`threshold`
表示当前最佳
`Fpr`
所对应的分类阈值,可用于后续模型部署落地等。
<a
name=
"3.2.1.2"
></a>
##### 3.2.1.2 基于默认超参数训练教师模型
复用
`ppcls/configs/
cls_demo
/person/PPLCNet/PPLCNet_x1_0.yaml`
中的超参数,训练教师模型,训练脚本如下:
复用
`ppcls/configs/
PULC
/person/PPLCNet/PPLCNet_x1_0.yaml`
中的超参数,训练教师模型,训练脚本如下:
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
\
--gpus
=
"0,1,2,3"
\
tools/train.py
\
-c
./ppcls/configs/
cls_demo
/person/PPLCNet/PPLCNet_x1_0.yaml
\
-c
./ppcls/configs/
PULC
/person/PPLCNet/PPLCNet_x1_0.yaml
\
-o
Arch.name
=
ResNet101_vd
```
验证集的最佳
metric 为0.97
-0.98之间,当前教师模型最好的权重保存在
`output/ResNet101_vd/best_model.pdparams`
。
验证集的最佳
指标为0.96
-0.98之间,当前教师模型最好的权重保存在
`output/ResNet101_vd/best_model.pdparams`
。
<a
name=
"3.2.1.3"
></a>
##### 3.2.1.3 基于默认超参数进行蒸馏训练
配置文件
`ppcls/configs/
cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml`
提供了
`KL-JS
-UGI知识蒸馏策略`
的配置。该配置将
`ResNet101_vd`
当作教师模型,
`PPLCNet_x1_0`
当作学生模型,使用ImageNet数据集的验证集作为新增的无标签数据。训练脚本如下:
配置文件
`ppcls/configs/
PULC/PULC/Distillation/PPLCNet_x1_0_distillation.yaml`
提供了
`SKL
-UGI知识蒸馏策略`
的配置。该配置将
`ResNet101_vd`
当作教师模型,
`PPLCNet_x1_0`
当作学生模型,使用ImageNet数据集的验证集作为新增的无标签数据。训练脚本如下:
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
\
--gpus
=
"0,1,2,3"
\
tools/train.py
\
-c
.
ppcls/configs/cls_demo
/person/Distillation/PPLCNet_x1_0_distillation.yaml
\
-c
.
/ppcls/configs/PULC
/person/Distillation/PPLCNet_x1_0_distillation.yaml
\
-o
Arch.models.0.Teacher.pretrained
=
output/ResNet101_vd/best_model
```
...
...
@@ -228,14 +242,19 @@ python3 -m paddle.distributed.launch \
*
搜索运行脚本如下:
```
shell
python tools/search_strategy.py
-c
ppcls/configs/
cls_demo/person/PPLCNet/PPLCNet_x1_0_search
.yaml
python tools/search_strategy.py
-c
ppcls/configs/
StrategySearch/person
.yaml
```
在
`ppcls/configs/StrategySearch/person.yaml`
中指定了具体的 GPU id 号和搜索配置。
*
**注意**
:
*
此过程基于当前数据集在 V100 4 卡上大概需要耗时 6 小时,如果缺少机器资源,希望体验搜索过程,可以将
`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml`
中的
`train_list.txt`
和
`val_list.txt`
分别替换为
`train_list.txt.debug`
和
`val_list.txt.debug`
。替换list只是为了加速跑通整个搜索过程,由于数据量较小,其搜素的结果没有参考性。
*
3.1小节提供的默认配置已经经过了搜索,所以此过程不是必要的过程,如果自己的训练数据集有变化,可以尝试此过程。
*
此过程基于当前数据集在 V100 4 卡上大概需要耗时 10 小时,如果缺少机器资源,希望体验搜索过程,可以将
`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml`
中的
`train_list.txt`
和
`val_list.txt`
分别替换为
`train_list.txt.debug`
和
`val_list.txt.debug`
。替换list只是为了加速跑通整个搜索过程,由于数据量较小,其搜素的结果没有参考性。另外,搜索空间可以根据当前的机器资源来调整,如果机器资源有限,可以尝试缩小搜索空间,如果机器资源较充足,可以尝试扩大搜索空间。
*
如果此过程搜索的得到的超参数与
[
3.2.1小节
](
#3.2.1
)
提供的超参数不一致,主要是由于训练数据较小造成的波动导致,可以忽略。
*
如果此过程搜索的得到的超参数与3.2.1小节提供的超参数不一致,主要是由于训练数据较小造成的波动导致,可以忽略。
<a
name=
"4"
></a>
...
...
@@ -246,11 +265,11 @@ python tools/search_strategy.py -c ppcls/configs/cls_demo/person/PPLCNet/PPLCNet
### 4.1 模型评估
训练好模型之后,可以通过以下命令实现对模型
精度
的评估。
训练好模型之后,可以通过以下命令实现对模型
指标
的评估。
```
bash
python3 tools/eval.py
\
-c
./ppcls/configs/
cls_demo
/person/PPLCNet/PPLCNet_x1_0.yaml
\
-c
./ppcls/configs/
PULC
/person/PPLCNet/PPLCNet_x1_0.yaml
\
-o
Global.pretrained_model
=
"output/PPLCNet_x1_0/best_model"
```
...
...
@@ -262,11 +281,20 @@ python3 tools/eval.py \
```
python
python3
tools
/
infer
.
py
\
-
c
.
/
ppcls
/
configs
/
cls_demo
/
person
/
PPLCNet
/
PPLCNet_x1_0
.
yaml
\
-
c
.
/
ppcls
/
configs
/
PULC
/
person
/
PPLCNet
/
PPLCNet_x1_0
.
yaml
\
-
o
Infer
.
infer_imgs
=
.
/
dataset
/
person
/
val
/
objects365_01780637
.
jpg
\
-
o
Global
.
pretrained_model
=
output
/
PPLCNet_x1_0
/
best_model
-
o
Global
.
pretrained_model
=
output
/
PPLCNet_x1_0
/
best_model
\
-
o
Global
.
pretrained_model
=
Infer
.
PostProcess
.
threshold
=
0.9794
```
输出结果如下:
```
[{'class_ids': [0], 'scores': [0.9878496769815683], 'label_names': ['nobody'], 'file_name': './dataset/person/val/objects365_01780637.jpg'}]
```
**备注:**
这里的
`Infer.PostProcess.threshold`
的值需要根据实际场景来确定,此处的
`0.9794`
是在该场景中的
`val`
数据集在千分之一Fpr下得到的最佳Tpr所得到的。
<a
name=
"4.3"
></a>
### 4.3 使用 inference 模型进行推理
...
...
@@ -280,7 +308,7 @@ python3 tools/infer.py \
```
bash
python3 tools/export_model.py
\
-c
./ppcls/configs/cls_demo/
person
/PPLCNet/PPLCNet_x1_0.yaml
\
-c
./ppcls/configs/cls_demo/
PULC
/PPLCNet/PPLCNet_x1_0.yaml
\
-o
Global.pretrained_model
=
output/PPLCNet_x1_0/best_model
\
-o
Global.save_inference_dir
=
deploy/models/PPLCNet_x1_0_person
```
...
...
@@ -292,8 +320,11 @@ python3 tools/export_model.py \
推理预测的脚本为:
```
python3.7 python/predict_cls.py -c configs/
cls_demo/person/inference_person_cls.yaml -o Global.inference_model_dir="models/PPLCNet_x1_0_person"
python3.7 python/predict_cls.py -c configs/
PULC/person/inference_person_cls.yaml -o Global.inference_model_dir="models/PPLCNet_x1_0_person" -o PostProcess.ThreshOutput.threshold=0.9794
```
更多关于推理的细节,可以参考
[
2.2节
](
#2.2
)
。
**备注:**
-
此处的
`PostProcess.ThreshOutput.threshold`
由eval时的最佳
`threshold`
来确定。
-
更多关于推理的细节,可以参考
[
2.2节
](
#2.2
)
。
ppcls/arch/__init__.py
浏览文件 @
8b639e93
...
...
@@ -40,6 +40,7 @@ def build_model(config):
arch
=
getattr
(
mod
,
model_type
)(
**
arch_config
)
if
use_sync_bn
:
arch
=
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
arch
)
if
isinstance
(
arch
,
TheseusLayer
):
prune_model
(
config
,
arch
)
quantize_model
(
config
,
arch
)
...
...
ppcls/configs/
cls_demo
/person/Distillation/PPLCNet_x1_0_distillation.yaml
→
ppcls/configs/
PULC
/person/Distillation/PPLCNet_x1_0_distillation.yaml
浏览文件 @
8b639e93
...
...
@@ -6,7 +6,7 @@ Global:
device
:
gpu
save_interval
:
1
eval_during_train
:
True
start_eval_epoch
:
1
0
start_eval_epoch
:
1
eval_interval
:
1
epochs
:
20
print_batch_step
:
10
...
...
@@ -33,14 +33,11 @@ Arch:
-
Teacher
:
name
:
ResNet101_vd
class_num
:
*class_num
use_sync_bn
:
True
-
Student
:
name
:
PPLCNet_x1_0
class_num
:
*class_num
pretrained
:
True
use_ssld
:
True
use_sync_bn
:
True
lr_mult_list
:
[
0.0
,
0.2
,
0.4
,
0.6
,
0.8
,
1.0
]
infer_model_name
:
"
Student"
...
...
@@ -155,9 +152,10 @@ Infer:
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
name
:
ThreshOutput
threshold
:
0.9
label_0
:
nobody
label_1
:
someone
Metric
:
Train
:
...
...
ppcls/configs/
cls_demo
/person/OtherModels/MobileNetV3_large_x1_0.yaml
→
ppcls/configs/
PULC
/person/OtherModels/MobileNetV3_large_x1_0.yaml
浏览文件 @
8b639e93
...
...
@@ -130,9 +130,10 @@ Infer:
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
name
:
ThreshOutput
threshold
:
0.9
label_0
:
nobody
label_1
:
someone
Metric
:
Train
:
...
...
ppcls/configs/
cls_demo
/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml
→
ppcls/configs/
PULC
/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml
浏览文件 @
8b639e93
...
...
@@ -153,9 +153,10 @@ Infer:
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
name
:
ThreshOutput
threshold
:
0.9
label_0
:
nobody
label_1
:
someone
Metric
:
Train
:
...
...
ppcls/configs/
cls_demo
/person/PPLCNet/PPLCNet_x1_0.yaml
→
ppcls/configs/
PULC
/person/PPLCNet/PPLCNet_x1_0.yaml
浏览文件 @
8b639e93
...
...
@@ -26,7 +26,6 @@ Arch:
pretrained
:
True
use_ssld
:
True
use_sync_bn
:
True
lr_mult_list
:
[
0.0
,
0.2
,
0.4
,
0.6
,
0.8
,
1.0
]
# loss function config for traing/eval process
Loss
:
...
...
@@ -137,9 +136,10 @@ Infer:
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
name
:
ThreshOutput
threshold
:
0.9
label_0
:
nobody
label_1
:
someone
Metric
:
Train
:
...
...
ppcls/configs/
cls_demo
/person/PPLCNet/PPLCNet_x1_0_search.yaml
→
ppcls/configs/
PULC
/person/PPLCNet/PPLCNet_x1_0_search.yaml
浏览文件 @
8b639e93
...
...
@@ -136,9 +136,10 @@ Infer:
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
name
:
ThreshOutput
threshold
:
0.9
label_0
:
nobody
label_1
:
someone
Metric
:
Train
:
...
...
ppcls/configs/StrategySearch/person.yaml
浏览文件 @
8b639e93
base_config_file
:
ppcls/configs/
cls_demo/person/PPLCNet/PPLCNet_x1_0
.yaml
distill_config_file
:
ppcls/configs/
cls_demo
/person/Distillation/PPLCNet_x1_0_distillation.yaml
base_config_file
:
ppcls/configs/
PULC/person/PPLCNet/PPLCNet_x1_0_search
.yaml
distill_config_file
:
ppcls/configs/
PULC
/person/Distillation/PPLCNet_x1_0_distillation.yaml
gpus
:
0,1,2,3
output_dir
:
output/search_person
search_times
:
3
search_times
:
1
search_dict
:
-
search_key
:
lrs
replace_config
:
...
...
ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml
已删除
100644 → 0
浏览文件 @
2abbb704
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output
device
:
gpu
save_interval
:
1
eval_during_train
:
True
start_eval_epoch
:
10
eval_interval
:
1
epochs
:
20
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# model architecture
Arch
:
name
:
"
DistillationModel"
class_num
:
&class_num
2
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
True
-
False
use_sync_bn
:
True
models
:
-
Teacher
:
name
:
ResNet101_vd
class_num
:
*class_num
use_sync_bn
:
True
-
Student
:
name
:
PPLCNet_x1_0
class_num
:
*class_num
pretrained
:
True
use_ssld
:
True
use_sync_bn
:
True
infer_model_name
:
"
Student"
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationDMLLoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.01
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/train_list_for_distill.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
prob
:
0.0
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.0
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
16
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
2
]
Eval
:
-
TprAtFpr
:
-
TopkAcc
:
topk
:
[
1
,
2
]
ppcls/data/postprocess/__init__.py
浏览文件 @
8b639e93
...
...
@@ -14,9 +14,10 @@
import
copy
import
importlib
from
.
import
topk
from
.
import
topk
,
threshoutput
from
.topk
import
Topk
,
MultiLabelTopk
from
.threshoutput
import
ThreshOutput
def
build_postprocess
(
config
):
...
...
ppcls/data/postprocess/threshoutput.py
0 → 100644
浏览文件 @
8b639e93
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.nn.functional
as
F
class
ThreshOutput
(
object
):
def
__init__
(
self
,
threshold
,
label_0
=
"0"
,
label_1
=
"1"
):
self
.
threshold
=
threshold
self
.
label_0
=
label_0
self
.
label_1
=
label_1
def
__call__
(
self
,
x
,
file_names
=
None
):
y
=
[]
x
=
F
.
softmax
(
x
,
axis
=-
1
).
numpy
()
for
idx
,
probs
in
enumerate
(
x
):
score
=
probs
[
1
]
if
score
<
self
.
threshold
:
result
=
{
"class_ids"
:
[
0
],
"scores"
:
[
1
-
score
],
"label_names"
:
[
self
.
label_0
]}
else
:
result
=
{
"class_ids"
:
[
1
],
"scores"
:
[
score
],
"label_names"
:
[
self
.
label_1
]}
if
file_names
is
not
None
:
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
return
y
tools/search_strategy.py
浏览文件 @
8b639e93
...
...
@@ -91,7 +91,7 @@ def search_strategy():
res
=
search_train
(
teacher_list
,
teacher_program
,
base_output_dir
,
"teacher"
,
replace_config
,
model_name
)
all_results
[
"teacher"
]
=
res
best
=
res
.
get
(
"best"
)
t_pretrained
=
"{}/{}_{}/{}/best_model"
.
format
(
base_output_dir
,
"teacher"
,
best
,
best
)
t_pretrained
=
"{}/{}_{}
_0
/{}/best_model"
.
format
(
base_output_dir
,
"teacher"
,
best
,
best
)
base_program
+=
[
"-o"
,
"Arch.models.0.Teacher.name={}"
.
format
(
best
),
"-o"
,
"Arch.models.0.Teacher.pretrained={}"
.
format
(
t_pretrained
)]
output_dir
=
"{}/search_res"
.
format
(
base_output_dir
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录