提交 ccf94523 编写于 作者: B Bai Yifan 提交者: qingqing01

Add COCO distillation demo (#189)

* update merge api
* add coco distillation demo
* fix details
* resolve some issues
上级 9e306e6e
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
关于蒸馏API如何使用您可以参考PaddleSlim蒸馏API文档 关于蒸馏API如何使用您可以参考PaddleSlim蒸馏API文档
这里以ResNet34-YoloV3蒸馏训练MobileNetV1-YoloV3模型为例,首先,为了对`student model``teacher model`有个总体的认识,进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variables)的名称和形状: ### MobileNetV1-YOLOv3在VOC数据集上的蒸馏
这里以ResNet34-YOLOv3蒸馏训练MobileNetV1-YOLOv3模型为例,首先,为了对`student model``teacher model`有个总体的认识,进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variables)的名称和形状:
```python ```python
# 观察student model的Variables # 观察student model的Variables
...@@ -60,6 +62,29 @@ dist_loss_3 = l2_loss('teacher_conv2d_22.tmp_1', 'conv2d_36.tmp_1') ...@@ -60,6 +62,29 @@ dist_loss_3 = l2_loss('teacher_conv2d_22.tmp_1', 'conv2d_36.tmp_1')
我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有`FSP_loss`, `L2_loss`, `softmax_with_cross_entropy_loss` 以及自定义的任何loss。 我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有`FSP_loss`, `L2_loss`, `softmax_with_cross_entropy_loss` 以及自定义的任何loss。
### MobileNetV1-YOLOv3在COCO数据集上的蒸馏
这里以ResNet34-YOLOv3作为蒸馏训练的teacher网络, 对MobileNetV1-YOLOv3结构的student网络进行蒸馏。
COCO数据集作为目标检测任务的训练目标难度更大,意味着teacher网络会预测出更多的背景bbox,如果直接用teacher的预测输出作为student学习的`soft label`会有严重的类别不均衡问题。解决这个问题需要引入新的方法,详细背景请参考论文:[Object detection at 200 Frames Per Second](https://arxiv.org/abs/1805.06361)
为了确定蒸馏的对象,我们首先需要找到student和teacher网络得到的`x,y,w,h,cls.objness`等变量在PaddlePaddle框架中的实际名称(var.name)。进而根据名称取出这些变量,用teacher得到的结果指导student训练。找到的所有变量如下:
```python
yolo_output_names = [
'strided_slice_0.tmp_0', 'strided_slice_1.tmp_0',
'strided_slice_2.tmp_0', 'strided_slice_3.tmp_0',
'strided_slice_4.tmp_0', 'transpose_0.tmp_0', 'strided_slice_5.tmp_0',
'strided_slice_6.tmp_0', 'strided_slice_7.tmp_0',
'strided_slice_8.tmp_0', 'strided_slice_9.tmp_0', 'transpose_2.tmp_0',
'strided_slice_10.tmp_0', 'strided_slice_11.tmp_0',
'strided_slice_12.tmp_0', 'strided_slice_13.tmp_0',
'strided_slice_14.tmp_0', 'transpose_4.tmp_0'
]
```
然后,就可以根据论文<<Object detection at 200 Frames Per Second>>的方法为YOLOv3中分类、回归、objness三个不同的head适配不同的蒸馏损失函数,并对分类和回归的损失函数用objness分值进行抑制,以解决前景背景类别不均衡问题。
## 训练 ## 训练
根据[PaddleDetection/tools/train.py](../../tools/train.py)编写压缩脚本`distill.py` 根据[PaddleDetection/tools/train.py](../../tools/train.py)编写压缩脚本`distill.py`
...@@ -76,12 +101,22 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ...@@ -76,12 +101,22 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
step2: 开始训练 step2: 开始训练
```bash ```bash
# yolov3_mobilenet_v1在voc数据集上蒸馏
python slim/distillation/distill.py \ python slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1_voc.yml \ -c configs/yolov3_mobilenet_v1_voc.yml \
-t configs/yolov3_r34_voc.yml \ -t configs/yolov3_r34_voc.yml \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar --teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
``` ```
```bash
# yolov3_mobilenet_v1在COCO数据集上蒸馏
python slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1.yml \
-o use_fine_grained_loss=true \
-t configs/yolov3_r34.yml \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar
```
如果要调整训练卡数,需要调整配置文件`yolov3_mobilenet_v1_voc.yml`中的以下参数: 如果要调整训练卡数,需要调整配置文件`yolov3_mobilenet_v1_voc.yml`中的以下参数:
- **max_iters:** 训练过程迭代总步数。 - **max_iters:** 训练过程迭代总步数。
...@@ -90,17 +125,27 @@ python slim/distillation/distill.py \ ...@@ -90,17 +125,27 @@ python slim/distillation/distill.py \
- **LearningRate.schedulers.PiecewiseDecay.milestones:** 请根据batch size的变化对其调整。 - **LearningRate.schedulers.PiecewiseDecay.milestones:** 请根据batch size的变化对其调整。
- **LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:** 请根据batch size的变化对其进行调整。 - **LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:** 请根据batch size的变化对其进行调整。
以下为4卡训练示例,通过命令行覆盖`yolov3_mobilenet_v1_voc.yml`中的参数: 以下为4卡训练示例,通过命令行-o参数覆盖`yolov3_mobilenet_v1_voc.yml`中的参数, 修改GPU卡数后应尽量确保总batch_size(GPU卡数\*YoloTrainFeed.batch_size)不变, 以确保训练效果不因bs大小受影响:
```shell ```bash
# yolov3_mobilenet_v1在VOC数据集上蒸馏
CUDA_VISIBLE_DEVICES=0,1,2,3 CUDA_VISIBLE_DEVICES=0,1,2,3
python slim/distillation/distill.py \ python slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1_voc.yml \ -c configs/yolov3_mobilenet_v1_voc.yml \
-t configs/yolov3_r34_voc.yml \ -t configs/yolov3_r34_voc.yml \
-o YoloTrainFeed.batch_size=16 \ -o YOLOv3Loss.batch_size=16 \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar --teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
``` ```
```bash
# yolov3_mobilenet_v1在COCO数据集上蒸馏
CUDA_VISIBLE_DEVICES=0,1,2,3
python slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1.yml \
-t configs/yolov3_r34.yml \
-o use_fine_grained_loss=true YOLOv3Loss.batch_size=16 \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar
```
...@@ -109,6 +154,7 @@ python slim/distillation/distill.py \ ...@@ -109,6 +154,7 @@ python slim/distillation/distill.py \
蒸馏任务执行过程中会自动保存断点。如果需要从断点继续训练请用`-r`参数指定checkpoint路径,示例如下: 蒸馏任务执行过程中会自动保存断点。如果需要从断点继续训练请用`-r`参数指定checkpoint路径,示例如下:
```bash ```bash
# yolov3_mobilenet_v1在VOC数据集上恢复断点
python -u slim/distillation/distill.py \ python -u slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1_voc.yml \ -c configs/yolov3_mobilenet_v1_voc.yml \
-t configs/yolov3_r34_voc.yml \ -t configs/yolov3_r34_voc.yml \
...@@ -116,6 +162,15 @@ python -u slim/distillation/distill.py \ ...@@ -116,6 +162,15 @@ python -u slim/distillation/distill.py \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar --teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
``` ```
```bash
# yolov3_mobilenet_v1在COCO数据集上恢复断点
python -u slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1.yml \
-t configs/yolov3_r34.yml \
-o use_fine_grained_loss=true \
-r output/yolov3_mobilenet_v1/10000 \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar
```
...@@ -125,11 +180,19 @@ python -u slim/distillation/distill.py \ ...@@ -125,11 +180,19 @@ python -u slim/distillation/distill.py \
运行命令为: 运行命令为:
```bash ```bash
# yolov3_mobilenet_v1在VOC数据集上评估
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python -u tools/eval.py -c configs/yolov3_mobilenet_v1_voc.yml \ python -u tools/eval.py -c configs/yolov3_mobilenet_v1_voc.yml \
-o weights=output/yolov3_mobilenet_v1_voc/model_final \ -o weights=output/yolov3_mobilenet_v1_voc/model_final \
``` ```
```bash
# yolov3_mobilenet_v1在COCO数据集上评估
export CUDA_VISIBLE_DEVICES=0
python -u tools/eval.py -c configs/yolov3_mobilenet_v1.yml \
-o weights=output/yolov3_mobilenet_v1/model_final \
```
## 预测 ## 预测
每隔`snap_shot_iter`步后保存的checkpoint模型也可以用于预测,使用PaddleDetection目录下[tools/infer.py](../../tools/infer.py)评估脚本,并指定`weights`为训练得到的模型路径 每隔`snap_shot_iter`步后保存的checkpoint模型也可以用于预测,使用PaddleDetection目录下[tools/infer.py](../../tools/infer.py)评估脚本,并指定`weights`为训练得到的模型路径
...@@ -138,6 +201,7 @@ python -u tools/eval.py -c configs/yolov3_mobilenet_v1_voc.yml \ ...@@ -138,6 +201,7 @@ python -u tools/eval.py -c configs/yolov3_mobilenet_v1_voc.yml \
运行命令为: 运行命令为:
``` ```
# 使用yolov3_mobilenet_v1_voc模型进行预测
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \ python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
--infer_img=demo/000000570688.jpg \ --infer_img=demo/000000570688.jpg \
...@@ -146,6 +210,16 @@ python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \ ...@@ -146,6 +210,16 @@ python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
-o weights=output/yolov3_mobilenet_v1_voc/model_final -o weights=output/yolov3_mobilenet_v1_voc/model_final
``` ```
```
# 使用yolov3_mobilenet_v1_coco模型进行预测
export CUDA_VISIBLE_DEVICES=0
python -u tools/infer.py -c configs/yolov3_mobilenet_v1.yml \
--infer_img=demo/000000570688.jpg \
--output_dir=infer_output/ \
--draw_threshold=0.5 \
-o weights=output/yolov3_mobilenet_v1/model_final
```
## 示例结果 ## 示例结果
### MobileNetV1-YOLO-V3-VOC ### MobileNetV1-YOLO-V3-VOC
...@@ -153,10 +227,23 @@ python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \ ...@@ -153,10 +227,23 @@ python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
| FLOPS |输入尺寸|每张GPU图片个数|推理时间(fps)|Box AP|下载| | FLOPS |输入尺寸|每张GPU图片个数|推理时间(fps)|Box AP|下载|
|:-:|:-:|:-:|:-:|:-:|:-:| |:-:|:-:|:-:|:-:|:-:|:-:|
|baseline|608 |16|104.291|76.2|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)| |baseline|608 |16|104.291|76.2|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|蒸馏后|608 |16|106.914|79.0|[下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_voc_distilled.tar)|
|baseline|416 |16|-|76.7|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)| |baseline|416 |16|-|76.7|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|蒸馏后|416 |16|-|78.2|[下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_voc_distilled.tar)|
|baseline|320 |16|-|75.3|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)| |baseline|320 |16|-|75.3|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|蒸馏后|608 |16|106.914|79.0|| |蒸馏后|320 |16|-|75.5|[下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_voc_distilled.tar)|
|蒸馏后|416 |16|-|78.2||
|蒸馏后|320 |16|-|75.5||
> 蒸馏后的结果用ResNet34-YOLO-V3做teacher,4GPU总batch_size64训练90000 iter得到 > 蒸馏后的结果用ResNet34-YOLO-V3做teacher,4GPU总batch_size64训练90000 iter得到
### MobileNetV1-YOLO-V3-COCO
| FLOPS |输入尺寸|每张GPU图片个数|推理时间(fps)|Box AP|下载|
|:-:|:-:|:-:|:-:|:-:|:-:|
|baseline|608 |16|78.302|29.3|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|蒸馏后|608 |16|78.523|31.4|[下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar)|
|baseline|416 |16|-|29.3|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|蒸馏后|416 |16|-|30.0|[下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar)|
|baseline|320 |16|-|27.0|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|蒸馏后|320 |16|-|27.1|[下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/yolov3_mobilenetv1_coco_distilled.tar)|
> 蒸馏后的结果用ResNet34-YOLO-V3做teacher,4GPU总batch_size64训练600000 iter得到
...@@ -36,6 +36,92 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) ...@@ -36,6 +36,92 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def l2_distill(pairs, weight):
"""
Add l2 distillation losses composed of multi pairs of feature maps,
each pair of feature maps is the input of teacher and student's
yolov3_loss respectively
"""
loss = []
for pair in pairs:
loss.append(l2_loss(pair[0], pair[1]))
loss = fluid.layers.sum(loss)
weighted_loss = loss * weight
return weighted_loss
def split_distill(split_output_names, weight):
"""
Add fine grained distillation losses.
Each loss is composed by distill_reg_loss, distill_cls_loss and
distill_obj_loss
"""
student_var = []
for name in split_output_names:
student_var.append(fluid.default_main_program().global_block().var(
name))
s_x0, s_y0, s_w0, s_h0, s_obj0, s_cls0 = student_var[0:6]
s_x1, s_y1, s_w1, s_h1, s_obj1, s_cls1 = student_var[6:12]
s_x2, s_y2, s_w2, s_h2, s_obj2, s_cls2 = student_var[12:18]
teacher_var = []
for name in split_output_names:
teacher_var.append(fluid.default_main_program().global_block().var(
'teacher_' + name))
t_x0, t_y0, t_w0, t_h0, t_obj0, t_cls0 = teacher_var[0:6]
t_x1, t_y1, t_w1, t_h1, t_obj1, t_cls1 = teacher_var[6:12]
t_x2, t_y2, t_w2, t_h2, t_obj2, t_cls2 = teacher_var[12:18]
def obj_weighted_reg(sx, sy, sw, sh, tx, ty, tw, th, tobj):
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
sx, fluid.layers.sigmoid(tx))
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
sy, fluid.layers.sigmoid(ty))
loss_w = fluid.layers.abs(sw - tw)
loss_h = fluid.layers.abs(sh - th)
loss = fluid.layers.sum([loss_x, loss_y, loss_w, loss_h])
weighted_loss = fluid.layers.reduce_mean(loss *
fluid.layers.sigmoid(tobj))
return weighted_loss
def obj_weighted_cls(scls, tcls, tobj):
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
scls, fluid.layers.sigmoid(tcls))
weighted_loss = fluid.layers.reduce_mean(
fluid.layers.elementwise_mul(
loss, fluid.layers.sigmoid(tobj), axis=0))
return weighted_loss
def obj_loss(sobj, tobj):
obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
return loss
distill_reg_loss0 = obj_weighted_reg(s_x0, s_y0, s_w0, s_h0, t_x0, t_y0,
t_w0, t_h0, t_obj0)
distill_reg_loss1 = obj_weighted_reg(s_x1, s_y1, s_w1, s_h1, t_x1, t_y1,
t_w1, t_h1, t_obj1)
distill_reg_loss2 = obj_weighted_reg(s_x2, s_y2, s_w2, s_h2, t_x2, t_y2,
t_w2, t_h2, t_obj2)
distill_reg_loss = fluid.layers.sum(
[distill_reg_loss0, distill_reg_loss1, distill_reg_loss2])
distill_cls_loss0 = obj_weighted_cls(s_cls0, t_cls0, t_obj0)
distill_cls_loss1 = obj_weighted_cls(s_cls1, t_cls1, t_obj1)
distill_cls_loss2 = obj_weighted_cls(s_cls2, t_cls2, t_obj2)
distill_cls_loss = fluid.layers.sum(
[distill_cls_loss0, distill_cls_loss1, distill_cls_loss2])
distill_obj_loss0 = obj_loss(s_obj0, t_obj0)
distill_obj_loss1 = obj_loss(s_obj1, t_obj1)
distill_obj_loss2 = obj_loss(s_obj2, t_obj2)
distill_obj_loss = fluid.layers.sum(
[distill_obj_loss0, distill_obj_loss1, distill_obj_loss2])
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss) * weight
return loss
def main(): def main():
env = os.environ env = os.environ
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
...@@ -69,6 +155,31 @@ def main(): ...@@ -69,6 +155,31 @@ def main():
train_feed_vars, train_loader = model.build_inputs(**inputs_def) train_feed_vars, train_loader = model.build_inputs(**inputs_def)
train_fetches = model.train(train_feed_vars) train_fetches = model.train(train_feed_vars)
loss = train_fetches['loss'] loss = train_fetches['loss']
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
ignore_params = cfg.finetune_exclude_pretrained_params \
if 'finetune_exclude_pretrained_params' in cfg else []
start_iter = 0
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe,
fluid.default_main_program(),
FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
elif cfg.pretrain_weights and fuse_bn and not ignore_params:
checkpoint.load_and_fusebn(exe,
fluid.default_main_program(),
cfg.pretrain_weights)
elif cfg.pretrain_weights:
checkpoint.load_params(
exe,
fluid.default_main_program(),
cfg.pretrain_weights,
ignore_params=ignore_params)
train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
devices_num, cfg)
train_loader.set_sample_list_generator(train_reader, place)
# get all student variables # get all student variables
student_vars = [] student_vars = []
for v in fluid.default_main_program().list_vars(): for v in fluid.default_main_program().list_vars():
...@@ -102,6 +213,7 @@ def main(): ...@@ -102,6 +213,7 @@ def main():
extra_keys) extra_keys)
teacher_cfg = load_config(FLAGS.teacher_config) teacher_cfg = load_config(FLAGS.teacher_config)
merge_config(FLAGS.opt)
teacher_arch = teacher_cfg.architecture teacher_arch = teacher_cfg.architecture
teacher_program = fluid.Program() teacher_program = fluid.Program()
teacher_startup_program = fluid.Program() teacher_startup_program = fluid.Program()
...@@ -135,33 +247,34 @@ def main(): ...@@ -135,33 +247,34 @@ def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
data_name_map = { data_name_map = {
'target0': 'target0',
'target1': 'target1',
'target2': 'target2',
'image': 'image', 'image': 'image',
'gt_bbox': 'gt_bbox', 'gt_bbox': 'gt_bbox',
'gt_class': 'gt_class', 'gt_class': 'gt_class',
'gt_score': 'gt_score' 'gt_score': 'gt_score'
} }
distill_prog = merge(teacher_program, merge(teacher_program, fluid.default_main_program(), data_name_map, place)
fluid.default_main_program(), data_name_map, place)
yolo_output_names = [
'strided_slice_0.tmp_0', 'strided_slice_1.tmp_0',
'strided_slice_2.tmp_0', 'strided_slice_3.tmp_0',
'strided_slice_4.tmp_0', 'transpose_0.tmp_0', 'strided_slice_5.tmp_0',
'strided_slice_6.tmp_0', 'strided_slice_7.tmp_0',
'strided_slice_8.tmp_0', 'strided_slice_9.tmp_0', 'transpose_2.tmp_0',
'strided_slice_10.tmp_0', 'strided_slice_11.tmp_0',
'strided_slice_12.tmp_0', 'strided_slice_13.tmp_0',
'strided_slice_14.tmp_0', 'transpose_4.tmp_0'
]
distill_weight = 100
distill_pairs = [['teacher_conv2d_6.tmp_1', 'conv2d_20.tmp_1'], distill_pairs = [['teacher_conv2d_6.tmp_1', 'conv2d_20.tmp_1'],
['teacher_conv2d_14.tmp_1', 'conv2d_28.tmp_1'], ['teacher_conv2d_14.tmp_1', 'conv2d_28.tmp_1'],
['teacher_conv2d_22.tmp_1', 'conv2d_36.tmp_1']] ['teacher_conv2d_22.tmp_1', 'conv2d_36.tmp_1']]
def l2_distill(pairs, weight): distill_loss = l2_distill(
""" distill_pairs, 100) if not cfg.use_fine_grained_loss else split_distill(
Add l2 distillation losses composed of multi pairs of feature maps, yolo_output_names, 1000)
each pair of feature maps is the input of teacher and student's
yolov3_loss respectively
"""
loss = []
for pair in pairs:
loss.append(l2_loss(pair[0], pair[1]))
loss = fluid.layers.sum(loss)
weighted_loss = loss * weight
return weighted_loss
distill_loss = l2_distill(distill_pairs, distill_weight)
loss = distill_loss + loss loss = distill_loss + loss
lr_builder = create('LearningRate') lr_builder = create('LearningRate')
optim_builder = create('OptimizerBuilder') optim_builder = create('OptimizerBuilder')
...@@ -170,8 +283,6 @@ def main(): ...@@ -170,8 +283,6 @@ def main():
opt.minimize(loss) opt.minimize(loss)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
checkpoint.load_params(exe,
fluid.default_main_program(), cfg.pretrain_weights)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False build_strategy.fuse_all_reduce_ops = False
...@@ -188,32 +299,14 @@ def main(): ...@@ -188,32 +299,14 @@ def main():
# local execution scopes can be deleted after each iteration. # local execution scopes can be deleted after each iteration.
exec_strategy.num_iteration_per_drop_scope = 1 exec_strategy.num_iteration_per_drop_scope = 1
parallel_main = fluid.CompiledProgram(distill_prog).with_data_parallel( parallel_main = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
ignore_params = cfg.finetune_exclude_pretrained_params \
if 'finetune_exclude_pretrained_params' in cfg else []
start_iter = 0
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, distill_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
elif cfg.pretrain_weights and fuse_bn and not ignore_params:
checkpoint.load_and_fusebn(exe, distill_prog, cfg.pretrain_weights)
elif cfg.pretrain_weights:
checkpoint.load_params(
exe,
distill_prog,
cfg.pretrain_weights,
ignore_params=ignore_params)
train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
devices_num, cfg)
train_loader.set_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer # whether output bbox is normalized in model output layer
is_bbox_normalized = False is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \ if hasattr(model, 'is_bbox_normalized') and \
...@@ -240,7 +333,8 @@ def main(): ...@@ -240,7 +333,8 @@ def main():
if step_id % cfg.snapshot_iter == 0 and step_id != 0 or step_id == cfg.max_iters - 1: if step_id % cfg.snapshot_iter == 0 and step_id != 0 or step_id == cfg.max_iters - 1:
save_name = str( save_name = str(
step_id) if step_id != cfg.max_iters - 1 else "model_final" step_id) if step_id != cfg.max_iters - 1 else "model_final"
checkpoint.save(exe, distill_prog, checkpoint.save(exe,
fluid.default_main_program(),
os.path.join(save_dir, save_name)) os.path.join(save_dir, save_name))
# eval # eval
results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys, results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys,
...@@ -254,7 +348,8 @@ def main(): ...@@ -254,7 +348,8 @@ def main():
if box_ap_stats[0] > best_box_ap_list[0]: if box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0] best_box_ap_list[0] = box_ap_stats[0]
best_box_ap_list[1] = step_id best_box_ap_list[1] = step_id
checkpoint.save(exe, distill_prog, checkpoint.save(exe,
fluid.default_main_program(),
os.path.join("./", "best_model")) os.path.join("./", "best_model"))
logger.info("Best test box ap: {}, in step: {}".format( logger.info("Best test box ap: {}, in step: {}".format(
best_box_ap_list[0], best_box_ap_list[1])) best_box_ap_list[0], best_box_ap_list[1]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册