未验证 提交 1a843c55 编写于 作者: B Bai Yifan 提交者: GitHub

Update distillation demo (#128)

* update distillation demo
上级 18645132
>运行该示例前请安装Paddle1.6或更高版本 >运行该示例前请安装PaddleSlim和Paddle1.6或更高版本
# 检测模型蒸馏示例 # 检测模型蒸馏示例
## 概述 ## 概述
该示例使用PaddleSlim提供的[蒸馏策略](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#3-蒸馏)对检测库中的模型进行蒸馏训练。 该示例使用PaddleSlim提供的[蒸馏策略](https://paddlepaddle.github.io/PaddleSlim/algo/algo/#3)对检测库中的模型进行蒸馏训练。
在阅读该示例前,建议您先了解以下内容: 在阅读该示例前,建议您先了解以下内容:
- [检测库的常规训练方法](https://github.com/PaddlePaddle/PaddleDetection) - [检测库的常规训练方法](https://github.com/PaddlePaddle/PaddleDetection)
- [PaddleSlim使用文档](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md) - [PaddleSlim蒸馏API文档](https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/)
## 安装PaddleSlim
可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim
## 配置文件说明 ## 蒸馏策略说明
关于配置文件如何编写您可以参考: 关于蒸馏API如何使用您可以参考PaddleSlim蒸馏API文档
- [PaddleSlim配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8) 这里以ResNet34-YoloV3蒸馏训练MobileNetV1-YoloV3模型为例,首先,为了对`student model``teacher model`有个总体的认识,进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variables)的名称和形状:
- [蒸馏策略配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#23-蒸馏)
这里以ResNet34-YoloV3蒸馏MobileNetV1-YoloV3模型为例,首先,为了对`student model``teacher model`有个总体的认识,从而进一步确认蒸馏的对象,我们通过以下命令分别观察两个网络变量(Variable)的名称和形状:
```python ```python
# 观察student model的Variable # 观察student model的Variables
student_vars = []
for v in fluid.default_main_program().list_vars(): for v in fluid.default_main_program().list_vars():
if "py_reader" not in v.name and "double_buffer" not in v.name and "generated_var" not in v.name: try:
print(v.name, v.shape) student_vars.append((v.name, v.shape))
# 观察teacher model的Variable except:
pass
print("="*50+"student_model_vars"+"="*50)
print(student_vars)
# 观察teacher model的Variables
teacher_vars = []
for v in teacher_program.list_vars(): for v in teacher_program.list_vars():
print(v.name, v.shape) try:
teacher_vars.append((v.name, v.shape))
except:
pass
print("="*50+"teacher_model_vars"+"="*50)
print(teacher_vars)
``` ```
经过对比可以发现,`student model``teacher model`的部分中间结果分别为: 经过对比可以发现,`student model``teacher model`输入到3个`yolov3_loss`的特征图分别为:
```bash ```bash
# student model # student model
conv2d_15.tmp_0 conv2d_20.tmp_1, conv2d_28.tmp_1, conv2d_36.tmp_1
# teacher model # teacher model
teacher_teacher_conv2d_1.tmp_0 conv2d_6.tmp_1, conv2d_14.tmp_1, conv2d_22.tmp_1
``` ```
所以,我们用`l2_distiller`对这两个特征图做蒸馏。在配置文件中进行如下配置: 它们形状两两相同,且分别处于两个网络的输出部分。所以,我们用`l2_loss`对这几个特征图两两对应添加蒸馏loss。需要注意的是,teacher的Variable在merge过程中被自动添加了一个`name_prefix`,所以这里也需要加上这个前缀`"teacher_"`,merge过程请参考[蒸馏API文档](https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/#merge)
```yaml ```python
distillers: dist_loss_1 = l2_loss('teacher_conv2d_6.tmp_1', 'conv2d_20.tmp_1')
l2_distiller: dist_loss_2 = l2_loss('teacher_conv2d_14.tmp_1', 'conv2d_28.tmp_1')
class: 'L2Distiller' dist_loss_3 = l2_loss('teacher_conv2d_22.tmp_1', 'conv2d_36.tmp_1')
teacher_feature_map: 'teacher_teacher_conv2d_1.tmp_0'
student_feature_map: 'conv2d_15.tmp_0'
distillation_loss_weight: 1
strategies:
distillation_strategy:
class: 'DistillationStrategy'
distillers: ['l2_distiller']
start_epoch: 0
end_epoch: 270
``` ```
我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有`FSP_loss`, `L2_loss``softmax_with_cross_entropy_loss` 我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有`FSP_loss`, `L2_loss`, `softmax_with_cross_entropy_loss` 以及自定义的任何loss
## 训练 ## 训练
根据[PaddleDetection/tools/train.py](https://github.com/PaddlePaddle/PaddleDetection/tree/master/tools/train.py)编写压缩脚本compress.py 根据[PaddleDetection/tools/train.py](../../tools/train.py)编写压缩脚本`distill.py`
在该脚本中定义了Compressor对象,用于执行压缩任务。 在该脚本中定义了teacher_model和student_model,用teacher_model的输出指导student_model的训练
### 执行示例
step1: 设置GPU卡
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
```
您可以通过运行脚本`run.sh`运行该示例。 step2: 开始训练
```bash
python slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1_voc.yml \
-t configs/yolov3_r34_voc.yml \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
### 保存断点(checkpoint) 如果要调整训练卡数,需要调整配置文件`yolov3_mobilenet_v1_voc.yml`中的以下参数:
如果在配置文件中设置了`checkpoint_path`, 则在蒸馏任务执行过程中会自动保存断点,当任务异常中断时, - **max_iters:** 训练过程迭代总步数。
重启任务会自动从`checkpoint_path`路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复, - **YOLOv3Loss.batch_size:** 该参数表示单张GPU卡上的`batch_size`, 总`batch_size`是GPU卡数乘以这个值, `batch_size`的设定受限于显存大小。
需要修改配置文件中的`checkpoint_path`,或者将`checkpoint_path`路径下文件清空。 - **LeaningRate.base_lr:** 根据多卡的总`batch_size`调整`base_lr`,两者大小正相关,可以简单的按比例进行调整。
- **LearningRate.schedulers.PiecewiseDecay.milestones:** 请根据batch size的变化对其调整。
- **LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:** 请根据batch size的变化对其进行调整。
>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。 以下为4卡训练示例,通过命令行覆盖`yolov3_mobilenet_v1_voc.yml`中的参数:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3
python slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1_voc.yml \
-t configs/yolov3_r34_voc.yml \
-o YoloTrainFeed.batch_size=16 \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
## 评估
如果在配置文件中设置了`checkpoint_path`,则每个epoch会保存一个压缩后的用于评估的模型,
该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__``__params__`两个文件。
其中,`__model__`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。
如果不需要保存评估模型,可以在定义Compressor对象时,将`save_eval_model`选项设置为False(默认为True)。
运行命令为: ### 保存断点(checkpoint)
```
python ../eval.py \
--model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
--model_name __model__ \
--params_name __params__ \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-d "../../dataset/voc"
```
## 预测 蒸馏任务执行过程中会自动保存断点。如果需要从断点继续训练请用`-r`参数指定checkpoint路径,示例如下:
如果在配置文件中设置了`checkpoint_path`,并且在定义Compressor对象时指定了`prune_infer_model`选项,则每个epoch都会 ```bash
保存一个`inference model`。该模型是通过删除eval_program中多余的operators而得到的。 python -u slim/distillation/distill.py \
-c configs/yolov3_mobilenet_v1_voc.yml \
-t configs/yolov3_r34_voc.yml \
-r output/yolov3_mobilenet_v1_voc/10000 \
--teacher_pretrained https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34_voc.tar
```
该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__.infer``__params__`两个文件。
其中,`__model__.infer`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。
更多关于`prune_infer_model`选项的介绍,请参考:[Compressor介绍](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#121-%E5%A6%82%E4%BD%95%E6%94%B9%E5%86%99%E6%99%AE%E9%80%9A%E8%AE%AD%E7%BB%83%E8%84%9A%E6%9C%AC)
### python预测
在脚本<a href="../infer.py">slim/infer.py</a>中展示了如何使用fluid python API加载使用预测模型进行预测。 ## 评估
每隔`snap_shot_iter`步后会保存一个checkpoint模型可以用于评估,使用PaddleDetection目录下[tools/eval.py](../../tools/eval.py)评估脚本,并指定`weights`为训练得到的模型路径
运行命令为: 运行命令为:
``` ```bash
python ../infer.py \ export CUDA_VISIBLE_DEVICES=0
--model_path ${checkpoint_path}/${epoch_id}/eval_model/ \ python -u tools/eval.py -c configs/yolov3_mobilenet_v1_voc.yml \
--model_name __model__.infer \ -o weights=output/yolov3_mobilenet_v1_voc/model_final \
--params_name __params__ \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
--infer_dir ../../demo
``` ```
### PaddleLite ## 预测
该示例中产出的预测(inference)模型可以直接用PaddleLite进行加载使用。 每隔`snap_shot_iter`步后保存的checkpoint模型也可以用于预测,使用PaddleDetection目录下[tools/infer.py](../../tools/infer.py)评估脚本,并指定`weights`为训练得到的模型路径
关于PaddleLite如何使用,请参考:[PaddleLite使用文档](https://github.com/PaddlePaddle/Paddle-Lite/wiki#%E4%BD%BF%E7%94%A8)
## 示例结果 ### Python预测
>当前release的结果并非超参调优后的最好结果,仅做示例参考,后续我们会优化当前结果。 运行命令为:
```
export CUDA_VISIBLE_DEVICES=0
python -u tools/infer.py -c configs/yolov3_mobilenet_v1_voc.yml \
--infer_img=demo/000000570688.jpg \
--output_dir=infer_output/ \
--draw_threshold=0.5 \
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
### MobileNetV1-YOLO-V3 ## 示例结果
| FLOPS |Box AP| ### MobileNetV1-YOLO-V3-VOC
|---|---|
|baseline|76.2 |
|蒸馏后|- |
| FLOPS |输入尺寸|每张GPU图片个数|推理时间(fps)|Box AP|下载|
|:-:|:-:|:-:|:-:|:-:|:-:|
|baseline|608 |16|104.291|76.2|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|baseline|416 |16|-|76.7|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|baseline|320 |16|-|75.3|[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar)|
|蒸馏后|608 |16|106.914|79.0||
|蒸馏后|416 |16|-|78.2||
|蒸馏后|320 |16|-|75.5||
## FAQ > 蒸馏后的结果用ResNet34-YOLO-V3做teacher,4GPU总batch_size64训练90000 iter得到
...@@ -17,38 +17,18 @@ from __future__ import division ...@@ -17,38 +17,18 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import time
import multiprocessing
import numpy as np import numpy as np
from collections import deque, OrderedDict from collections import OrderedDict
from paddle.fluid.contrib.slim.core import Compressor from paddleslim.dist.single_distiller import merge, l2_loss
from paddle.fluid.framework import IrGraph
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
from paddle import fluid from paddle import fluid
import sys
sys.path.append("../../")
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.data_feed import create_reader from ppdet.data.reader import create_reader
from ppdet.utils.eval_utils import parse_fetches, eval_results from ppdet.utils.eval_utils import parse_fetches, eval_results, eval_run
from ppdet.utils.stats import TrainingStats from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu from ppdet.utils.check import check_gpu
import ppdet.utils.checkpoint as checkpoint import ppdet.utils.checkpoint as checkpoint
from ppdet.modeling.model_input import create_feed
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
...@@ -56,56 +36,8 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) ...@@ -56,56 +36,8 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def eval_run(exe, compile_program, reader, keys, values, cls, test_feed):
"""
Run evaluation program, return program outputs.
"""
iter_id = 0
results = []
if len(cls) != 0:
values = []
for i in range(len(cls)):
_, accum_map = cls[i].get_map_var()
cls[i].reset(exe)
values.append(accum_map)
images_num = 0
start_time = time.time()
has_bbox = 'bbox' in keys
for data in reader():
data = test_feed.feed(data)
feed_data = {'image': data['image'], 'im_size': data['im_size']}
outs = exe.run(compile_program,
feed=feed_data,
fetch_list=[values[0]],
return_numpy=False)
outs.append(data['gt_box'])
outs.append(data['gt_label'])
outs.append(data['is_difficult'])
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
results.append(res)
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
iter_id += 1
images_num += len(res['bbox'][1][0]) if has_bbox else 1
logger.info('Test finish iter {}'.format(iter_id))
end_time = time.time()
fps = images_num / (end_time - start_time)
if has_bbox:
logger.info('Total number of images: {}, inference time: {} fps.'.
format(images_num, fps))
else:
logger.info('Total iteration: {}, inference time: {} batch/s.'.format(
images_num, fps))
return results
def main(): def main():
env = os.environ
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
if 'architecture' in cfg: if 'architecture' in cfg:
main_arch = cfg.architecture main_arch = cfg.architecture
...@@ -122,112 +54,60 @@ def main(): ...@@ -122,112 +54,60 @@ def main():
if cfg.use_gpu: if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count() devices_num = fluid.core.get_cuda_device_count()
else: else:
devices_num = int( devices_num = int(os.environ.get('CPU_NUM', 1))
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
if 'train_feed' not in cfg:
train_feed = create(main_arch + 'TrainFeed')
else:
train_feed = create(cfg.train_feed)
if 'eval_feed' not in cfg: if 'FLAGS_selected_gpus' in env:
eval_feed = create(main_arch + 'EvalFeed') device_id = int(env['FLAGS_selected_gpus'])
else: else:
eval_feed = create(cfg.eval_feed) device_id = 0
place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
lr_builder = create('LearningRate')
optim_builder = create('OptimizerBuilder')
# build program # build program
model = create(main_arch) model = create(main_arch)
_, train_feed_vars = create_feed(train_feed, True) inputs_def = cfg['TrainReader']['inputs_def']
train_feed_vars, train_loader = model.build_inputs(**inputs_def)
train_fetches = model.train(train_feed_vars) train_fetches = model.train(train_feed_vars)
loss = train_fetches['loss'] loss = train_fetches['loss']
lr = lr_builder() # get all student variables
opt = optim_builder(lr) student_vars = []
opt.minimize(loss) for v in fluid.default_main_program().list_vars():
#for v in fluid.default_main_program().list_vars(): try:
# if "py_reader" not in v.name and "double_buffer" not in v.name and "generated_var" not in v.name: student_vars.append((v.name, v.shape))
# print(v.name, v.shape) except:
pass
cfg.max_iters = 258 # uncomment the following lines to print all student variables
train_reader = create_reader(train_feed, cfg.max_iters, FLAGS.dataset_dir) # print("="*50 + "student_model_vars" + "="*50)
# print(student_vars)
exe.run(fluid.default_startup_program())
# parse train fetches
train_keys, train_values, _ = parse_fetches(train_fetches)
train_keys.append('lr')
train_values.append(lr.name)
train_fetch_list = []
for k, v in zip(train_keys, train_values):
train_fetch_list.append((k, v))
print("train_fetch_list: {}".format(train_fetch_list))
eval_prog = fluid.Program() eval_prog = fluid.Program()
startup_prog = fluid.Program() with fluid.program_guard(eval_prog, fluid.default_startup_program()):
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model = create(main_arch) model = create(main_arch)
_, test_feed_vars = create_feed(eval_feed, True) inputs_def = cfg['EvalReader']['inputs_def']
test_feed_vars, eval_loader = model.build_inputs(**inputs_def)
fetches = model.eval(test_feed_vars) fetches = model.eval(test_feed_vars)
eval_prog = eval_prog.clone(True) eval_prog = eval_prog.clone(True)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) eval_reader = create_reader(cfg.EvalReader)
test_data_feed = fluid.DataFeeder(test_feed_vars.values(), place) eval_loader.set_sample_list_generator(eval_reader, place)
# parse eval fetches # parse eval fetches
extra_keys = [] extra_keys = []
if cfg.metric == 'COCO': if cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape'] extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg.metric == 'VOC': if cfg.metric == 'VOC':
extra_keys = ['gt_box', 'gt_label', 'is_difficult'] extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys) extra_keys)
eval_fetch_list = []
for k, v in zip(eval_keys, eval_values):
eval_fetch_list.append((k, v))
print("eval_fetch_list: {}".format(eval_fetch_list))
exe.run(startup_prog)
checkpoint.load_params(exe,
fluid.default_main_program(), cfg.pretrain_weights)
best_box_ap_list = []
def eval_func(program, scope):
results = eval_run(exe, program, eval_reader, eval_keys, eval_values,
eval_cls, test_data_feed)
resolution = None
is_bbox_normalized = False
if 'mask' in results[0]:
resolution = model.mask_head.resolution
box_ap_stats = eval_results(results, eval_feed, cfg.metric,
cfg.num_classes, resolution,
is_bbox_normalized, FLAGS.output_eval)
if len(best_box_ap_list) == 0:
best_box_ap_list.append(box_ap_stats[0])
elif box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0]
logger.info("Best test box ap: {}".format(best_box_ap_list[0]))
return best_box_ap_list[0]
test_feed = [('image', test_feed_vars['image'].name),
('im_size', test_feed_vars['im_size'].name)]
teacher_cfg = load_config(FLAGS.teacher_config) teacher_cfg = load_config(FLAGS.teacher_config)
teacher_arch = teacher_cfg.architecture teacher_arch = teacher_cfg.architecture
teacher_programs = []
teacher_program = fluid.Program() teacher_program = fluid.Program()
teacher_startup_program = fluid.Program() teacher_startup_program = fluid.Program()
with fluid.program_guard(teacher_program, teacher_startup_program): with fluid.program_guard(teacher_program, teacher_startup_program):
with fluid.unique_name.guard('teacher_'): with fluid.unique_name.guard():
teacher_feed_vars = OrderedDict() teacher_feed_vars = OrderedDict()
for name, var in train_feed_vars.items(): for name, var in train_feed_vars.items():
teacher_feed_vars[name] = teacher_program.global_block( teacher_feed_vars[name] = teacher_program.global_block(
...@@ -235,64 +115,154 @@ def main(): ...@@ -235,64 +115,154 @@ def main():
var, force_persistable=False) var, force_persistable=False)
model = create(teacher_arch) model = create(teacher_arch)
train_fetches = model.train(teacher_feed_vars) train_fetches = model.train(teacher_feed_vars)
#print("="*50+"teacher_model_params"+"="*50) teacher_loss = train_fetches['loss']
#for v in teacher_program.list_vars():
# print(v.name, v.shape) # get all teacher variables
#return teacher_vars = []
for v in teacher_program.list_vars():
try:
teacher_vars.append((v.name, v.shape))
except:
pass
# uncomment the following lines to print all teacher variables
# print("="*50 + "teacher_model_vars" + "="*50)
# print(teacher_vars)
exe.run(teacher_startup_program) exe.run(teacher_startup_program)
assert FLAGS.teacher_pretrained and os.path.exists( assert FLAGS.teacher_pretrained, "teacher_pretrained should be set"
FLAGS.teacher_pretrained checkpoint.load_params(exe, teacher_program, FLAGS.teacher_pretrained)
), "teacher_pretrained should be set when teacher_model is not None." teacher_program = teacher_program.clone(for_test=True)
def if_exist(var): cfg = load_config(FLAGS.config)
return os.path.exists(os.path.join(FLAGS.teacher_pretrained, var.name)) data_name_map = {
'image': 'image',
fluid.io.load_vars( 'gt_bbox': 'gt_bbox',
exe, 'gt_class': 'gt_class',
FLAGS.teacher_pretrained, 'gt_score': 'gt_score'
main_program=teacher_program, }
predicate=if_exist) distill_prog = merge(teacher_program,
fluid.default_main_program(), data_name_map, place)
teacher_programs.append(teacher_program.clone(for_test=True))
distill_weight = 100
com = Compressor( distill_pairs = [['teacher_conv2d_6.tmp_1', 'conv2d_20.tmp_1'],
place, ['teacher_conv2d_14.tmp_1', 'conv2d_28.tmp_1'],
fluid.global_scope(), ['teacher_conv2d_22.tmp_1', 'conv2d_36.tmp_1']]
fluid.default_main_program(),
train_reader=train_reader, def l2_distill(pairs, weight):
train_feed_list=[(key, value.name) """
for key, value in train_feed_vars.items()], Add l2 distillation losses composed of multi pairs of feature maps,
train_fetch_list=train_fetch_list, each pair of feature maps is the input of teacher and student's
eval_program=eval_prog, yolov3_loss respectively
eval_reader=eval_reader, """
eval_feed_list=test_feed, loss = []
eval_func={'map': eval_func}, for pair in pairs:
eval_fetch_list=eval_fetch_list[0:1], loss.append(l2_loss(pair[0], pair[1]))
save_eval_model=True, loss = fluid.layers.sum(loss)
prune_infer_model=[["image", "im_size"], ["multiclass_nms_0.tmp_0"]], weighted_loss = loss * weight
teacher_programs=teacher_programs, return weighted_loss
train_optimizer=None,
distiller_optimizer=opt, distill_loss = l2_distill(distill_pairs, distill_weight)
log_period=20) loss = distill_loss + loss
com.config(FLAGS.slim_file) lr_builder = create('LearningRate')
com.run() optim_builder = create('OptimizerBuilder')
lr = lr_builder()
opt = optim_builder(lr)
opt.minimize(loss)
exe.run(fluid.default_startup_program())
checkpoint.load_params(exe,
fluid.default_main_program(), cfg.pretrain_weights)
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
build_strategy.fuse_all_optimizer_ops = False
build_strategy.fuse_elewise_add_act_ops = True
# only enable sync_bn in multi GPU devices
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
and cfg.use_gpu
exec_strategy = fluid.ExecutionStrategy()
# iteration number when CompiledProgram tries to drop local execution scopes.
# Set it to be 1 to save memory usages, so that unused variables in
# local execution scopes can be deleted after each iteration.
exec_strategy.num_iteration_per_drop_scope = 1
parallel_main = fluid.CompiledProgram(distill_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
ignore_params = cfg.finetune_exclude_pretrained_params \
if 'finetune_exclude_pretrained_params' in cfg else []
start_iter = 0
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, distill_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
elif cfg.pretrain_weights and fuse_bn and not ignore_params:
checkpoint.load_and_fusebn(exe, distill_prog, cfg.pretrain_weights)
elif cfg.pretrain_weights:
checkpoint.load_params(
exe,
distill_prog,
cfg.pretrain_weights,
ignore_params=ignore_params)
train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
devices_num, cfg)
train_loader.set_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \
callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized()
map_type = cfg.map_type if 'map_type' in cfg else '11point'
best_box_ap_list = [0.0, 0] #[map, iter]
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(cfg.save_dir, cfg_name)
train_loader.start()
for step_id in range(start_iter, cfg.max_iters):
teacher_loss_np, distill_loss_np, loss_np, lr_np = exe.run(
parallel_main,
fetch_list=[
'teacher_' + teacher_loss.name, distill_loss.name, loss.name,
lr.name
])
if step_id % cfg.log_iter == 0:
logger.info(
"step {} lr {:.6f}, loss {:.6f}, distill_loss {:.6f}, teacher_loss {:.6f}".
format(step_id, lr_np[0], loss_np[0], distill_loss_np[0],
teacher_loss_np[0]))
if step_id % cfg.snapshot_iter == 0 and step_id != 0 or step_id == cfg.max_iters - 1:
save_name = str(
step_id) if step_id != cfg.max_iters - 1 else "model_final"
checkpoint.save(exe, distill_prog,
os.path.join(save_dir, save_name))
# eval
results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys,
eval_values, eval_cls)
resolution = None
box_ap_stats = eval_results(results, cfg.metric, cfg.num_classes,
resolution, is_bbox_normalized,
FLAGS.output_eval, map_type,
cfg['EvalReader']['dataset'])
if box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0]
best_box_ap_list[1] = step_id
checkpoint.save(exe, distill_prog,
os.path.join("./", "best_model"))
logger.info("Best test box ap: {}, in step: {}".format(
best_box_ap_list[0], best_box_ap_list[1]))
train_loader.reset()
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgsParser() parser = ArgsParser()
parser.add_argument(
"-t",
"--teacher_config",
default=None,
type=str,
help="Config file of teacher architecture.")
parser.add_argument(
"-s",
"--slim_file",
default=None,
type=str,
help="Config file of PaddleSlim.")
parser.add_argument( parser.add_argument(
"-r", "-r",
"--resume_checkpoint", "--resume_checkpoint",
...@@ -300,10 +270,11 @@ if __name__ == '__main__': ...@@ -300,10 +270,11 @@ if __name__ == '__main__':
type=str, type=str,
help="Checkpoint path for resuming training.") help="Checkpoint path for resuming training.")
parser.add_argument( parser.add_argument(
"--eval", "-t",
action='store_true', "--teacher_config",
default=False, default=None,
help="Whether to perform evaluation in train") type=str,
help="Config file of teacher architecture.")
parser.add_argument( parser.add_argument(
"--teacher_pretrained", "--teacher_pretrained",
default=None, default=None,
...@@ -314,11 +285,5 @@ if __name__ == '__main__': ...@@ -314,11 +285,5 @@ if __name__ == '__main__':
default=None, default=None,
type=str, type=str,
help="Evaluation directory, default is current directory.") help="Evaluation directory, default is current directory.")
parser.add_argument(
"-d",
"--dataset_dir",
default=None,
type=str,
help="Dataset path, same as DataFeed.dataset.dataset_dir")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
#!/usr/bin/env bash
# download pretrain model
root_url="https://paddlemodels.bj.bcebos.com/object_detection"
yolov3_r34_voc="yolov3_r34_voc.tar"
pretrain_dir='./pretrain'
if [ ! -d ${pretrain_dir} ]; then
mkdir ${pretrain_dir}
fi
cd ${pretrain_dir}
if [ ! -f ${yolov3_r34_voc} ]; then
wget ${root_url}/${yolov3_r34_voc}
tar xf ${yolov3_r34_voc}
fi
cd -
# enable GC strategy
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
# for distillation
#-----------------
export CUDA_VISIBLE_DEVICES=0,1,2,3
# Fixing name conflicts in distillation
cd ${pretrain_dir}/yolov3_r34_voc
for files in $(ls teacher_*)
do mv $files ${files#*_}
done
for files in $(ls *)
do mv $files "teacher_"$files
done
cd -
python -u compress.py \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-t yolov3_resnet34.yml \
-s yolov3_mobilenet_v1_yolov3_resnet34_distillation.yml \
-o YoloTrainFeed.batch_size=64 \
-d ../../dataset/voc \
--teacher_pretrained ./pretrain/yolov3_r34_voc \
> yolov3_distallation.log 2>&1 &
tailf yolov3_distallation.log
version: 1.0
distillers:
l2_distiller:
class: 'L2Distiller'
teacher_feature_map: 'teacher_teacher_conv2d_1.tmp_0'
student_feature_map: 'conv2d_15.tmp_0'
distillation_loss_weight: 1
strategies:
distillation_strategy:
class: 'DistillationStrategy'
distillers: ['l2_distiller']
start_epoch: 0
end_epoch: 270
compressor:
epoch: 271
checkpoint_path: './checkpoints/'
strategies:
- distillation_strategy
architecture: YOLOv3
log_smooth_window: 20
metric: VOC
map_type: 11point
num_classes: 20
weight_prefix_name: teacher_
YOLOv3:
backbone: ResNet
yolo_head: YOLOv3Head
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 34
feature_maps: [3, 4, 5]
YOLOv3Head:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: false
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册