README.md 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
>运行该示例前请安装Paddle1.6或更高版本

# 检测模型蒸馏示例

## 概述

该示例使用PaddleSlim提供的[蒸馏策略](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#3-蒸馏)对检测库中的模型进行蒸馏训练。
在阅读该示例前,建议您先了解以下内容:

- [检测库的常规训练方法](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleDetection)
- [PaddleSlim使用文档](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md)


## 配置文件说明

关于配置文件如何编写您可以参考:

- [PaddleSlim配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8)
- [蒸馏策略配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#23-蒸馏)

这里以ResNet34-YoloV3蒸馏MobileNetV1-YoloV3模型为例,首先,为了对`student model``teacher model`有个总体的认识,从而进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variable)的名称和形状:

```python
# 观察student model的Variable
for v in fluid.default_main_program().list_vars():
    if "py_reader" not in v.name and "double_buffer" not in v.name and "generated_var" not in v.name:
        print(v.name, v.shape)
# 观察teacher model的Variable
for v in teacher_program.list_vars():
    print(v.name, v.shape)
```

经过对比可以发现,`student model``teacher model`的部分中间结果分别为:

```bash
# student model
conv2d_15.tmp_0
# teacher model
teacher_teacher_conv2d_1.tmp_0
```


所以,我们用`l2_distiller`对这两个特征图做蒸馏。在配置文件中进行如下配置:

```yaml
distillers:
    l2_distiller:
        class: 'L2Distiller'
        teacher_feature_map: 'teacher_teacher_conv2d_1.tmp_0'
        student_feature_map: 'conv2d_15.tmp_0'
        distillation_loss_weight: 1
strategies:
    distillation_strategy:
        class: 'DistillationStrategy'
        distillers: ['l2_distiller']
        start_epoch: 0
        end_epoch: 270
```

我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有`FSP_loss`, `L2_loss``softmax_with_cross_entropy_loss`

## 训练

根据[PaddleDetection/tools/train.py](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/PaddleDetection/tools/train.py)编写压缩脚本compress.py。
在该脚本中定义了Compressor对象,用于执行压缩任务。




您可以通过运行脚本`run.sh`运行该示例。


### 保存断点(checkpoint)

如果在配置文件中设置了`checkpoint_path`, 则在蒸馏任务执行过程中会自动保存断点,当任务异常中断时,
重启任务会自动从`checkpoint_path`路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复,
需要修改配置文件中的`checkpoint_path`,或者将`checkpoint_path`路径下文件清空。

>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。


## 评估

如果在配置文件中设置了`checkpoint_path`,则每个epoch会保存一个压缩后的用于评估的模型,
该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__``__params__`两个文件。
其中,`__model__`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。

如果不需要保存评估模型,可以在定义Compressor对象时,将`save_eval_model`选项设置为False(默认为True)。

B
Bai Yifan 已提交
90 91 92 93 94 95 96 97 98 99
运行命令为:
```
python ../eval.py \
    --model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
    --model_name __model__ \
    --params_name __params__ \
    -c ../../configs/yolov3_mobilenet_v1_voc.yml \
    -d "../../dataset/voc"
```

100 101 102 103 104 105 106 107 108 109 110 111 112 113
## 预测

如果在配置文件中设置了`checkpoint_path`,并且在定义Compressor对象时指定了`prune_infer_model`选项,则每个epoch都会
保存一个`inference model`。该模型是通过删除eval_program中多余的operators而得到的。

该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__.infer``__params__`两个文件。
其中,`__model__.infer`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。

更多关于`prune_infer_model`选项的介绍,请参考:[Compressor介绍](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#121-%E5%A6%82%E4%BD%95%E6%94%B9%E5%86%99%E6%99%AE%E9%80%9A%E8%AE%AD%E7%BB%83%E8%84%9A%E6%9C%AC)

### python预测

在脚本<a href="../infer.py">slim/infer.py</a>中展示了如何使用fluid python API加载使用预测模型进行预测。

B
Bai Yifan 已提交
114 115 116 117 118 119 120 121 122 123
运行命令为:
```
python ../infer.py \
    --model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
    --model_name __model__ \
    --params_name __params__ \
    -c ../../configs/yolov3_mobilenet_v1_voc.yml \
    --infer_dir ../../demo
```

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
### PaddleLite

该示例中产出的预测(inference)模型可以直接用PaddleLite进行加载使用。
关于PaddleLite如何使用,请参考:[PaddleLite使用文档](https://github.com/PaddlePaddle/Paddle-Lite/wiki#%E4%BD%BF%E7%94%A8)

## 示例结果

### MobileNetV1-YOLO-V3

| FLOPS |Box AP|
|---|---|
|baseline|76.2     |
|蒸馏后|- |


## FAQ