Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
1a843c55
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a843c55
编写于
1月 03, 2020
作者:
B
Bai Yifan
提交者:
GitHub
1月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update distillation demo (#128)
* update distillation demo
上级
18645132
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
391 addition
and
180 deletion
+391
-180
slim/distillation/README.md
slim/distillation/README.md
+102
-81
slim/distillation/distill.py
slim/distillation/distill.py
+289
-0
slim/distillation/run.sh
slim/distillation/run.sh
+0
-47
slim/distillation/yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml
...tion/yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml
+0
-18
slim/distillation/yolov3_resnet34.yml
slim/distillation/yolov3_resnet34.yml
+0
-34
未找到文件。
slim/distillation/README.md
100755 → 100644
浏览文件 @
1a843c55
>运行该示例前请安装Paddle1.6或更高版本
>运行该示例前请安装Paddle
Slim和Paddle
1.6或更高版本
# 检测模型蒸馏示例
## 概述
该示例使用PaddleSlim提供的
[
蒸馏策略
](
https://
github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#3-蒸馏
)
对检测库中的模型进行蒸馏训练。
该示例使用PaddleSlim提供的
[
蒸馏策略
](
https://
paddlepaddle.github.io/PaddleSlim/algo/algo/#3
)
对检测库中的模型进行蒸馏训练。
在阅读该示例前,建议您先了解以下内容:
-
[
检测库的常规训练方法
](
https://github.com/PaddlePaddle/PaddleDetection
)
-
[
PaddleSlim
使用文档
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md
)
-
[
PaddleSlim
蒸馏API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/
)
## 安装PaddleSlim
可按照
[
PaddleSlim使用文档
](
https://paddlepaddle.github.io/PaddleSlim/
)
中的步骤安装PaddleSlim
##
配置文件
说明
##
蒸馏策略
说明
关于
配置文件如何编写您可以参考:
关于
蒸馏API如何使用您可以参考PaddleSlim蒸馏API文档
-
[
PaddleSlim配置文件编写说明
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8
)
-
[
蒸馏策略配置文件编写说明
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#23-蒸馏
)
这里以ResNet34-YoloV3蒸馏MobileNetV1-YoloV3模型为例,首先,为了对
`student model`
和
`teacher model`
有个总体的认识,从而进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variable)的名称和形状:
这里以ResNet34-YoloV3蒸馏训练MobileNetV1-YoloV3模型为例,首先,为了对
`student model`
和
`teacher model`
有个总体的认识,进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variables)的名称和形状:
```
python
# 观察student model的Variable
# 观察student model的Variables
student_vars
=
[]
for
v
in
fluid
.
default_main_program
().
list_vars
():
if
"py_reader"
not
in
v
.
name
and
"double_buffer"
not
in
v
.
name
and
"generated_var"
not
in
v
.
name
:
print
(
v
.
name
,
v
.
shape
)
# 观察teacher model的Variable
try
:
student_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
print
(
"="
*
50
+
"student_model_vars"
+
"="
*
50
)
print
(
student_vars
)
# 观察teacher model的Variables
teacher_vars
=
[]
for
v
in
teacher_program
.
list_vars
():
print
(
v
.
name
,
v
.
shape
)
try
:
teacher_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
print
(
"="
*
50
+
"teacher_model_vars"
+
"="
*
50
)
print
(
teacher_vars
)
```
经过对比可以发现,
`student model`
和
`teacher model`
的部分中间结果
分别为:
经过对比可以发现,
`student model`
和
`teacher model`
输入到3个
`yolov3_loss`
的特征图
分别为:
```
bash
# student model
conv2d_
15.tmp_0
conv2d_
20.tmp_1, conv2d_28.tmp_1, conv2d_36.tmp_1
# teacher model
teacher_teacher_conv2d_1.tmp_0
conv2d_6.tmp_1, conv2d_14.tmp_1, conv2d_22.tmp_1
```
所以,我们用
`l2_distiller`
对这两个特征图做蒸馏。在配置文件中进行如下配置:
```
yaml
distillers
:
l2_distiller
:
class
:
'
L2Distiller'
teacher_feature_map
:
'
teacher_teacher_conv2d_1.tmp_0'
student_feature_map
:
'
conv2d_15.tmp_0'
distillation_loss_weight
:
1
strategies
:
distillation_strategy
:
class
:
'
DistillationStrategy'
distillers
:
[
'
l2_distiller'
]
start_epoch
:
0
end_epoch
:
270
它们形状两两相同,且分别处于两个网络的输出部分。所以,我们用
`l2_loss`
对这几个特征图两两对应添加蒸馏loss。需要注意的是,teacher的Variable在merge过程中被自动添加了一个
`name_prefix`
,所以这里也需要加上这个前缀
`"teacher_"`
,merge过程请参考
[
蒸馏API文档
](
https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/#merge
)
```
python
dist_loss_1
=
l2_loss
(
'teacher_conv2d_6.tmp_1'
,
'conv2d_20.tmp_1'
)
dist_loss_2
=
l2_loss
(
'teacher_conv2d_14.tmp_1'
,
'conv2d_28.tmp_1'
)
dist_loss_3
=
l2_loss
(
'teacher_conv2d_22.tmp_1'
,
'conv2d_36.tmp_1'
)
```
我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有
`FSP_loss`
,
`L2_loss`
和
`softmax_with_cross_entropy_loss`
。
我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有
`FSP_loss`
,
`L2_loss`
,
`softmax_with_cross_entropy_loss`
以及自定义的任何loss
。
## 训练
根据
[
PaddleDetection/tools/train.py
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/tools/train.py
)
编写压缩脚本compress.py
。
在该脚本中定义了
Compressor对象,用于执行压缩任务。
根据
[
PaddleDetection/tools/train.py
](
../../tools/train.py
)
编写压缩脚本
`distill.py`
。
在该脚本中定义了
teacher_model和student_model,用teacher_model的输出指导student_model的训练
### 执行示例
step1: 设置GPU卡
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
```
您可以通过运行脚本
`run.sh`
运行该示例。
step2: 开始训练
```
bash
python slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
### 保存断点(checkpoint)
如果要调整训练卡数,需要调整配置文件
`yolov3_mobilenet_v1_voc.yml`
中的以下参数:
如果在配置文件中设置了
`checkpoint_path`
, 则在蒸馏任务执行过程中会自动保存断点,当任务异常中断时,
重启任务会自动从
`checkpoint_path`
路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复,
需要修改配置文件中的
`checkpoint_path`
,或者将
`checkpoint_path`
路径下文件清空。
-
**max_iters:**
训练过程迭代总步数。
-
**YOLOv3Loss.batch_size:**
该参数表示单张GPU卡上的
`batch_size`
, 总
`batch_size`
是GPU卡数乘以这个值,
`batch_size`
的设定受限于显存大小。
-
**LeaningRate.base_lr:**
根据多卡的总
`batch_size`
调整
`base_lr`
,两者大小正相关,可以简单的按比例进行调整。
-
**LearningRate.schedulers.PiecewiseDecay.milestones:**
请根据batch size的变化对其调整。
-
**LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:**
请根据batch size的变化对其进行调整。
>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。
以下为4卡训练示例,通过命令行覆盖
`yolov3_mobilenet_v1_voc.yml`
中的参数:
```
shell
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
-o
YoloTrainFeed.batch_size
=
16
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
## 评估
如果在配置文件中设置了
`checkpoint_path`
,则每个epoch会保存一个压缩后的用于评估的模型,
该模型会保存在
`${checkpoint_path}/${epoch_id}/eval_model/`
路径下,包含
`__model__`
和
`__params__`
两个文件。
其中,
`__model__`
用于保存模型结构信息,
`__params__`
用于保存参数(parameters)信息。
如果不需要保存评估模型,可以在定义Compressor对象时,将
`save_eval_model`
选项设置为False(默认为True)。
运行命令为:
```
python ../eval.py \
--model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
--model_name __model__ \
--params_name __params__ \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-d "../../dataset/voc"
```
### 保存断点(checkpoint)
## 预测
蒸馏任务执行过程中会自动保存断点。如果需要从断点继续训练请用
`-r`
参数指定checkpoint路径,示例如下:
如果在配置文件中设置了
`checkpoint_path`
,并且在定义Compressor对象时指定了
`prune_infer_model`
选项,则每个epoch都会
保存一个
`inference model`
。该模型是通过删除eval_program中多余的operators而得到的。
```
bash
python
-u
slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
-r
output/yolov3_mobilenet_v1_voc/10000
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
该模型会保存在
`${checkpoint_path}/${epoch_id}/eval_model/`
路径下,包含
`__model__.infer`
和
`__params__`
两个文件。
其中,
`__model__.infer`
用于保存模型结构信息,
`__params__`
用于保存参数(parameters)信息。
更多关于
`prune_infer_model`
选项的介绍,请参考:
[
Compressor介绍
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#121-%E5%A6%82%E4%BD%95%E6%94%B9%E5%86%99%E6%99%AE%E9%80%9A%E8%AE%AD%E7%BB%83%E8%84%9A%E6%9C%AC
)
### python预测
在脚本
<a
href=
"../infer.py"
>
slim/infer.py
</a>
中展示了如何使用fluid python API加载使用预测模型进行预测。
## 评估
每隔
`snap_shot_iter`
步后会保存一个checkpoint模型可以用于评估,使用PaddleDetection目录下
[
tools/eval.py
](
../../tools/eval.py
)
评估脚本,并指定
`weights`
为训练得到的模型路径
运行命令为:
```
python ../infer.py \
--model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
--model_name __model__.infer \
--params_name __params__ \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
--infer_dir ../../demo
```
bash
export
CUDA_VISIBLE_DEVICES
=
0
python
-u
tools/eval.py
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-o
weights
=
output/yolov3_mobilenet_v1_voc/model_final
\
```
##
# PaddleLite
##
预测
该示例中产出的预测(inference)模型可以直接用PaddleLite进行加载使用。
关于PaddleLite如何使用,请参考:
[
PaddleLite使用文档
](
https://github.com/PaddlePaddle/Paddle-Lite/wiki#%E4%BD%BF%E7%94%A8
)
每隔
`snap_shot_iter`
步后保存的checkpoint模型也可以用于预测,使用PaddleDetection目录下
[
tools/infer.py
](
../../tools/infer.py
)
评估脚本,并指定
`weights`
为训练得到的模型路径
##
示例结果
##
# Python预测
>当前release的结果并非超参调优后的最好结果,仅做示例参考,后续我们会优化当前结果。
运行命令为:
```
export CUDA_VISIBLE_DEVICES=0
python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
--infer_img=demo/000000570688.jpg \
--output_dir=infer_output/ \
--draw_threshold=0.5 \
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
##
# MobileNetV1-YOLO-V3
##
示例结果
| FLOPS |Box AP|
|---|---|
|baseline|76.2 |
|蒸馏后|- |
### MobileNetV1-YOLO-V3-VOC
| FLOPS |输入尺寸|每张GPU图片个数|推理时间(fps)|Box AP|下载|
|:-:|:-:|:-:|:-:|:-:|:-:|
|baseline|608 |16|104.291|76.2|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|baseline|416 |16|-|76.7|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|baseline|320 |16|-|75.3|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|蒸馏后|608 |16|106.914|79.0||
|蒸馏后|416 |16|-|78.2||
|蒸馏后|320 |16|-|75.5||
## FAQ
> 蒸馏后的结果用ResNet34-YOLO-V3做teacher,4GPU总batch_size64训练90000 iter得到
slim/distillation/
compress
.py
→
slim/distillation/
distill
.py
浏览文件 @
1a843c55
...
...
@@ -17,38 +17,18 @@ from __future__ import division
from
__future__
import
print_function
import
os
import
time
import
multiprocessing
import
numpy
as
np
from
collections
import
deque
,
OrderedDict
from
paddle.fluid.contrib.slim.core
import
Compressor
from
paddle.fluid.framework
import
IrGraph
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
from
collections
import
OrderedDict
from
paddleslim.dist.single_distiller
import
merge
,
l2_loss
from
paddle
import
fluid
import
sys
sys
.
path
.
append
(
"../../"
)
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.
data_feed
import
create_reader
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_results
from
ppdet.data.
reader
import
create_reader
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_results
,
eval_run
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.check
import
check_gpu
import
ppdet.utils.checkpoint
as
checkpoint
from
ppdet.modeling.model_input
import
create_feed
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
...
...
@@ -56,56 +36,8 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger
=
logging
.
getLogger
(
__name__
)
def
eval_run
(
exe
,
compile_program
,
reader
,
keys
,
values
,
cls
,
test_feed
):
"""
Run evaluation program, return program outputs.
"""
iter_id
=
0
results
=
[]
if
len
(
cls
)
!=
0
:
values
=
[]
for
i
in
range
(
len
(
cls
)):
_
,
accum_map
=
cls
[
i
].
get_map_var
()
cls
[
i
].
reset
(
exe
)
values
.
append
(
accum_map
)
images_num
=
0
start_time
=
time
.
time
()
has_bbox
=
'bbox'
in
keys
for
data
in
reader
():
data
=
test_feed
.
feed
(
data
)
feed_data
=
{
'image'
:
data
[
'image'
],
'im_size'
:
data
[
'im_size'
]}
outs
=
exe
.
run
(
compile_program
,
feed
=
feed_data
,
fetch_list
=
[
values
[
0
]],
return_numpy
=
False
)
outs
.
append
(
data
[
'gt_box'
])
outs
.
append
(
data
[
'gt_label'
])
outs
.
append
(
data
[
'is_difficult'
])
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
keys
,
outs
)
}
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
iter_id
+=
1
images_num
+=
len
(
res
[
'bbox'
][
1
][
0
])
if
has_bbox
else
1
logger
.
info
(
'Test finish iter {}'
.
format
(
iter_id
))
end_time
=
time
.
time
()
fps
=
images_num
/
(
end_time
-
start_time
)
if
has_bbox
:
logger
.
info
(
'Total number of images: {}, inference time: {} fps.'
.
format
(
images_num
,
fps
))
else
:
logger
.
info
(
'Total iteration: {}, inference time: {} batch/s.'
.
format
(
images_num
,
fps
))
return
results
def
main
():
env
=
os
.
environ
cfg
=
load_config
(
FLAGS
.
config
)
if
'architecture'
in
cfg
:
main_arch
=
cfg
.
architecture
...
...
@@ -122,112 +54,60 @@ def main():
if
cfg
.
use_gpu
:
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
if
'train_feed'
not
in
cfg
:
train_feed
=
create
(
main_arch
+
'TrainFeed'
)
else
:
train_feed
=
create
(
cfg
.
train_feed
)
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
if
'
eval_feed'
not
in
cfg
:
eval_feed
=
create
(
main_arch
+
'EvalFeed'
)
if
'
FLAGS_selected_gpus'
in
env
:
device_id
=
int
(
env
[
'FLAGS_selected_gpus'
]
)
else
:
eval_feed
=
create
(
cfg
.
eval_feed
)
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
device_id
=
0
place
=
fluid
.
CUDAPlace
(
device_id
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
# build program
model
=
create
(
main_arch
)
_
,
train_feed_vars
=
create_feed
(
train_feed
,
True
)
inputs_def
=
cfg
[
'TrainReader'
][
'inputs_def'
]
train_feed_vars
,
train_loader
=
model
.
build_inputs
(
**
inputs_def
)
train_fetches
=
model
.
train
(
train_feed_vars
)
loss
=
train_fetches
[
'loss'
]
lr
=
lr_builder
()
opt
=
optim_builder
(
lr
)
opt
.
minimize
(
loss
)
#for v in fluid.default_main_program().list_vars():
# if "py_reader" not in v.name and "double_buffer" not in v.name and "generated_var" not in v.name:
# print(v.name, v.shape)
cfg
.
max_iters
=
258
train_reader
=
create_reader
(
train_feed
,
cfg
.
max_iters
,
FLAGS
.
dataset_dir
)
exe
.
run
(
fluid
.
default_startup_program
())
# parse train fetches
train_keys
,
train_values
,
_
=
parse_fetches
(
train_fetches
)
train_keys
.
append
(
'lr'
)
train_values
.
append
(
lr
.
name
)
train_fetch_list
=
[]
for
k
,
v
in
zip
(
train_keys
,
train_values
):
train_fetch_list
.
append
((
k
,
v
))
print
(
"train_fetch_list: {}"
.
format
(
train_fetch_list
))
# get all student variables
student_vars
=
[]
for
v
in
fluid
.
default_main_program
().
list_vars
():
try
:
student_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
# uncomment the following lines to print all student variables
# print("="*50 + "student_model_vars" + "="*50)
# print(student_vars)
eval_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
with
fluid
.
program_guard
(
eval_prog
,
fluid
.
default_startup_program
()):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
_
,
test_feed_vars
=
create_feed
(
eval_feed
,
True
)
inputs_def
=
cfg
[
'EvalReader'
][
'inputs_def'
]
test_feed_vars
,
eval_loader
=
model
.
build_inputs
(
**
inputs_def
)
fetches
=
model
.
eval
(
test_feed_vars
)
eval_prog
=
eval_prog
.
clone
(
True
)
eval_reader
=
create_reader
(
eval_feed
,
args_path
=
FLAGS
.
dataset_di
r
)
test_data_feed
=
fluid
.
DataFeeder
(
test_feed_vars
.
values
()
,
place
)
eval_reader
=
create_reader
(
cfg
.
EvalReade
r
)
eval_loader
.
set_sample_list_generator
(
eval_reader
,
place
)
# parse eval fetches
extra_keys
=
[]
if
cfg
.
metric
==
'COCO'
:
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
]
if
cfg
.
metric
==
'VOC'
:
extra_keys
=
[
'gt_b
ox'
,
'gt_label
'
,
'is_difficult'
]
extra_keys
=
[
'gt_b
box'
,
'gt_class
'
,
'is_difficult'
]
eval_keys
,
eval_values
,
eval_cls
=
parse_fetches
(
fetches
,
eval_prog
,
extra_keys
)
eval_fetch_list
=
[]
for
k
,
v
in
zip
(
eval_keys
,
eval_values
):
eval_fetch_list
.
append
((
k
,
v
))
print
(
"eval_fetch_list: {}"
.
format
(
eval_fetch_list
))
exe
.
run
(
startup_prog
)
checkpoint
.
load_params
(
exe
,
fluid
.
default_main_program
(),
cfg
.
pretrain_weights
)
best_box_ap_list
=
[]
def
eval_func
(
program
,
scope
):
results
=
eval_run
(
exe
,
program
,
eval_reader
,
eval_keys
,
eval_values
,
eval_cls
,
test_data_feed
)
resolution
=
None
is_bbox_normalized
=
False
if
'mask'
in
results
[
0
]:
resolution
=
model
.
mask_head
.
resolution
box_ap_stats
=
eval_results
(
results
,
eval_feed
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
is_bbox_normalized
,
FLAGS
.
output_eval
)
if
len
(
best_box_ap_list
)
==
0
:
best_box_ap_list
.
append
(
box_ap_stats
[
0
])
elif
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
logger
.
info
(
"Best test box ap: {}"
.
format
(
best_box_ap_list
[
0
]))
return
best_box_ap_list
[
0
]
test_feed
=
[(
'image'
,
test_feed_vars
[
'image'
].
name
),
(
'im_size'
,
test_feed_vars
[
'im_size'
].
name
)]
teacher_cfg
=
load_config
(
FLAGS
.
teacher_config
)
teacher_arch
=
teacher_cfg
.
architecture
teacher_programs
=
[]
teacher_program
=
fluid
.
Program
()
teacher_startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
teacher_program
,
teacher_startup_program
):
with
fluid
.
unique_name
.
guard
(
'teacher_'
):
with
fluid
.
unique_name
.
guard
():
teacher_feed_vars
=
OrderedDict
()
for
name
,
var
in
train_feed_vars
.
items
():
teacher_feed_vars
[
name
]
=
teacher_program
.
global_block
(
...
...
@@ -235,64 +115,154 @@ def main():
var
,
force_persistable
=
False
)
model
=
create
(
teacher_arch
)
train_fetches
=
model
.
train
(
teacher_feed_vars
)
#print("="*50+"teacher_model_params"+"="*50)
#for v in teacher_program.list_vars():
# print(v.name, v.shape)
#return
teacher_loss
=
train_fetches
[
'loss'
]
# get all teacher variables
teacher_vars
=
[]
for
v
in
teacher_program
.
list_vars
():
try
:
teacher_vars
.
append
((
v
.
name
,
v
.
shape
))
except
:
pass
# uncomment the following lines to print all teacher variables
# print("="*50 + "teacher_model_vars" + "="*50)
# print(teacher_vars)
exe
.
run
(
teacher_startup_program
)
assert
FLAGS
.
teacher_pretrained
and
os
.
path
.
exists
(
FLAGS
.
teacher_pretrained
),
"teacher_pretrained should be set when teacher_model is not None."
def
if_exist
(
var
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
FLAGS
.
teacher_pretrained
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
FLAGS
.
teacher_pretrained
,
main_program
=
teacher_program
,
predicate
=
if_exist
)
teacher_programs
.
append
(
teacher_program
.
clone
(
for_test
=
True
))
com
=
Compressor
(
place
,
fluid
.
global_scope
(),
fluid
.
default_main_program
(),
train_reader
=
train_reader
,
train_feed_list
=
[(
key
,
value
.
name
)
for
key
,
value
in
train_feed_vars
.
items
()],
train_fetch_list
=
train_fetch_list
,
eval_program
=
eval_prog
,
eval_reader
=
eval_reader
,
eval_feed_list
=
test_feed
,
eval_func
=
{
'map'
:
eval_func
},
eval_fetch_list
=
eval_fetch_list
[
0
:
1
],
save_eval_model
=
True
,
prune_infer_model
=
[[
"image"
,
"im_size"
],
[
"multiclass_nms_0.tmp_0"
]],
teacher_programs
=
teacher_programs
,
train_optimizer
=
None
,
distiller_optimizer
=
opt
,
log_period
=
20
)
com
.
config
(
FLAGS
.
slim_file
)
com
.
run
()
assert
FLAGS
.
teacher_pretrained
,
"teacher_pretrained should be set"
checkpoint
.
load_params
(
exe
,
teacher_program
,
FLAGS
.
teacher_pretrained
)
teacher_program
=
teacher_program
.
clone
(
for_test
=
True
)
cfg
=
load_config
(
FLAGS
.
config
)
data_name_map
=
{
'image'
:
'image'
,
'gt_bbox'
:
'gt_bbox'
,
'gt_class'
:
'gt_class'
,
'gt_score'
:
'gt_score'
}
distill_prog
=
merge
(
teacher_program
,
fluid
.
default_main_program
(),
data_name_map
,
place
)
distill_weight
=
100
distill_pairs
=
[[
'teacher_conv2d_6.tmp_1'
,
'conv2d_20.tmp_1'
],
[
'teacher_conv2d_14.tmp_1'
,
'conv2d_28.tmp_1'
],
[
'teacher_conv2d_22.tmp_1'
,
'conv2d_36.tmp_1'
]]
def
l2_distill
(
pairs
,
weight
):
"""
Add l2 distillation losses composed of multi pairs of feature maps,
each pair of feature maps is the input of teacher and student's
yolov3_loss respectively
"""
loss
=
[]
for
pair
in
pairs
:
loss
.
append
(
l2_loss
(
pair
[
0
],
pair
[
1
]))
loss
=
fluid
.
layers
.
sum
(
loss
)
weighted_loss
=
loss
*
weight
return
weighted_loss
distill_loss
=
l2_distill
(
distill_pairs
,
distill_weight
)
loss
=
distill_loss
+
loss
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
lr
=
lr_builder
()
opt
=
optim_builder
(
lr
)
opt
.
minimize
(
loss
)
exe
.
run
(
fluid
.
default_startup_program
())
checkpoint
.
load_params
(
exe
,
fluid
.
default_main_program
(),
cfg
.
pretrain_weights
)
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
fuse_all_reduce_ops
=
False
build_strategy
.
fuse_all_optimizer_ops
=
False
build_strategy
.
fuse_elewise_add_act_ops
=
True
# only enable sync_bn in multi GPU devices
sync_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'sync_bn'
build_strategy
.
sync_batch_norm
=
sync_bn
and
devices_num
>
1
\
and
cfg
.
use_gpu
exec_strategy
=
fluid
.
ExecutionStrategy
()
# iteration number when CompiledProgram tries to drop local execution scopes.
# Set it to be 1 to save memory usages, so that unused variables in
# local execution scopes can be deleted after each iteration.
exec_strategy
.
num_iteration_per_drop_scope
=
1
parallel_main
=
fluid
.
CompiledProgram
(
distill_prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
compiled_eval_prog
=
fluid
.
compiler
.
CompiledProgram
(
eval_prog
)
fuse_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'affine_channel'
ignore_params
=
cfg
.
finetune_exclude_pretrained_params
\
if
'finetune_exclude_pretrained_params'
in
cfg
else
[]
start_iter
=
0
if
FLAGS
.
resume_checkpoint
:
checkpoint
.
load_checkpoint
(
exe
,
distill_prog
,
FLAGS
.
resume_checkpoint
)
start_iter
=
checkpoint
.
global_step
()
elif
cfg
.
pretrain_weights
and
fuse_bn
and
not
ignore_params
:
checkpoint
.
load_and_fusebn
(
exe
,
distill_prog
,
cfg
.
pretrain_weights
)
elif
cfg
.
pretrain_weights
:
checkpoint
.
load_params
(
exe
,
distill_prog
,
cfg
.
pretrain_weights
,
ignore_params
=
ignore_params
)
train_reader
=
create_reader
(
cfg
.
TrainReader
,
(
cfg
.
max_iters
-
start_iter
)
*
devices_num
,
cfg
)
train_loader
.
set_sample_list_generator
(
train_reader
,
place
)
# whether output bbox is normalized in model output layer
is_bbox_normalized
=
False
if
hasattr
(
model
,
'is_bbox_normalized'
)
and
\
callable
(
model
.
is_bbox_normalized
):
is_bbox_normalized
=
model
.
is_bbox_normalized
()
map_type
=
cfg
.
map_type
if
'map_type'
in
cfg
else
'11point'
best_box_ap_list
=
[
0.0
,
0
]
#[map, iter]
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_dir
=
os
.
path
.
join
(
cfg
.
save_dir
,
cfg_name
)
train_loader
.
start
()
for
step_id
in
range
(
start_iter
,
cfg
.
max_iters
):
teacher_loss_np
,
distill_loss_np
,
loss_np
,
lr_np
=
exe
.
run
(
parallel_main
,
fetch_list
=
[
'teacher_'
+
teacher_loss
.
name
,
distill_loss
.
name
,
loss
.
name
,
lr
.
name
])
if
step_id
%
cfg
.
log_iter
==
0
:
logger
.
info
(
"step {} lr {:.6f}, loss {:.6f}, distill_loss {:.6f}, teacher_loss {:.6f}"
.
format
(
step_id
,
lr_np
[
0
],
loss_np
[
0
],
distill_loss_np
[
0
],
teacher_loss_np
[
0
]))
if
step_id
%
cfg
.
snapshot_iter
==
0
and
step_id
!=
0
or
step_id
==
cfg
.
max_iters
-
1
:
save_name
=
str
(
step_id
)
if
step_id
!=
cfg
.
max_iters
-
1
else
"model_final"
checkpoint
.
save
(
exe
,
distill_prog
,
os
.
path
.
join
(
save_dir
,
save_name
))
# eval
results
=
eval_run
(
exe
,
compiled_eval_prog
,
eval_loader
,
eval_keys
,
eval_values
,
eval_cls
)
resolution
=
None
box_ap_stats
=
eval_results
(
results
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
is_bbox_normalized
,
FLAGS
.
output_eval
,
map_type
,
cfg
[
'EvalReader'
][
'dataset'
])
if
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
best_box_ap_list
[
1
]
=
step_id
checkpoint
.
save
(
exe
,
distill_prog
,
os
.
path
.
join
(
"./"
,
"best_model"
))
logger
.
info
(
"Best test box ap: {}, in step: {}"
.
format
(
best_box_ap_list
[
0
],
best_box_ap_list
[
1
]))
train_loader
.
reset
()
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"-t"
,
"--teacher_config"
,
default
=
None
,
type
=
str
,
help
=
"Config file of teacher architecture."
)
parser
.
add_argument
(
"-s"
,
"--slim_file"
,
default
=
None
,
type
=
str
,
help
=
"Config file of PaddleSlim."
)
parser
.
add_argument
(
"-r"
,
"--resume_checkpoint"
,
...
...
@@ -300,10 +270,11 @@ if __name__ == '__main__':
type
=
str
,
help
=
"Checkpoint path for resuming training."
)
parser
.
add_argument
(
"--eval"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Whether to perform evaluation in train"
)
"-t"
,
"--teacher_config"
,
default
=
None
,
type
=
str
,
help
=
"Config file of teacher architecture."
)
parser
.
add_argument
(
"--teacher_pretrained"
,
default
=
None
,
...
...
@@ -314,11 +285,5 @@ if __name__ == '__main__':
default
=
None
,
type
=
str
,
help
=
"Evaluation directory, default is current directory."
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
None
,
type
=
str
,
help
=
"Dataset path, same as DataFeed.dataset.dataset_dir"
)
FLAGS
=
parser
.
parse_args
()
main
()
slim/distillation/run.sh
已删除
100644 → 0
浏览文件 @
18645132
#!/usr/bin/env bash
# download pretrain model
root_url
=
"https://paddlemodels.bj.bcebos.com/object_detection"
yolov3_r34_voc
=
"yolov3_r34_voc.tar"
pretrain_dir
=
'./pretrain'
if
[
!
-d
${
pretrain_dir
}
]
;
then
mkdir
${
pretrain_dir
}
fi
cd
${
pretrain_dir
}
if
[
!
-f
${
yolov3_r34_voc
}
]
;
then
wget
${
root_url
}
/
${
yolov3_r34_voc
}
tar
xf
${
yolov3_r34_voc
}
fi
cd
-
# enable GC strategy
export
FLAGS_fast_eager_deletion_mode
=
1
export
FLAGS_eager_delete_tensor_gb
=
0.0
# for distillation
#-----------------
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
# Fixing name conflicts in distillation
cd
${
pretrain_dir
}
/yolov3_r34_voc
for
files
in
$(
ls
teacher_
*
)
do
mv
$files
${
files
#*_
}
done
for
files
in
$(
ls
*
)
do
mv
$files
"teacher_"
$files
done
cd
-
python
-u
compress.py
\
-c
../../configs/yolov3_mobilenet_v1_voc.yml
\
-t
yolov3_resnet34.yml
\
-s
yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml
\
-o
YoloTrainFeed.batch_size
=
64
\
-d
../../dataset/voc
\
--teacher_pretrained
./pretrain/yolov3_r34_voc
\
>
yolov3_distallation.log 2>&1 &
tailf yolov3_distallation.log
slim/distillation/yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml
已删除
100644 → 0
浏览文件 @
18645132
version
:
1.0
distillers
:
l2_distiller
:
class
:
'
L2Distiller'
teacher_feature_map
:
'
teacher_teacher_conv2d_1.tmp_0'
student_feature_map
:
'
conv2d_15.tmp_0'
distillation_loss_weight
:
1
strategies
:
distillation_strategy
:
class
:
'
DistillationStrategy'
distillers
:
[
'
l2_distiller'
]
start_epoch
:
0
end_epoch
:
270
compressor
:
epoch
:
271
checkpoint_path
:
'
./checkpoints/'
strategies
:
-
distillation_strategy
slim/distillation/yolov3_resnet34.yml
已删除
100644 → 0
浏览文件 @
18645132
architecture
:
YOLOv3
log_smooth_window
:
20
metric
:
VOC
map_type
:
11point
num_classes
:
20
weight_prefix_name
:
teacher_
YOLOv3
:
backbone
:
ResNet
yolo_head
:
YOLOv3Head
ResNet
:
norm_type
:
sync_bn
freeze_at
:
0
freeze_norm
:
false
norm_decay
:
0.
depth
:
34
feature_maps
:
[
3
,
4
,
5
]
YOLOv3Head
:
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
anchors
:
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
norm_decay
:
0.
ignore_thresh
:
0.7
label_smooth
:
false
nms
:
background_label
:
-1
keep_top_k
:
100
nms_threshold
:
0.45
nms_top_k
:
1000
normalized
:
false
score_threshold
:
0.01
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录