Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
ccf94523
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ccf94523
编写于
1月 17, 2020
作者:
B
Bai Yifan
提交者:
qingqing01
1月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add COCO distillation demo (#189)
* update merge api * add coco distillation demo * fix details * resolve some issues
上级
9e306e6e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
230 addition
and
48 deletion
+230
-48
slim/distillation/README.md
slim/distillation/README.md
+94
-7
slim/distillation/distill.py
slim/distillation/distill.py
+136
-41
未找到文件。
slim/distillation/README.md
浏览文件 @
ccf94523
...
...
@@ -17,7 +17,9 @@
关于蒸馏API如何使用您可以参考PaddleSlim蒸馏API文档
这里以ResNet34-YoloV3蒸馏训练MobileNetV1-YoloV3模型为例,首先,为了对
`student model`
和
`teacher model`
有个总体的认识,进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variables)的名称和形状:
### MobileNetV1-YOLOv3在VOC数据集上的蒸馏
这里以ResNet34-YOLOv3蒸馏训练MobileNetV1-YOLOv3模型为例,首先,为了对
`student model`
和
`teacher model`
有个总体的认识,进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variables)的名称和形状:
```
python
# 观察student model的Variables
...
...
@@ -60,6 +62,29 @@ dist_loss_3 = l2_loss('teacher_conv2d_22.tmp_1', 'conv2d_36.tmp_1')
我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有
`FSP_loss`
,
`L2_loss`
,
`softmax_with_cross_entropy_loss`
以及自定义的任何loss。
### MobileNetV1-YOLOv3在COCO数据集上的蒸馏
这里以ResNet34-YOLOv3作为蒸馏训练的teacher网络, 对MobileNetV1-YOLOv3结构的student网络进行蒸馏。
COCO数据集作为目标检测任务的训练目标难度更大,意味着teacher网络会预测出更多的背景bbox,如果直接用teacher的预测输出作为student学习的
`soft label`
会有严重的类别不均衡问题。解决这个问题需要引入新的方法,详细背景请参考论文:
[
Object detection at 200 Frames Per Second
](
https://arxiv.org/abs/1805.06361
)
为了确定蒸馏的对象,我们首先需要找到student和teacher网络得到的
`x,y,w,h,cls.objness`
等变量在PaddlePaddle框架中的实际名称(var.name)。进而根据名称取出这些变量,用teacher得到的结果指导student训练。找到的所有变量如下:
```
python
yolo_output_names
=
[
'strided_slice_0.tmp_0'
,
'strided_slice_1.tmp_0'
,
'strided_slice_2.tmp_0'
,
'strided_slice_3.tmp_0'
,
'strided_slice_4.tmp_0'
,
'transpose_0.tmp_0'
,
'strided_slice_5.tmp_0'
,
'strided_slice_6.tmp_0'
,
'strided_slice_7.tmp_0'
,
'strided_slice_8.tmp_0'
,
'strided_slice_9.tmp_0'
,
'transpose_2.tmp_0'
,
'strided_slice_10.tmp_0'
,
'strided_slice_11.tmp_0'
,
'strided_slice_12.tmp_0'
,
'strided_slice_13.tmp_0'
,
'strided_slice_14.tmp_0'
,
'transpose_4.tmp_0'
]
```
然后,就可以根据论文
<
<
Object
detection
at
200
Frames
Per
Second
>
>的方法为YOLOv3中分类、回归、objness三个不同的head适配不同的蒸馏损失函数,并对分类和回归的损失函数用objness分值进行抑制,以解决前景背景类别不均衡问题。
## 训练
根据
[
PaddleDetection/tools/train.py
](
../../tools/train.py
)
编写压缩脚本
`distill.py`
。
...
...
@@ -76,12 +101,22 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
step2: 开始训练
```
bash
# yolov3_mobilenet_v1在voc数据集上蒸馏
python slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
```
bash
# yolov3_mobilenet_v1在COCO数据集上蒸馏
python slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1.yml
\
-o
use_fine_grained_loss
=
true
\
-t
configs/yolov3_r34.yml
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar
```
如果要调整训练卡数,需要调整配置文件
`yolov3_mobilenet_v1_voc.yml`
中的以下参数:
-
**max_iters:**
训练过程迭代总步数。
...
...
@@ -90,17 +125,27 @@ python slim/distillation/distill.py \
-
**LearningRate.schedulers.PiecewiseDecay.milestones:**
请根据batch size的变化对其调整。
-
**LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:**
请根据batch size的变化对其进行调整。
以下为4卡训练示例,通过命令行
覆盖
`yolov3_mobilenet_v1_voc.yml`
中的参数:
以下为4卡训练示例,通过命令行
-o参数覆盖
`yolov3_mobilenet_v1_voc.yml`
中的参数, 修改GPU卡数后应尽量确保总batch_size(GPU卡数
\*
YoloTrainFeed.batch_size)不变, 以确保训练效果不因bs大小受影响:
```
shell
```
bash
# yolov3_mobilenet_v1在VOC数据集上蒸馏
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
-o
Y
oloTrainFeed
.batch_size
=
16
\
-o
Y
OLOv3Loss
.batch_size
=
16
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
```
bash
# yolov3_mobilenet_v1在COCO数据集上蒸馏
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1.yml
\
-t
configs/yolov3_r34.yml
\
-o
use_fine_grained_loss
=
true
YOLOv3Loss.batch_size
=
16
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar
```
...
...
@@ -109,6 +154,7 @@ python slim/distillation/distill.py \
蒸馏任务执行过程中会自动保存断点。如果需要从断点继续训练请用
`-r`
参数指定checkpoint路径,示例如下:
```
bash
# yolov3_mobilenet_v1在VOC数据集上恢复断点
python
-u
slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-t
configs/yolov3_r34_voc.yml
\
...
...
@@ -116,6 +162,15 @@ python -u slim/distillation/distill.py \
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
```
bash
# yolov3_mobilenet_v1在COCO数据集上恢复断点
python
-u
slim/distillation/distill.py
\
-c
configs/yolov3_mobilenet_v1.yml
\
-t
configs/yolov3_r34.yml
\
-o
use_fine_grained_loss
=
true
\
-r
output/yolov3_mobilenet_v1/10000
\
--teacher_pretrained
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar
```
...
...
@@ -125,11 +180,19 @@ python -u slim/distillation/distill.py \
运行命令为:
```
bash
# yolov3_mobilenet_v1在VOC数据集上评估
export
CUDA_VISIBLE_DEVICES
=
0
python
-u
tools/eval.py
-c
configs/yolov3_mobilenet_v1_voc.yml
\
-o
weights
=
output/yolov3_mobilenet_v1_voc/model_final
\
```
```
bash
# yolov3_mobilenet_v1在COCO数据集上评估
export
CUDA_VISIBLE_DEVICES
=
0
python
-u
tools/eval.py
-c
configs/yolov3_mobilenet_v1.yml
\
-o
weights
=
output/yolov3_mobilenet_v1/model_final
\
```
## 预测
每隔
`snap_shot_iter`
步后保存的checkpoint模型也可以用于预测,使用PaddleDetection目录下
[
tools/infer.py
](
../../tools/infer.py
)
评估脚本,并指定
`weights`
为训练得到的模型路径
...
...
@@ -138,6 +201,7 @@ python -u tools/eval.py -c configs/yolov3_mobilenet_v1_voc.yml \
运行命令为:
```
# 使用yolov3_mobilenet_v1_voc模型进行预测
export CUDA_VISIBLE_DEVICES=0
python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
--infer_img=demo/000000570688.jpg \
...
...
@@ -146,6 +210,16 @@ python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
```
# 使用yolov3_mobilenet_v1_coco模型进行预测
export CUDA_VISIBLE_DEVICES=0
python -u tools/infer.py -c configs/yolov3_mobilenet_v1.yml \
--infer_img=demo/000000570688.jpg \
--output_dir=infer_output/ \
--draw_threshold=0.5 \
-o weights=output/yolov3_mobilenet_v1/model_final
```
## 示例结果
### MobileNetV1-YOLO-V3-VOC
...
...
@@ -153,10 +227,23 @@ python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
| FLOPS |输入尺寸|每张GPU图片个数|推理时间(fps)|Box AP|下载|
|:-:|:-:|:-:|:-:|:-:|:-:|
|baseline|608 |16|104.291|76.2|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|蒸馏后|608 |16|106.914|79.0|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_voc_distilled.tar
)
|
|baseline|416 |16|-|76.7|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|蒸馏后|416 |16|-|78.2|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_voc_distilled.tar
)
|
|baseline|320 |16|-|75.3|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|蒸馏后|608 |16|106.914|79.0||
|蒸馏后|416 |16|-|78.2||
|蒸馏后|320 |16|-|75.5||
|蒸馏后|320 |16|-|75.5|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_voc_distilled.tar
)
|
> 蒸馏后的结果用ResNet34-YOLO-V3做teacher,4GPU总batch_size64训练90000 iter得到
### MobileNetV1-YOLO-V3-COCO
| FLOPS |输入尺寸|每张GPU图片个数|推理时间(fps)|Box AP|下载|
|:-:|:-:|:-:|:-:|:-:|:-:|
|baseline|608 |16|78.302|29.3|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|蒸馏后|608 |16|78.523|31.4|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar
)
|
|baseline|416 |16|-|29.3|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|蒸馏后|416 |16|-|30.0|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar
)
|
|baseline|320 |16|-|27.0|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
)
|
|蒸馏后|320 |16|-|27.1|
[
下载链接
](
https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar
)
|
> 蒸馏后的结果用ResNet34-YOLO-V3做teacher,4GPU总batch_size64训练600000 iter得到
slim/distillation/distill.py
浏览文件 @
ccf94523
...
...
@@ -36,6 +36,92 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger
=
logging
.
getLogger
(
__name__
)
def
l2_distill
(
pairs
,
weight
):
"""
Add l2 distillation losses composed of multi pairs of feature maps,
each pair of feature maps is the input of teacher and student's
yolov3_loss respectively
"""
loss
=
[]
for
pair
in
pairs
:
loss
.
append
(
l2_loss
(
pair
[
0
],
pair
[
1
]))
loss
=
fluid
.
layers
.
sum
(
loss
)
weighted_loss
=
loss
*
weight
return
weighted_loss
def
split_distill
(
split_output_names
,
weight
):
"""
Add fine grained distillation losses.
Each loss is composed by distill_reg_loss, distill_cls_loss and
distill_obj_loss
"""
student_var
=
[]
for
name
in
split_output_names
:
student_var
.
append
(
fluid
.
default_main_program
().
global_block
().
var
(
name
))
s_x0
,
s_y0
,
s_w0
,
s_h0
,
s_obj0
,
s_cls0
=
student_var
[
0
:
6
]
s_x1
,
s_y1
,
s_w1
,
s_h1
,
s_obj1
,
s_cls1
=
student_var
[
6
:
12
]
s_x2
,
s_y2
,
s_w2
,
s_h2
,
s_obj2
,
s_cls2
=
student_var
[
12
:
18
]
teacher_var
=
[]
for
name
in
split_output_names
:
teacher_var
.
append
(
fluid
.
default_main_program
().
global_block
().
var
(
'teacher_'
+
name
))
t_x0
,
t_y0
,
t_w0
,
t_h0
,
t_obj0
,
t_cls0
=
teacher_var
[
0
:
6
]
t_x1
,
t_y1
,
t_w1
,
t_h1
,
t_obj1
,
t_cls1
=
teacher_var
[
6
:
12
]
t_x2
,
t_y2
,
t_w2
,
t_h2
,
t_obj2
,
t_cls2
=
teacher_var
[
12
:
18
]
def
obj_weighted_reg
(
sx
,
sy
,
sw
,
sh
,
tx
,
ty
,
tw
,
th
,
tobj
):
loss_x
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
sx
,
fluid
.
layers
.
sigmoid
(
tx
))
loss_y
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
sy
,
fluid
.
layers
.
sigmoid
(
ty
))
loss_w
=
fluid
.
layers
.
abs
(
sw
-
tw
)
loss_h
=
fluid
.
layers
.
abs
(
sh
-
th
)
loss
=
fluid
.
layers
.
sum
([
loss_x
,
loss_y
,
loss_w
,
loss_h
])
weighted_loss
=
fluid
.
layers
.
reduce_mean
(
loss
*
fluid
.
layers
.
sigmoid
(
tobj
))
return
weighted_loss
def
obj_weighted_cls
(
scls
,
tcls
,
tobj
):
loss
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
scls
,
fluid
.
layers
.
sigmoid
(
tcls
))
weighted_loss
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
elementwise_mul
(
loss
,
fluid
.
layers
.
sigmoid
(
tobj
),
axis
=
0
))
return
weighted_loss
def
obj_loss
(
sobj
,
tobj
):
obj_mask
=
fluid
.
layers
.
cast
(
tobj
>
0.
,
dtype
=
"float32"
)
obj_mask
.
stop_gradient
=
True
loss
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
sobj
,
obj_mask
))
return
loss
distill_reg_loss0
=
obj_weighted_reg
(
s_x0
,
s_y0
,
s_w0
,
s_h0
,
t_x0
,
t_y0
,
t_w0
,
t_h0
,
t_obj0
)
distill_reg_loss1
=
obj_weighted_reg
(
s_x1
,
s_y1
,
s_w1
,
s_h1
,
t_x1
,
t_y1
,
t_w1
,
t_h1
,
t_obj1
)
distill_reg_loss2
=
obj_weighted_reg
(
s_x2
,
s_y2
,
s_w2
,
s_h2
,
t_x2
,
t_y2
,
t_w2
,
t_h2
,
t_obj2
)
distill_reg_loss
=
fluid
.
layers
.
sum
(
[
distill_reg_loss0
,
distill_reg_loss1
,
distill_reg_loss2
])
distill_cls_loss0
=
obj_weighted_cls
(
s_cls0
,
t_cls0
,
t_obj0
)
distill_cls_loss1
=
obj_weighted_cls
(
s_cls1
,
t_cls1
,
t_obj1
)
distill_cls_loss2
=
obj_weighted_cls
(
s_cls2
,
t_cls2
,
t_obj2
)
distill_cls_loss
=
fluid
.
layers
.
sum
(
[
distill_cls_loss0
,
distill_cls_loss1
,
distill_cls_loss2
])
distill_obj_loss0
=
obj_loss
(
s_obj0
,
t_obj0
)
distill_obj_loss1
=
obj_loss
(
s_obj1
,
t_obj1
)
distill_obj_loss2
=
obj_loss
(
s_obj2
,
t_obj2
)
distill_obj_loss
=
fluid
.
layers
.
sum
(
[
distill_obj_loss0
,
distill_obj_loss1
,
distill_obj_loss2
])
loss
=
(
distill_reg_loss
+
distill_cls_loss
+
distill_obj_loss
)
*
weight
return
loss
def
main
():
env
=
os
.
environ
cfg
=
load_config
(
FLAGS
.
config
)
...
...
@@ -69,6 +155,31 @@ def main():
train_feed_vars
,
train_loader
=
model
.
build_inputs
(
**
inputs_def
)
train_fetches
=
model
.
train
(
train_feed_vars
)
loss
=
train_fetches
[
'loss'
]
fuse_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'affine_channel'
ignore_params
=
cfg
.
finetune_exclude_pretrained_params
\
if
'finetune_exclude_pretrained_params'
in
cfg
else
[]
start_iter
=
0
if
FLAGS
.
resume_checkpoint
:
checkpoint
.
load_checkpoint
(
exe
,
fluid
.
default_main_program
(),
FLAGS
.
resume_checkpoint
)
start_iter
=
checkpoint
.
global_step
()
elif
cfg
.
pretrain_weights
and
fuse_bn
and
not
ignore_params
:
checkpoint
.
load_and_fusebn
(
exe
,
fluid
.
default_main_program
(),
cfg
.
pretrain_weights
)
elif
cfg
.
pretrain_weights
:
checkpoint
.
load_params
(
exe
,
fluid
.
default_main_program
(),
cfg
.
pretrain_weights
,
ignore_params
=
ignore_params
)
train_reader
=
create_reader
(
cfg
.
TrainReader
,
(
cfg
.
max_iters
-
start_iter
)
*
devices_num
,
cfg
)
train_loader
.
set_sample_list_generator
(
train_reader
,
place
)
# get all student variables
student_vars
=
[]
for
v
in
fluid
.
default_main_program
().
list_vars
():
...
...
@@ -102,6 +213,7 @@ def main():
extra_keys
)
teacher_cfg
=
load_config
(
FLAGS
.
teacher_config
)
merge_config
(
FLAGS
.
opt
)
teacher_arch
=
teacher_cfg
.
architecture
teacher_program
=
fluid
.
Program
()
teacher_startup_program
=
fluid
.
Program
()
...
...
@@ -135,33 +247,34 @@ def main():
cfg
=
load_config
(
FLAGS
.
config
)
data_name_map
=
{
'target0'
:
'target0'
,
'target1'
:
'target1'
,
'target2'
:
'target2'
,
'image'
:
'image'
,
'gt_bbox'
:
'gt_bbox'
,
'gt_class'
:
'gt_class'
,
'gt_score'
:
'gt_score'
}
distill_prog
=
merge
(
teacher_program
,
fluid
.
default_main_program
(),
data_name_map
,
place
)
merge
(
teacher_program
,
fluid
.
default_main_program
(),
data_name_map
,
place
)
yolo_output_names
=
[
'strided_slice_0.tmp_0'
,
'strided_slice_1.tmp_0'
,
'strided_slice_2.tmp_0'
,
'strided_slice_3.tmp_0'
,
'strided_slice_4.tmp_0'
,
'transpose_0.tmp_0'
,
'strided_slice_5.tmp_0'
,
'strided_slice_6.tmp_0'
,
'strided_slice_7.tmp_0'
,
'strided_slice_8.tmp_0'
,
'strided_slice_9.tmp_0'
,
'transpose_2.tmp_0'
,
'strided_slice_10.tmp_0'
,
'strided_slice_11.tmp_0'
,
'strided_slice_12.tmp_0'
,
'strided_slice_13.tmp_0'
,
'strided_slice_14.tmp_0'
,
'transpose_4.tmp_0'
]
distill_weight
=
100
distill_pairs
=
[[
'teacher_conv2d_6.tmp_1'
,
'conv2d_20.tmp_1'
],
[
'teacher_conv2d_14.tmp_1'
,
'conv2d_28.tmp_1'
],
[
'teacher_conv2d_22.tmp_1'
,
'conv2d_36.tmp_1'
]]
def
l2_distill
(
pairs
,
weight
):
"""
Add l2 distillation losses composed of multi pairs of feature maps,
each pair of feature maps is the input of teacher and student's
yolov3_loss respectively
"""
loss
=
[]
for
pair
in
pairs
:
loss
.
append
(
l2_loss
(
pair
[
0
],
pair
[
1
]))
loss
=
fluid
.
layers
.
sum
(
loss
)
weighted_loss
=
loss
*
weight
return
weighted_loss
distill_loss
=
l2_distill
(
distill_pairs
,
distill_weight
)
distill_loss
=
l2_distill
(
distill_pairs
,
100
)
if
not
cfg
.
use_fine_grained_loss
else
split_distill
(
yolo_output_names
,
1000
)
loss
=
distill_loss
+
loss
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
...
...
@@ -170,8 +283,6 @@ def main():
opt
.
minimize
(
loss
)
exe
.
run
(
fluid
.
default_startup_program
())
checkpoint
.
load_params
(
exe
,
fluid
.
default_main_program
(),
cfg
.
pretrain_weights
)
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
fuse_all_reduce_ops
=
False
...
...
@@ -188,32 +299,14 @@ def main():
# local execution scopes can be deleted after each iteration.
exec_strategy
.
num_iteration_per_drop_scope
=
1
parallel_main
=
fluid
.
CompiledProgram
(
distill_prog
).
with_data_parallel
(
parallel_main
=
fluid
.
CompiledProgram
(
fluid
.
default_main_program
(
)).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
compiled_eval_prog
=
fluid
.
compiler
.
CompiledProgram
(
eval_prog
)
fuse_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'affine_channel'
ignore_params
=
cfg
.
finetune_exclude_pretrained_params
\
if
'finetune_exclude_pretrained_params'
in
cfg
else
[]
start_iter
=
0
if
FLAGS
.
resume_checkpoint
:
checkpoint
.
load_checkpoint
(
exe
,
distill_prog
,
FLAGS
.
resume_checkpoint
)
start_iter
=
checkpoint
.
global_step
()
elif
cfg
.
pretrain_weights
and
fuse_bn
and
not
ignore_params
:
checkpoint
.
load_and_fusebn
(
exe
,
distill_prog
,
cfg
.
pretrain_weights
)
elif
cfg
.
pretrain_weights
:
checkpoint
.
load_params
(
exe
,
distill_prog
,
cfg
.
pretrain_weights
,
ignore_params
=
ignore_params
)
train_reader
=
create_reader
(
cfg
.
TrainReader
,
(
cfg
.
max_iters
-
start_iter
)
*
devices_num
,
cfg
)
train_loader
.
set_sample_list_generator
(
train_reader
,
place
)
# whether output bbox is normalized in model output layer
is_bbox_normalized
=
False
if
hasattr
(
model
,
'is_bbox_normalized'
)
and
\
...
...
@@ -240,7 +333,8 @@ def main():
if
step_id
%
cfg
.
snapshot_iter
==
0
and
step_id
!=
0
or
step_id
==
cfg
.
max_iters
-
1
:
save_name
=
str
(
step_id
)
if
step_id
!=
cfg
.
max_iters
-
1
else
"model_final"
checkpoint
.
save
(
exe
,
distill_prog
,
checkpoint
.
save
(
exe
,
fluid
.
default_main_program
(),
os
.
path
.
join
(
save_dir
,
save_name
))
# eval
results
=
eval_run
(
exe
,
compiled_eval_prog
,
eval_loader
,
eval_keys
,
...
...
@@ -254,7 +348,8 @@ def main():
if
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
best_box_ap_list
[
1
]
=
step_id
checkpoint
.
save
(
exe
,
distill_prog
,
checkpoint
.
save
(
exe
,
fluid
.
default_main_program
(),
os
.
path
.
join
(
"./"
,
"best_model"
))
logger
.
info
(
"Best test box ap: {}, in step: {}"
.
format
(
best_box_ap_list
[
0
],
best_box_ap_list
[
1
]))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录