Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
1a843c55
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a843c55
编写于
1月 03, 2020
作者:
B
Bai Yifan
提交者:
GitHub
1月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update distillation demo (#128)
* update distillation demo
上级
18645132
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
391 addition
and
180 deletion
+391
-180
slim/distillation/README.md
slim/distillation/README.md
+102
-81
slim/distillation/distill.py
slim/distillation/distill.py
+289
-0
slim/distillation/run.sh
slim/distillation/run.sh
+0
-47
slim/distillation/yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml
...tion/yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml
+0
-18
slim/distillation/yolov3_resnet34.yml
slim/distillation/yolov3_resnet34.yml
+0
-34
未找到文件。
slim/distillation/README.md
100755 → 100644
浏览文件 @
1a843c55
>运行该示例前请安装Paddle1.6或更高版本
>运行该示例前请安装Paddle
Slim和Paddle
1.6或更高版本
# 检测模型蒸馏示例
## 概述
该示例使用PaddleSlim提供的
[
蒸馏策略
](
https://
github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#3-蒸馏
)
对检测库中的模型进行蒸馏训练。
该示例使用PaddleSlim提供的
[
蒸馏策略
](
https://
paddlepaddle.github.io/PaddleSlim/algo/algo/#3
)
对检测库中的模型进行蒸馏训练。
在阅读该示例前,建议您先了解以下内容:
-
[
检测库的常规训练方法
](
https://github.com/PaddlePaddle/PaddleDetection
)
-
[
PaddleSlim
使用文档
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md
)
-
[
PaddleSlim
蒸馏API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/
)
## 安装PaddleSlim
可按照
[
PaddleSlim使用文档
](
https://paddlepaddle.github.io/PaddleSlim/
)
中的步骤安装PaddleSlim
##
配置文件
说明
##
蒸馏策略
说明
关于
配置文件如何编写您可以参考:
关于
蒸馏API如何使用您可以参考PaddleSlim蒸馏API文档
-
[
PaddleSlim配置文件编写说明
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8
)
-
[
蒸馏策略配置文件编写说明
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#23-蒸馏
)
这里以ResNet34-YoloV3蒸馏MobileNetV1-YoloV3模型为例,首先,为了对
`student model`
和
`teacher model`
有个总体的认识,从而进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variable)的名称和形状:
这里以ResNet34-YoloV3蒸馏训练MobileNetV1-YoloV3模型为例,首先,为了对
`student model`
和
`teacher model`
有个总体的认识,进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variables)的名称和形状:
```
python
# 观察student model的Variable
# 观察student model的Variables
student_vars
=
[]
for
v
in
fluid
.
default_main_program
().
list_vars
():
if
"py_reader"
not
in
v
.
name
and
"double_buffer"
not
in
v
.
name
and
"generated_var"
not
in
v
.
name
:
print
(
v
.
name
,
v
.
shape
)
# 观察teacher model的Variable
try
:
student_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
print
(
"="
*
50
+
"student_model_vars"
+
"="
*
50
)
print
(
student_vars
)
# 观察teacher model的Variables
teacher_vars
=
[]
for
v
in
teacher_program
.
list_vars
():
print
(
v
.
name
,
v
.
shape
)
try
:
teacher_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
print
(
"="
*
50
+
"teacher_model_vars"
+
"="
*
50
)
print
(
teacher_vars
)
```
经过对比可以发现,
`student model`
和
`teacher model`
的部分中间结果
分别为:
经过对比可以发现,
`student model`
和
`teacher model`
输入到3个
`yolov3_loss`
的特征图
分别为:
```
bash
# student model
conv2d_
15.tmp_0
conv2d_
20.tmp_1, conv2d_28.tmp_1, conv2d_36.tmp_1
# teacher model
teacher_teacher_conv2d_1.tmp_0
conv2d_6.tmp_1, conv2d_14.tmp_1, conv2d_22.tmp_1
```
所以,我们用
`l2_distiller`
对这两个特征图做蒸馏。在配置文件中进行如下配置:
```
yaml
distillers
:
l2_distiller
:
class
:
'
L2Distiller'
teacher_feature_map
:
'
teacher_teacher_conv2d_1.tmp_0'
student_feature_map
:
'
conv2d_15.tmp_0'
distillation_loss_weight
:
1
strategies
:
distillation_strategy
:
class
:
'
DistillationStrategy'
distillers
:
[
'
l2_distiller'
]
start_epoch
:
0
end_epoch
:
270
它们形状两两相同,且分别处于两个网络的输出部分。所以,我们用
`l2_loss`
对这几个特征图两两对应添加蒸馏loss。需要注意的是,teacher的Variable在merge过程中被自动添加了一个
`name_prefix`
,所以这里也需要加上这个前缀
`"teacher_"`
,merge过程请参考
[
蒸馏API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/#merge
)
```
python
dist_loss_1
=
l2_loss
(
'teacher_conv2d_6.tmp_1'
,
'conv2d_20.tmp_1'
)
dist_loss_2
=
l2_loss
(
'teacher_conv2d_14.tmp_1'
,
'conv2d_28.tmp_1'
)
dist_loss_3
=
l2_loss
(
'teacher_conv2d_22.tmp_1'
,
'conv2d_36.tmp_1'
)
```
我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有
`FSP_loss`
,
`L2_loss`
和
`softmax_with_cross_entropy_loss`
。
我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有
`FSP_loss`
,
`L2_loss`
,
`softmax_with_cross_entropy_loss`
以及自定义的任何loss
。
## 训练
根据
[
PaddleDetection/tools/train.py
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/tools/train.py
)
编写压缩脚本compress.py
。
在该脚本中定义了
Compressor对象,用于执行压缩任务。
根据
[
PaddleDetection/tools/train.py
](
../../tools/train.py
)
编写压缩脚本
`distill.py`
。
在该脚本中定义了
teacher_model和student_model,用teacher_model的输出指导student_model的训练
### 执行示例
step1: 设置GPU卡
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
```
您可以通过运行脚本
`run.sh`
运行该示例。
step2: 开始训练
```
bash
python slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
### 保存断点(checkpoint)
如果要调整训练卡数,需要调整配置文件
`yolov3_mobilenet_v1_voc.yml`
中的以下参数:
如果在配置文件中设置了
`checkpoint_path`
, 则在蒸馏任务执行过程中会自动保存断点,当任务异常中断时,
重启任务会自动从
`checkpoint_path`
路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复,
需要修改配置文件中的
`checkpoint_path`
,或者将
`checkpoint_path`
路径下文件清空。
-
**max_iters:**
训练过程迭代总步数。
-
**YOLOv3Loss.batch_size:**
该参数表示单张GPU卡上的
`batch_size`
, 总
`batch_size`
是GPU卡数乘以这个值,
`batch_size`
的设定受限于显存大小。
-
**LeaningRate.base_lr:**
根据多卡的总
`batch_size`
调整
`base_lr`
,两者大小正相关,可以简单的按比例进行调整。
-
**LearningRate.schedulers.PiecewiseDecay.milestones:**
请根据batch size的变化对其调整。
-
**LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:**
请根据batch size的变化对其进行调整。
>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。
以下为4卡训练示例,通过命令行覆盖
`yolov3_mobilenet_v1_voc.yml`
中的参数:
```
shell
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
-o
YoloTrainFeed.batch_size
=
16
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
## 评估
如果在配置文件中设置了
`checkpoint_path`
,则每个epoch会保存一个压缩后的用于评估的模型,
该模型会保存在
`${checkpoint_path}/${epoch_id}/eval_model/`
路径下,包含
`__model__`
和
`__params__`
两个文件。
其中,
`__model__`
用于保存模型结构信息,
`__params__`
用于保存参数(parameters)信息。
如果不需要保存评估模型,可以在定义Compressor对象时,将
`save_eval_model`
选项设置为False(默认为True)。
运行命令为:
```
python ../eval.py \
--model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
--model_name __model__ \
--params_name __params__ \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-d "../../dataset/voc"
```
### 保存断点(checkpoint)
## 预测
蒸馏任务执行过程中会自动保存断点。如果需要从断点继续训练请用
`-r`
参数指定checkpoint路径,示例如下:
如果在配置文件中设置了
`checkpoint_path`
,并且在定义Compressor对象时指定了
`prune_infer_model`
选项,则每个epoch都会
保存一个
`inference model`
。该模型是通过删除eval_program中多余的operators而得到的。
```
bash
python
-u
slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
-r
output/yolov3_mobilenet_v1_voc/10000
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
该模型会保存在
`${checkpoint_path}/${epoch_id}/eval_model/`
路径下,包含
`__model__.infer`
和
`__params__`
两个文件。
其中,
`__model__.infer`
用于保存模型结构信息,
`__params__`
用于保存参数(parameters)信息。
更多关于
`prune_infer_model`
选项的介绍,请参考:
[
Compressor介绍
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#121-%E5%A6%82%E4%BD%95%E6%94%B9%E5%86%99%E6%99%AE%E9%80%9A%E8%AE%AD%E7%BB%83%E8%84%9A%E6%9C%AC
)
### python预测
在脚本
<a
href=
"../infer.py"
>
slim/infer.py
</a>
中展示了如何使用fluid python API加载使用预测模型进行预测。
## 评估
每隔
`snap_shot_iter`
步后会保存一个checkpoint模型可以用于评估,使用PaddleDetection目录下
[
tools/eval.py
](
../../tools/eval.py
)
评估脚本,并指定
`weights`
为训练得到的模型路径
运行命令为:
```
python ../infer.py \
--model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
--model_name __model__.infer \
--params_name __params__ \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
--infer_dir ../../demo
```
bash
export
CUDA_VISIBLE_DEVICES
=
0
python
-u
tools/eval.py
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-o
weights
=
output/yolov3_mobilenet_v1_voc/model_final
\
```
##
# PaddleLite
##
预测
该示例中产出的预测(inference)模型可以直接用PaddleLite进行加载使用。
关于PaddleLite如何使用,请参考:
[
PaddleLite使用文档
](
https://github.com/PaddlePaddle/Paddle-Lite/wiki#%E4%BD%BF%E7%94%A8
)
每隔
`snap_shot_iter`
步后保存的checkpoint模型也可以用于预测,使用PaddleDetection目录下
[
tools/infer.py
](
../../tools/infer.py
)
评估脚本,并指定
`weights`
为训练得到的模型路径
##
示例结果
##
# Python预测
>当前release的结果并非超参调优后的最好结果,仅做示例参考,后续我们会优化当前结果。
运行命令为:
```
export CUDA_VISIBLE_DEVICES=0
python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
--infer_img=demo/000000570688.jpg \
--output_dir=infer_output/ \
--draw_threshold=0.5 \
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
##
# MobileNetV1-YOLO-V3
##
示例结果
| FLOPS |Box AP|
|---|---|
|baseline|76.2 |
|蒸馏后|- |
### MobileNetV1-YOLO-V3-VOC
| FLOPS |输入尺寸|每张GPU图片个数|推理时间(fps)|Box AP|下载|
|:-:|:-:|:-:|:-:|:-:|:-:|
|baseline|608 |16|104.291|76.2|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|baseline|416 |16|-|76.7|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|baseline|320 |16|-|75.3|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|蒸馏后|608 |16|106.914|79.0||
|蒸馏后|416 |16|-|78.2||
|蒸馏后|320 |16|-|75.5||
## FAQ
> 蒸馏后的结果用ResNet34-YOLO-V3做teacher,4GPU总batch_size64训练90000 iter得到
slim/distillation/
compress
.py
→
slim/distillation/
distill
.py
浏览文件 @
1a843c55
...
...
@@ -17,38 +17,18 @@ from __future__ import division
from
__future__
import
print_function
import
os
import
time
import
multiprocessing
import
numpy
as
np
from
collections
import
deque
,
OrderedDict
from
paddle.fluid.contrib.slim.core
import
Compressor
from
paddle.fluid.framework
import
IrGraph
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
from
collections
import
OrderedDict
from
paddleslim.dist.single_distiller
import
merge
,
l2_loss
from
paddle
import
fluid
import
sys
sys
.
path
.
append
(
"../../"
)
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.
data_feed
import
create_reader
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_results
from
ppdet.data.
reader
import
create_reader
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_results
,
eval_run
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.check
import
check_gpu
import
ppdet.utils.checkpoint
as
checkpoint
from
ppdet.modeling.model_input
import
create_feed
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
...
...
@@ -56,56 +36,8 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger
=
logging
.
getLogger
(
__name__
)
def
eval_run
(
exe
,
compile_program
,
reader
,
keys
,
values
,
cls
,
test_feed
):
"""
Run evaluation program, return program outputs.
"""
iter_id
=
0
results
=
[]
if
len
(
cls
)
!=
0
:
values
=
[]
for
i
in
range
(
len
(
cls
)):
_
,
accum_map
=
cls
[
i
].
get_map_var
()
cls
[
i
].
reset
(
exe
)
values
.
append
(
accum_map
)
images_num
=
0
start_time
=
time
.
time
()
has_bbox
=
'bbox'
in
keys
for
data
in
reader
():
data
=
test_feed
.
feed
(
data
)
feed_data
=
{
'image'
:
data
[
'image'
],
'im_size'
:
data
[
'im_size'
]}
outs
=
exe
.
run
(
compile_program
,
feed
=
feed_data
,
fetch_list
=
[
values
[
0
]],
return_numpy
=
False
)
outs
.
append
(
data
[
'gt_box'
])
outs
.
append
(
data
[
'gt_label'
])
outs
.
append
(
data
[
'is_difficult'
])
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
keys
,
outs
)
}
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
iter_id
+=
1
images_num
+=
len
(
res
[
'bbox'
][
1
][
0
])
if
has_bbox
else
1
logger
.
info
(
'Test finish iter {}'
.
format
(
iter_id
))
end_time
=
time
.
time
()
fps
=
images_num
/
(
end_time
-
start_time
)
if
has_bbox
:
logger
.
info
(
'Total number of images: {}, inference time: {} fps.'
.
format
(
images_num
,
fps
))
else
:
logger
.
info
(
'Total iteration: {}, inference time: {} batch/s.'
.
format
(
images_num
,
fps
))
return
results
def
main
():
env
=
os
.
environ
cfg
=
load_config
(
FLAGS
.
config
)
if
'architecture'
in
cfg
:
main_arch
=
cfg
.
architecture
...
...
@@ -122,112 +54,60 @@ def main():
if
cfg
.
use_gpu
:
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
if
'train_feed'
not
in
cfg
:
train_feed
=
create
(
main_arch
+
'TrainFeed'
)
else
:
train_feed
=
create
(
cfg
.
train_feed
)
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
if
'
eval_feed'
not
in
cfg
:
eval_feed
=
create
(
main_arch
+
'EvalFeed'
)
if
'
FLAGS_selected_gpus'
in
env
:
device_id
=
int
(
env
[
'FLAGS_selected_gpus'
]
)
else
:
eval_feed
=
create
(
cfg
.
eval_feed
)
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
device_id
=
0
place
=
fluid
.
CUDAPlace
(
device_id
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
# build program
model
=
create
(
main_arch
)
_
,
train_feed_vars
=
create_feed
(
train_feed
,
True
)
inputs_def
=
cfg
[
'TrainReader'
][
'inputs_def'
]
train_feed_vars
,
train_loader
=
model
.
build_inputs
(
**
inputs_def
)
train_fetches
=
model
.
train
(
train_feed_vars
)
loss
=
train_fetches
[
'loss'
]
lr
=
lr_builder
()
opt
=
optim_builder
(
lr
)
opt
.
minimize
(
loss
)
#for v in fluid.default_main_program().list_vars():
# if "py_reader" not in v.name and "double_buffer" not in v.name and "generated_var" not in v.name:
# print(v.name, v.shape)
cfg
.
max_iters
=
258
train_reader
=
create_reader
(
train_feed
,
cfg
.
max_iters
,
FLAGS
.
dataset_dir
)
exe
.
run
(
fluid
.
default_startup_program
())
# parse train fetches
train_keys
,
train_values
,
_
=
parse_fetches
(
train_fetches
)
train_keys
.
append
(
'lr'
)
train_values
.
append
(
lr
.
name
)
train_fetch_list
=
[]
for
k
,
v
in
zip
(
train_keys
,
train_values
):
train_fetch_list
.
append
((
k
,
v
))
print
(
"train_fetch_list: {}"
.
format
(
train_fetch_list
))
# get all student variables
student_vars
=
[]
for
v
in
fluid
.
default_main_program
().
list_vars
():
try
:
student_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
# uncomment the following lines to print all student variables
# print("="*50 + "student_model_vars" + "="*50)
# print(student_vars)
eval_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
with
fluid
.
program_guard
(
eval_prog
,
fluid
.
default_startup_program
()):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
_
,
test_feed_vars
=
create_feed
(
eval_feed
,
True
)
inputs_def
=
cfg
[
'EvalReader'
][
'inputs_def'
]
test_feed_vars
,
eval_loader
=
model
.
build_inputs
(
**
inputs_def
)
fetches
=
model
.
eval
(
test_feed_vars
)
eval_prog
=
eval_prog
.
clone
(
True
)
eval_reader
=
create_reader
(
eval_feed
,
args_path
=
FLAGS
.
dataset_di
r
)
test_data_feed
=
fluid
.
DataFeeder
(
test_feed_vars
.
values
()
,
place
)
eval_reader
=
create_reader
(
cfg
.
EvalReade
r
)
eval_loader
.
set_sample_list_generator
(
eval_reader
,
place
)
# parse eval fetches
extra_keys
=
[]
if
cfg
.
metric
==
'COCO'
:
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
]
if
cfg
.
metric
==
'VOC'
:
extra_keys
=
[
'gt_b
ox'
,
'gt_label
'
,
'is_difficult'
]
extra_keys
=
[
'gt_b
box'
,
'gt_class
'
,
'is_difficult'
]
eval_keys
,
eval_values
,
eval_cls
=
parse_fetches
(
fetches
,
eval_prog
,
extra_keys
)
eval_fetch_list
=
[]
for
k
,
v
in
zip
(
eval_keys
,
eval_values
):
eval_fetch_list
.
append
((
k
,
v
))
print
(
"eval_fetch_list: {}"
.
format
(
eval_fetch_list
))
exe
.
run
(
startup_prog
)
checkpoint
.
load_params
(
exe
,
fluid
.
default_main_program
(),
cfg
.
pretrain_weights
)
best_box_ap_list
=
[]
def
eval_func
(
program
,
scope
):
results
=
eval_run
(
exe
,
program
,
eval_reader
,
eval_keys
,
eval_values
,
eval_cls
,
test_data_feed
)
resolution
=
None
is_bbox_normalized
=
False
if
'mask'
in
results
[
0
]:
resolution
=
model
.
mask_head
.
resolution
box_ap_stats
=
eval_results
(
results
,
eval_feed
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
is_bbox_normalized
,
FLAGS
.
output_eval
)
if
len
(
best_box_ap_list
)
==
0
:
best_box_ap_list
.
append
(
box_ap_stats
[
0
])
elif
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
logger
.
info
(
"Best test box ap: {}"
.
format
(
best_box_ap_list
[
0
]))
return
best_box_ap_list
[
0
]
test_feed
=
[(
'image'
,
test_feed_vars
[
'image'
].
name
),
(
'im_size'
,
test_feed_vars
[
'im_size'
].
name
)]
teacher_cfg
=
load_config
(
FLAGS
.
teacher_config
)
teacher_arch
=
teacher_cfg
.
architecture
teacher_programs
=
[]
teacher_program
=
fluid
.
Program
()
teacher_startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
teacher_program
,
teacher_startup_program
):
with
fluid
.
unique_name
.
guard
(
'teacher_'
):
with
fluid
.
unique_name
.
guard
():
teacher_feed_vars
=
OrderedDict
()
for
name
,
var
in
train_feed_vars
.
items
():
teacher_feed_vars
[
name
]
=
teacher_program
.
global_block
(
...
...
@@ -235,64 +115,154 @@ def main():
var
,
force_persistable
=
False
)
model
=
create
(
teacher_arch
)
train_fetches
=
model
.
train
(
teacher_feed_vars
)
#print("="*50+"teacher_model_params"+"="*50)
#for v in teacher_program.list_vars():
# print(v.name, v.shape)
#return
teacher_loss
=
train_fetches
[
'loss'
]
# get all teacher variables
teacher_vars
=
[]
for
v
in
teacher_program
.
list_vars
():
try
:
teacher_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
# uncomment the following lines to print all teacher variables
# print("="*50 + "teacher_model_vars" + "="*50)
# print(teacher_vars)
exe
.
run
(
teacher_startup_program
)
assert
FLAGS
.
teacher_pretrained
and
os
.
path
.
exists
(
FLAGS
.
teacher_pretrained
),
"teacher_pretrained should be set when teacher_model is not None."
def
if_exist
(
var
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
FLAGS
.
teacher_pretrained
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
FLAGS
.
teacher_pretrained
,
main_program
=
teacher_program
,
predicate
=
if_exist
)
teacher_programs
.
append
(
teacher_program
.
clone
(
for_test
=
True
))
com
=
Compressor
(
place
,
fluid
.
global_scope
(),
fluid
.
default_main_program
(),
train_reader
=
train_reader
,
train_feed_list
=
[(
key
,
value
.
name
)
for
key
,
value
in
train_feed_vars
.
items
()],
train_fetch_list
=
train_fetch_list
,
eval_program
=
eval_prog
,
eval_reader
=
eval_reader
,
eval_feed_list
=
test_feed
,
eval_func
=
{
'map'
:
eval_func
},
eval_fetch_list
=
eval_fetch_list
[
0
:
1
],
save_eval_model
=
True
,
prune_infer_model
=
[[
"image"
,
"im_size"
],
[
"multiclass_nms_0.tmp_0"
]],
teacher_programs
=
teacher_programs
,
train_optimizer
=
None
,
distiller_optimizer
=
opt
,
log_period
=
20
)
com
.
config
(
FLAGS
.
slim_file
)
com
.
run
()
assert
FLAGS
.
teacher_pretrained
,
"teacher_pretrained should be set"
checkpoint
.
load_params
(
exe
,
teacher_program
,
FLAGS
.
teacher_pretrained
)
teacher_program
=
teacher_program
.
clone
(
for_test
=
True
)
cfg
=
load_config
(
FLAGS
.
config
)
data_name_map
=
{
'image'
:
'image'
,
'gt_bbox'
:
'gt_bbox'
,
'gt_class'
:
'gt_class'
,
'gt_score'
:
'gt_score'
}
distill_prog
=
merge
(
teacher_program
,
fluid
.
default_main_program
(),
data_name_map
,
place
)
distill_weight
=
100
distill_pairs
=
[[
'teacher_conv2d_6.tmp_1'
,
'conv2d_20.tmp_1'
],
[
'teacher_conv2d_14.tmp_1'
,
'conv2d_28.tmp_1'
],
[
'teacher_conv2d_22.tmp_1'
,
'conv2d_36.tmp_1'
]]
def
l2_distill
(
pairs
,
weight
):
"""
Add l2 distillation losses composed of multi pairs of feature maps,
each pair of feature maps is the input of teacher and student's
yolov3_loss respectively
"""
loss
=
[]
for
pair
in
pairs
:
loss
.
append
(
l2_loss
(
pair
[
0
],
pair
[
1
]))
loss
=
fluid
.
layers
.
sum
(
loss
)
weighted_loss
=
loss
*
weight
return
weighted_loss
distill_loss
=
l2_distill
(
distill_pairs
,
distill_weight
)
loss
=
distill_loss
+
loss
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
lr
=
lr_builder
()
opt
=
optim_builder
(
lr
)
opt
.
minimize
(
loss
)
exe
.
run
(
fluid
.
default_startup_program
())
checkpoint
.
load_params
(
exe
,
fluid
.
default_main_program
(),
cfg
.
pretrain_weights
)
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
fuse_all_reduce_ops
=
False
build_strategy
.
fuse_all_optimizer_ops
=
False
build_strategy
.
fuse_elewise_add_act_ops
=
True
# only enable sync_bn in multi GPU devices
sync_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'sync_bn'
build_strategy
.
sync_batch_norm
=
sync_bn
and
devices_num
>
1
\
and
cfg
.
use_gpu
exec_strategy
=
fluid
.
ExecutionStrategy
()
# iteration number when CompiledProgram tries to drop local execution scopes.
# Set it to be 1 to save memory usages, so that unused variables in
# local execution scopes can be deleted after each iteration.
exec_strategy
.
num_iteration_per_drop_scope
=
1
parallel_main
=
fluid
.
CompiledProgram
(
distill_prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
compiled_eval_prog
=
fluid
.
compiler
.
CompiledProgram
(
eval_prog
)
fuse_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'affine_channel'
ignore_params
=
cfg
.
finetune_exclude_pretrained_params
\
if
'finetune_exclude_pretrained_params'
in
cfg
else
[]
start_iter
=
0
if
FLAGS
.
resume_checkpoint
:
checkpoint
.
load_checkpoint
(
exe
,
distill_prog
,
FLAGS
.
resume_checkpoint
)
start_iter
=
checkpoint
.
global_step
()
elif
cfg
.
pretrain_weights
and
fuse_bn
and
not
ignore_params
:
checkpoint
.
load_and_fusebn
(
exe
,
distill_prog
,
cfg
.
pretrain_weights
)
elif
cfg
.
pretrain_weights
:
checkpoint
.
load_params
(
exe
,
distill_prog
,
cfg
.
pretrain_weights
,
ignore_params
=
ignore_params
)
train_reader
=
create_reader
(
cfg
.
TrainReader
,
(
cfg
.
max_iters
-
start_iter
)
*
devices_num
,
cfg
)
train_loader
.
set_sample_list_generator
(
train_reader
,
place
)
# whether output bbox is normalized in model output layer
is_bbox_normalized
=
False
if
hasattr
(
model
,
'is_bbox_normalized'
)
and
\
callable
(
model
.
is_bbox_normalized
):
is_bbox_normalized
=
model
.
is_bbox_normalized
()
map_type
=
cfg
.
map_type
if
'map_type'
in
cfg
else
'11point'
best_box_ap_list
=
[
0.0
,
0
]
#[map, iter]
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_dir
=
os
.
path
.
join
(
cfg
.
save_dir
,
cfg_name
)
train_loader
.
start
()
for
step_id
in
range
(
start_iter
,
cfg
.
max_iters
):
teacher_loss_np
,
distill_loss_np
,
loss_np
,
lr_np
=
exe
.
run
(
parallel_main
,
fetch_list
=
[
'teacher_'
+
teacher_loss
.
name
,
distill_loss
.
name
,
loss
.
name
,
lr
.
name
])
if
step_id
%
cfg
.
log_iter
==
0
:
logger
.
info
(
"step {} lr {:.6f}, loss {:.6f}, distill_loss {:.6f}, teacher_loss {:.6f}"
.
format
(
step_id
,
lr_np
[
0
],
loss_np
[
0
],
distill_loss_np
[
0
],
teacher_loss_np
[
0
]))
if
step_id
%
cfg
.
snapshot_iter
==
0
and
step_id
!=
0
or
step_id
==
cfg
.
max_iters
-
1
:
save_name
=
str
(
step_id
)
if
step_id
!=
cfg
.
max_iters
-
1
else
"model_final"
checkpoint
.
save
(
exe
,
distill_prog
,
os
.
path
.
join
(
save_dir
,
save_name
))
# eval
results
=
eval_run
(
exe
,
compiled_eval_prog
,
eval_loader
,
eval_keys
,
eval_values
,
eval_cls
)
resolution
=
None
box_ap_stats
=
eval_results
(
results
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
is_bbox_normalized
,
FLAGS
.
output_eval
,
map_type
,
cfg
[
'EvalReader'
][
'dataset'
])
if
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
best_box_ap_list
[
1
]
=
step_id
checkpoint
.
save
(
exe
,
distill_prog
,
os
.
path
.
join
(
"./"
,
"best_model"
))
logger
.
info
(
"Best test box ap: {}, in step: {}"
.
format
(
best_box_ap_list
[
0
],
best_box_ap_list
[
1
]))
train_loader
.
reset
()
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"-t"
,
"--teacher_config"
,
default
=
None
,
type
=
str
,
help
=
"Config file of teacher architecture."
)
parser
.
add_argument
(
"-s"
,
"--slim_file"
,
default
=
None
,
type
=
str
,
help
=
"Config file of PaddleSlim."
)
parser
.
add_argument
(
"-r"
,
"--resume_checkpoint"
,
...
...
@@ -300,10 +270,11 @@ if __name__ == '__main__':
type
=
str
,
help
=
"Checkpoint path for resuming training."
)
parser
.
add_argument
(
"--eval"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Whether to perform evaluation in train"
)
"-t"
,
"--teacher_config"
,
default
=
None
,
type
=
str
,
help
=
"Config file of teacher architecture."
)
parser
.
add_argument
(
"--teacher_pretrained"
,
default
=
None
,
...
...
@@ -314,11 +285,5 @@ if __name__ == '__main__':
default
=
None
,
type
=
str
,
help
=
"Evaluation directory, default is current directory."
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
None
,
type
=
str
,
help
=
"Dataset path, same as DataFeed.dataset.dataset_dir"
)
FLAGS
=
parser
.
parse_args
()
main
()
slim/distillation/run.sh
已删除
100644 → 0
浏览文件 @
18645132
#!/usr/bin/env bash
# download pretrain model
root_url
=
"https://paddlemodels.bj.bcebos.com/object_detection"
yolov3_r34_voc
=
"yolov3_r34_voc.tar"
pretrain_dir
=
'./pretrain'
if
[
!
-d
${
pretrain_dir
}
]
;
then
mkdir
${
pretrain_dir
}
fi
cd
${
pretrain_dir
}
if
[
!
-f
${
yolov3_r34_voc
}
]
;
then
wget
${
root_url
}
/
${
yolov3_r34_voc
}
tar
xf
${
yolov3_r34_voc
}
fi
cd
-
# enable GC strategy
export
FLAGS_fast_eager_deletion_mode
=
1
export
FLAGS_eager_delete_tensor_gb
=
0.0
# for distillation
#-----------------
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
# Fixing name conflicts in distillation
cd
${
pretrain_dir
}
/yolov3_r34_voc
for
files
in
$(
ls
teacher_
*
)
do
mv
$files
${
files
#*_
}
done
for
files
in
$(
ls
*
)
do
mv
$files
"teacher_"
$files
done
cd
-
python
-u
compress.py
\
-c
../../configs/yolov3_mobilenet_v1_voc.yml
\
-t
yolov3_resnet34.yml
\
-s
yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml
\
-o
YoloTrainFeed.batch_size
=
64
\
-d
../../dataset/voc
\
--teacher_pretrained
./pretrain/yolov3_r34_voc
\
>
yolov3_distallation.log 2>&1 &
tailf yolov3_distallation.log
slim/distillation/yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml
已删除
100644 → 0
浏览文件 @
18645132
version
:
1.0
distillers
:
l2_distiller
:
class
:
'
L2Distiller'
teacher_feature_map
:
'
teacher_teacher_conv2d_1.tmp_0'
student_feature_map
:
'
conv2d_15.tmp_0'
distillation_loss_weight
:
1
strategies
:
distillation_strategy
:
class
:
'
DistillationStrategy'
distillers
:
[
'
l2_distiller'
]
start_epoch
:
0
end_epoch
:
270
compressor
:
epoch
:
271
checkpoint_path
:
'
./checkpoints/'
strategies
:
-
distillation_strategy
slim/distillation/yolov3_resnet34.yml
已删除
100644 → 0
浏览文件 @
18645132
architecture
:
YOLOv3
log_smooth_window
:
20
metric
:
VOC
map_type
:
11point
num_classes
:
20
weight_prefix_name
:
teacher_
YOLOv3
:
backbone
:
ResNet
yolo_head
:
YOLOv3Head
ResNet
:
norm_type
:
sync_bn
freeze_at
:
0
freeze_norm
:
false
norm_decay
:
0.
depth
:
34
feature_maps
:
[
3
,
4
,
5
]
YOLOv3Head
:
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
anchors
:
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
norm_decay
:
0.
ignore_thresh
:
0.7
label_smooth
:
false
nms
:
background_label
:
-1
keep_top_k
:
100
nms_threshold
:
0.45
nms_top_k
:
1000
normalized
:
false
score_threshold
:
0.01
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录