diff --git a/slim/prune/README.md b/slim/prune/README.md
index 16509624d48fcd42b2d9962bf499daa08f9e6247..31e537d5339dc6d883b9c4870591b6bec95d37c3 100644
--- a/slim/prune/README.md
+++ b/slim/prune/README.md
@@ -1,221 +1,62 @@
->运行该示例前请安装Paddle1.6或更高版本
+# 卷积层通道剪裁教程
-# 检测模型卷积通道剪裁示例
+请确保已正确[安装PaddleDetection](../../docs/INSTALL_cn.md)及其依赖。
-## 概述
+该文档介绍如何使用[PaddleSlim](https://paddlepaddle.github.io/PaddleSlim)的卷积通道剪裁接口对检测库中的模型的卷积层的通道数进行剪裁。
-该示例使用PaddleSlim提供的[卷积通道剪裁压缩策略](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86)对检测库中的模型进行压缩。
-在阅读该示例前,建议您先了解以下内容:
+在检测库中,可以直接调用`PaddleDetection/slim/prune/prune.py`脚本实现剪裁,在该脚本中调用了PaddleSlim的[paddleslim.prune.Pruner](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#Pruner)接口。
-- 检测库的常规训练方法
-- [检测模型数据准备](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/INSTALL_cn.md#%E6%95%B0%E6%8D%AE%E9%9B%86)
-- [PaddleSlim使用文档](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md)
+该教程中所示操作,如无特殊说明,均在`PaddleDetection/slim/prune/`路径下执行。
+## 1. 数据准备
-## 配置文件说明
+请参考检测库[数据下载](../../docs/INSTALL_cn.md)文档准备数据。
-关于配置文件如何编写您可以参考:
+## 2. 模型选择
-- [PaddleSlim配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8)
-- [裁剪策略配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#22-%E6%A8%A1%E5%9E%8B%E9%80%9A%E9%81%93%E5%89%AA%E8%A3%81)
+通过`-c`选项指定待裁剪模型的配置文件的相对路径,更多可选配置文件请参考: [检测库配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/0.1/configs)
-其中,配置文件中的`pruned_params`需要根据当前模型的网络结构特点设置,它用来指定要裁剪的parameters.
+对于剪裁任务,原模型的权重不一定对剪裁后的模型训练的重训练有贡献,所以加载原模型的权重不是必需的步骤。
-这里以MobileNetV1-YoloV3模型为例,其卷积可以三种:主干网络中的普通卷积,主干网络中的`depthwise convolution`和`yolo block`里的普通卷积。PaddleSlim暂时无法对`depthwise convolution`直接进行剪裁, 因为`depthwise convolution`的`channel`的变化会同时影响到前后的卷积层。我们这里只对主干网络中的普通卷积和`yolo block`里的普通卷积做裁剪。
-
-通过以下方式可视化模型结构:
+通过`-o weights`指定模型的权重,可以指定url或本地文件系统的路径。如下所示:
```
-from paddle.fluid.framework import IrGraph
-from paddle.fluid import core
-
-graph = IrGraph(core.Graph(train_prog.desc), for_test=True)
-marked_nodes = set()
-for op in graph.all_op_nodes():
- print(op.name())
- if op.name().find('conv') > -1:
- marked_nodes.add(op)
-graph.draw('.', 'forward', marked_nodes)
+-o weights=https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
```
-该示例中MobileNetV1-YoloV3模型结构的可视化结果:MobileNetV1-YoloV3.pdf
-
-同时通过以下命令观察目标卷积层的参数(parameters)的名称和shape:
+或
```
-for param in fluid.default_main_program().global_block().all_parameters():
- if 'weights' in param.name:
- print(param.name, param.shape)
+-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
+官方已发布的模型请参考: [模型库](https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.1/docs/MODEL_ZOO_cn.md)
-从可视化结果,我们可以排除后续会做concat的卷积层,最终得到如下要裁剪的参数名称:
+## 3. 确定待分析参数
-```
-conv2_1_sep_weights
-conv2_2_sep_weights
-conv3_1_sep_weights
-conv4_1_sep_weights
-conv5_1_sep_weights
-conv5_2_sep_weights
-conv5_3_sep_weights
-conv5_4_sep_weights
-conv5_5_sep_weights
-conv5_6_sep_weights
-yolo_block.0.0.0.conv.weights
-yolo_block.0.0.1.conv.weights
-yolo_block.0.1.0.conv.weights
-yolo_block.0.1.1.conv.weights
-yolo_block.1.0.0.conv.weights
-yolo_block.1.0.1.conv.weights
-yolo_block.1.1.0.conv.weights
-yolo_block.1.1.1.conv.weights
-yolo_block.1.2.conv.weights
-yolo_block.2.0.0.conv.weights
-yolo_block.2.0.1.conv.weights
-yolo_block.2.1.1.conv.weights
-yolo_block.2.2.conv.weights
-yolo_block.2.tip.conv.weights
-```
+我们通过剪裁卷积层参数达到缩减卷积层通道数的目的,在剪裁之前,我们需要确定待裁卷积层的参数的名称。
+通过以下命令查看当前模型的所有参数:
```
-(conv2_1_sep_weights)|(conv2_2_sep_weights)|(conv3_1_sep_weights)|(conv4_1_sep_weights)|(conv5_1_sep_weights)|(conv5_2_sep_weights)|(conv5_3_sep_weights)|(conv5_4_sep_weights)|(conv5_5_sep_weights)|(conv5_6_sep_weights)|(yolo_block.0.0.0.conv.weights)|(yolo_block.0.0.1.conv.weights)|(yolo_block.0.1.0.conv.weights)|(yolo_block.0.1.1.conv.weights)|(yolo_block.1.0.0.conv.weights)|(yolo_block.1.0.1.conv.weights)|(yolo_block.1.1.0.conv.weights)|(yolo_block.1.1.1.conv.weights)|(yolo_block.1.2.conv.weights)|(yolo_block.2.0.0.conv.weights)|(yolo_block.2.0.1.conv.weights)|(yolo_block.2.1.1.conv.weights)|(yolo_block.2.2.conv.weights)|(yolo_block.2.tip.conv.weights)
+python prune.py \
+-c ../../configs/yolov3_mobilenet_v1_voc.yml \
+--print_params
```
-综上,我们将MobileNetV2配置文件中的`pruned_params`设置为以下正则表达式:
-
-```
-(conv2_1_sep_weights)|(conv2_2_sep_weights)|(conv3_1_sep_weights)|(conv4_1_sep_weights)|(conv5_1_sep_weights)|(conv5_2_sep_weights)|(conv5_3_sep_weights)|(conv5_4_sep_weights)|(conv5_5_sep_weights)|(conv5_6_sep_weights)|(yolo_block.0.0.0.conv.weights)|(yolo_block.0.0.1.conv.weights)|(yolo_block.0.1.0.conv.weights)|(yolo_block.0.1.1.conv.weights)|(yolo_block.1.0.0.conv.weights)|(yolo_block.1.0.1.conv.weights)|(yolo_block.1.1.0.conv.weights)|(yolo_block.1.1.1.conv.weights)|(yolo_block.1.2.conv.weights)|(yolo_block.2.0.0.conv.weights)|(yolo_block.2.0.1.conv.weights)|(yolo_block.2.1.1.conv.weights)|(yolo_block.2.2.conv.weights)|(yolo_block.2.tip.conv.weights)
-```
+通过观察参数名称和参数的形状,筛选出所有卷积层参数,并确定要裁剪的卷积层参数。
-我们可以用上述操作观察其它检测模型的参数名称规律,然后设置合适的正则表达式来剪裁合适的参数。
-
-## 训练
-
-根据PaddleDetection/tools/train.py编写压缩脚本compress.py。
-在该脚本中定义了Compressor对象,用于执行压缩任务。
-
-### 执行示例
-
-step1: 设置gpu卡
-```
-export CUDA_VISIBLE_DEVICES=0
-```
-step2: 开始训练
+## 4. 启动剪裁任务
-使用PaddleDetection提供的配置文件在用8卡进行训练:
+使用`prune.py`启动裁剪任务时,通过`--pruned_params`选项指定待裁剪的参数名称列表,参数名之间用空格分隔,通过`--pruned_ratios`选项指定各个参数被裁掉的比例。
```
-python compress.py \
- -s yolov3_mobilenet_v1_slim.yaml \
- -c ../../configs/yolov3_mobilenet_v1_voc.yml \
- -o max_iters=258 \
- YoloTrainFeed.batch_size=64 \
- -d "../../dataset/voc"
+python prune.py \
+-c ../../configs/yolov3_mobilenet_v1_voc.yml \
+--pruned_params "yolo_block.0.0.0.conv.weights,yolo_block.0.0.1.conv.weights,yolo_block.0.1.0.conv.weights" \
+--pruned_ratios="0.2 0.3 0.4"
```
->通过命令行覆盖设置max_iters选项,因为PaddleDetection中训练是以`batch`为单位迭代的,并没有涉及`epoch`的概念,但是PaddleSlim需要知道当前训练进行到第几个`epoch`, 所以需要将`max_iters`设置为一个`epoch`内的`batch`的数量。
-
-如果要调整训练卡数,需要调整配置文件`yolov3_mobilenet_v1_voc.yml`中的以下参数:
-
-- **max_iters:** 一个`epoch`中batch的数量,需要设置为`total_num / batch_size`, 其中`total_num`为训练样本总数量,`batch_size`为多卡上总的batch size.
-- **YoloTrainFeed.batch_size:** 当使用DataLoader时,表示单张卡上的batch size; 当使用普通reader时,则表示多卡上的总的`batch_size`。`batch_size`受限于显存大小。
-- **LeaningRate.base_lr:** 根据多卡的总`batch_size`调整`base_lr`,两者大小正相关,可以简单的按比例进行调整。
-- **LearningRate.schedulers.PiecewiseDecay.milestones:** 请根据batch size的变化对其调整。
-- **LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:** 请根据batch size的变化对其进行调整。
-
-
-以下为4卡训练示例,通过命令行覆盖`yolov3_mobilenet_v1_voc.yml`中的参数:
-
-```
-python compress.py \
- -s yolov3_mobilenet_v1_slim.yaml \
- -c ../../configs/yolov3_mobilenet_v1_voc.yml \
- -o max_iters=258 \
- YoloTrainFeed.batch_size=64 \
- -d "../../dataset/voc"
-```
-
-以下为2卡训练示例,受显存所制,单卡`batch_size`不变,总`batch_size`减小,`base_lr`减小,一个epoch内batch数量增加,同时需要调整学习率相关参数,如下:
-```
-python compress.py \
- -s yolov3_mobilenet_v1_slim.yaml \
- -c ../../configs/yolov3_mobilenet_v1_voc.yml \
- -o max_iters=516 \
- LeaningRate.base_lr=0.005 \
- YoloTrainFeed.batch_size=32 \
- LearningRate.schedulers='[!PiecewiseDecay {gamma: 0.1, milestones: [110000, 124000]}, !LinearWarmup {start_factor: 0., steps: 2000}]' \
- -d "../../dataset/voc"
-```
-
-通过`python compress.py --help`查看可配置参数。
-通过`python ../../tools/configure.py ${option_name} help`查看如何通过命令行覆盖配置文件`yolov3_mobilenet_v1_voc.yml`中的参数。
-
-### 保存断点(checkpoint)
-
-如果在配置文件中设置了`checkpoint_path`, 则在压缩任务执行过程中会自动保存断点,当任务异常中断时,
-重启任务会自动从`checkpoint_path`路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复,
-需要修改配置文件中的`checkpoint_path`,或者将`checkpoint_path`路径下文件清空。
-
->注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。
-
-
-## 评估
-
-如果在配置文件中设置了`checkpoint_path`,则每个epoch会保存一个压缩后的用于评估的模型,
-该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__`和`__params__`两个文件。
-其中,`__model__`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。
-
-如果不需要保存评估模型,可以在定义Compressor对象时,将`save_eval_model`选项设置为False(默认为True)。
-
-运行命令为:
-```
-python ../eval.py \
- --model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
- --model_name __model__ \
- --params_name __params__ \
- -c ../../configs/yolov3_mobilenet_v1_voc.yml \
- -d "../../dataset/voc"
-```
-
-## 预测
-
-如果在配置文件中设置了`checkpoint_path`,并且在定义Compressor对象时指定了`prune_infer_model`选项,则每个epoch都会
-保存一个`inference model`。该模型是通过删除eval_program中多余的operators而得到的。
-
-该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__.infer`和`__params__`两个文件。
-其中,`__model__.infer`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。
-
-更多关于`prune_infer_model`选项的介绍,请参考:[Compressor介绍](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#121-%E5%A6%82%E4%BD%95%E6%94%B9%E5%86%99%E6%99%AE%E9%80%9A%E8%AE%AD%E7%BB%83%E8%84%9A%E6%9C%AC)
-
-### python预测
-
-在脚本PaddleDetection/tools/infer.py中展示了如何使用fluid python API加载使用预测模型进行预测。
-
-运行命令为:
-```
-python ../infer.py \
- --model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
- --model_name __model__.infer \
- --params_name __params__ \
- -c ../../configs/yolov3_mobilenet_v1_voc.yml \
- --infer_dir ../../demo
-```
-
-### PaddleLite
-
-该示例中产出的预测(inference)模型可以直接用PaddleLite进行加载使用。
-关于PaddleLite如何使用,请参考:[PaddleLite使用文档](https://github.com/PaddlePaddle/Paddle-Lite/wiki#%E4%BD%BF%E7%94%A8)
-
-## 示例结果
-
-> 当前release的结果并非超参调优后的最好结果,仅做示例参考,后续我们会优化当前结果。
-
-### MobileNetV1-YOLO-V3
-
-| FLOPS |Box AP| model_size |Paddle Fluid inference time(ms)| Paddle Lite inference time(ms)|
-|---|---|---|---|---|
-|baseline|76.2 |93M |- |-|
-|-50%|69.48 |51M |- |-|
+## 5. 扩展模型
-## FAQ
+如果需要对自己的模型进行修改,可以参考`prune.py`中对`paddleslim.prune.Pruner`接口的调用方式,基于自己的模型训练脚本进行修改。
+本节我们介绍的剪裁示例,需要用户根据先验知识指定每层的剪裁率,除此之外,PaddleSlim还提供了敏感度分析等功能,协助用户选择合适的剪裁率。更多详情请参考:[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)
diff --git a/slim/prune/prune.py b/slim/prune/prune.py
new file mode 100644
index 0000000000000000000000000000000000000000..90dfc5732b92ace4767dd744e5f42b291d5a8754
--- /dev/null
+++ b/slim/prune/prune.py
@@ -0,0 +1,386 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+import numpy as np
+import datetime
+from collections import deque
+from paddleslim.prune import Pruner
+from paddleslim.analysis import flops
+from paddle import fluid
+from ppdet.experimental import mixed_precision_context
+from ppdet.core.workspace import load_config, merge_config, create
+from ppdet.data.reader import create_reader
+from ppdet.utils.cli import print_total_cfg
+from ppdet.utils import dist_utils
+from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
+from ppdet.utils.stats import TrainingStats
+from ppdet.utils.cli import ArgsParser
+from ppdet.utils.check import check_gpu, check_version
+import ppdet.utils.checkpoint as checkpoint
+from ppdet.modeling.model_input import create_feed
+
+import logging
+FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT)
+logger = logging.getLogger(__name__)
+
+
+def main():
+ env = os.environ
+ FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
+ if FLAGS.dist:
+ trainer_id = int(env['PADDLE_TRAINER_ID'])
+ import random
+ local_seed = (99 + trainer_id)
+ random.seed(local_seed)
+ np.random.seed(local_seed)
+
+ cfg = load_config(FLAGS.config)
+ if 'architecture' in cfg:
+ main_arch = cfg.architecture
+ else:
+ raise ValueError("'architecture' not specified in config file.")
+
+ merge_config(FLAGS.opt)
+
+ if 'log_iter' not in cfg:
+ cfg.log_iter = 20
+
+ # check if set use_gpu=True in paddlepaddle cpu version
+ check_gpu(cfg.use_gpu)
+ # check if paddlepaddle version is satisfied
+ check_version()
+ if not FLAGS.dist or trainer_id == 0:
+ print_total_cfg(cfg)
+
+ if cfg.use_gpu:
+ devices_num = fluid.core.get_cuda_device_count()
+ else:
+ devices_num = int(os.environ.get('CPU_NUM', 1))
+
+ if 'FLAGS_selected_gpus' in env:
+ device_id = int(env['FLAGS_selected_gpus'])
+ else:
+ device_id = 0
+ place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+
+ lr_builder = create('LearningRate')
+ optim_builder = create('OptimizerBuilder')
+
+ # build program
+ startup_prog = fluid.Program()
+ train_prog = fluid.Program()
+ with fluid.program_guard(train_prog, startup_prog):
+ with fluid.unique_name.guard():
+ model = create(main_arch)
+ if FLAGS.fp16:
+ assert (getattr(model.backbone, 'norm_type', None)
+ != 'affine_channel'), \
+ '--fp16 currently does not support affine channel, ' \
+ ' please modify backbone settings to use batch norm'
+
+ with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx:
+ inputs_def = cfg['TrainReader']['inputs_def']
+ feed_vars, train_loader = model.build_inputs(**inputs_def)
+ train_fetches = model.train(feed_vars)
+ loss = train_fetches['loss']
+ if FLAGS.fp16:
+ loss *= ctx.get_loss_scale_var()
+ lr = lr_builder()
+ optimizer = optim_builder(lr)
+ optimizer.minimize(loss)
+ if FLAGS.fp16:
+ loss /= ctx.get_loss_scale_var()
+
+ # parse train fetches
+ train_keys, train_values, _ = parse_fetches(train_fetches)
+ train_values.append(lr)
+
+ if FLAGS.print_params:
+ print("-------------------------All parameters in current graph----------------------")
+ for block in train_prog.blocks:
+ for param in block.all_parameters():
+ print("parameter name: {}\tshape: {}".format(param.name, param.shape))
+ print("------------------------------------------------------------------------------")
+ return
+
+ if FLAGS.eval:
+ eval_prog = fluid.Program()
+ with fluid.program_guard(eval_prog, startup_prog):
+ with fluid.unique_name.guard():
+ model = create(main_arch)
+ inputs_def = cfg['EvalReader']['inputs_def']
+ feed_vars, eval_loader = model.build_inputs(**inputs_def)
+ fetches = model.eval(feed_vars)
+ eval_prog = eval_prog.clone(True)
+
+ eval_reader = create_reader(cfg.EvalReader)
+ eval_loader.set_sample_list_generator(eval_reader, place)
+
+ # parse eval fetches
+ extra_keys = []
+ if cfg.metric == 'COCO':
+ extra_keys = ['im_info', 'im_id', 'im_shape']
+ if cfg.metric == 'VOC':
+ extra_keys = ['gt_box', 'gt_label', 'is_difficult']
+ if cfg.metric == 'WIDERFACE':
+ extra_keys = ['im_id', 'im_shape', 'gt_box']
+ eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
+ extra_keys)
+
+ # compile program for multi-devices
+ build_strategy = fluid.BuildStrategy()
+ build_strategy.fuse_all_optimizer_ops = False
+ build_strategy.fuse_elewise_add_act_ops = True
+ # only enable sync_bn in multi GPU devices
+ sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
+ build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
+ and cfg.use_gpu
+
+ exec_strategy = fluid.ExecutionStrategy()
+ # iteration number when CompiledProgram tries to drop local execution scopes.
+ # Set it to be 1 to save memory usages, so that unused variables in
+ # local execution scopes can be deleted after each iteration.
+ exec_strategy.num_iteration_per_drop_scope = 1
+ if FLAGS.dist:
+ dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
+ train_prog)
+ exec_strategy.num_threads = 1
+
+ exe.run(startup_prog)
+
+ fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
+
+ start_iter = 0
+ if FLAGS.resume_checkpoint:
+ checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
+ start_iter = checkpoint.global_step()
+ elif cfg.pretrain_weights:
+ checkpoint.load_params(
+ exe, train_prog, cfg.pretrain_weights)
+
+
+ pruned_params = FLAGS.pruned_params
+ assert (FLAGS.pruned_params is not None), "FLAGS.pruned_params is empty!!! Please set it by '--pruned_params' option."
+ pruned_params = FLAGS.pruned_params.strip().split(",")
+ logger.info("pruned params: {}".format(pruned_params))
+ pruned_ratios = [float(n) for n in FLAGS.pruned_ratios.strip().split(" ")]
+ logger.info("pruned ratios: {}".format(pruned_ratios))
+ assert(len(pruned_params) == len(pruned_ratios)), "The length of pruned params and pruned ratios should be equal."
+ assert(pruned_ratios > [0] * len(pruned_ratios) and pruned_ratios < [1] * len(pruned_ratios)), "The elements of pruned ratios should be in range (0, 1)."
+
+
+ pruner = Pruner()
+ train_prog = pruner.prune(
+ train_prog,
+ fluid.global_scope(),
+ params=pruned_params,
+ ratios=pruned_ratios,
+ place=place,
+ only_graph=False)[0]
+
+ compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
+ loss_name=loss.name,
+ build_strategy=build_strategy,
+ exec_strategy=exec_strategy)
+
+ if FLAGS.eval:
+
+ base_flops = flops(eval_prog)
+ eval_prog = pruner.prune(
+ eval_prog,
+ fluid.global_scope(),
+ params=pruned_params,
+ ratios=pruned_ratios,
+ place=place,
+ only_graph=True)[0]
+ pruned_flops = flops(eval_prog)
+ logger.info("FLOPs -{}; total FLOPs: {}; pruned FLOPs: {}".format(float(base_flops - pruned_flops)/base_flops, base_flops, pruned_flops))
+ compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
+
+
+
+ train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
+ devices_num, cfg)
+ train_loader.set_sample_list_generator(train_reader, place)
+
+ # whether output bbox is normalized in model output layer
+ is_bbox_normalized = False
+ if hasattr(model, 'is_bbox_normalized') and \
+ callable(model.is_bbox_normalized):
+ is_bbox_normalized = model.is_bbox_normalized()
+
+ # if map_type not set, use default 11point, only use in VOC eval
+ map_type = cfg.map_type if 'map_type' in cfg else '11point'
+
+ train_stats = TrainingStats(cfg.log_smooth_window, train_keys)
+ train_loader.start()
+ start_time = time.time()
+ end_time = time.time()
+
+ cfg_name = os.path.basename(FLAGS.config).split('.')[0]
+ save_dir = os.path.join(cfg.save_dir, cfg_name)
+ time_stat = deque(maxlen=cfg.log_smooth_window)
+ best_box_ap_list = [0.0, 0] #[map, iter]
+
+ # use tb-paddle to log data
+ if FLAGS.use_tb:
+ from tb_paddle import SummaryWriter
+ tb_writer = SummaryWriter(FLAGS.tb_log_dir)
+ tb_loss_step = 0
+ tb_mAP_step = 0
+
+
+
+ if FLAGS.eval:
+ # evaluation
+ results = eval_run(exe, compiled_eval_prog, eval_loader,
+ eval_keys, eval_values, eval_cls)
+ resolution = None
+ if 'mask' in results[0]:
+ resolution = model.mask_head.resolution
+ dataset = cfg['EvalReader']['dataset']
+ box_ap_stats = eval_results(
+ results,
+ cfg.metric,
+ cfg.num_classes,
+ resolution,
+ is_bbox_normalized,
+ FLAGS.output_eval,
+ map_type,
+ dataset=dataset)
+
+
+
+ for it in range(start_iter, cfg.max_iters):
+ start_time = end_time
+ end_time = time.time()
+ time_stat.append(end_time - start_time)
+ time_cost = np.mean(time_stat)
+ eta_sec = (cfg.max_iters - it) * time_cost
+ eta = str(datetime.timedelta(seconds=int(eta_sec)))
+ outs = exe.run(compiled_train_prog, fetch_list=train_values)
+ stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
+
+ # use tb-paddle to log loss
+ if FLAGS.use_tb:
+ if it % cfg.log_iter == 0:
+ for loss_name, loss_value in stats.items():
+ tb_writer.add_scalar(loss_name, loss_value, tb_loss_step)
+ tb_loss_step += 1
+
+ train_stats.update(stats)
+ logs = train_stats.log()
+ if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
+ strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
+ it, np.mean(outs[-1]), logs, time_cost, eta)
+ logger.info(strs)
+
+ if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
+ and (not FLAGS.dist or trainer_id == 0):
+ save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
+ checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
+
+ if FLAGS.eval:
+ # evaluation
+ results = eval_run(exe, compiled_eval_prog, eval_loader,
+ eval_keys, eval_values, eval_cls)
+ resolution = None
+ if 'mask' in results[0]:
+ resolution = model.mask_head.resolution
+ box_ap_stats = eval_results(
+ results, eval_feed, cfg.metric, cfg.num_classes, resolution,
+ is_bbox_normalized, FLAGS.output_eval, map_type)
+
+ # use tb_paddle to log mAP
+ if FLAGS.use_tb:
+ tb_writer.add_scalar("mAP", box_ap_stats[0], tb_mAP_step)
+ tb_mAP_step += 1
+
+ if box_ap_stats[0] > best_box_ap_list[0]:
+ best_box_ap_list[0] = box_ap_stats[0]
+ best_box_ap_list[1] = it
+ checkpoint.save(exe, train_prog,
+ os.path.join(save_dir, "best_model"))
+ logger.info("Best test box ap: {}, in iter: {}".format(
+ best_box_ap_list[0], best_box_ap_list[1]))
+
+ train_loader.reset()
+
+
+if __name__ == '__main__':
+ parser = ArgsParser()
+ parser.add_argument(
+ "-r",
+ "--resume_checkpoint",
+ default=None,
+ type=str,
+ help="Checkpoint path for resuming training.")
+ parser.add_argument(
+ "--fp16",
+ action='store_true',
+ default=False,
+ help="Enable mixed precision training.")
+ parser.add_argument(
+ "--loss_scale",
+ default=8.,
+ type=float,
+ help="Mixed precision training loss scale.")
+ parser.add_argument(
+ "--eval",
+ action='store_true',
+ default=False,
+ help="Whether to perform evaluation in train")
+ parser.add_argument(
+ "--output_eval",
+ default=None,
+ type=str,
+ help="Evaluation directory, default is current directory.")
+ parser.add_argument(
+ "--use_tb",
+ type=bool,
+ default=False,
+ help="whether to record the data to Tensorboard.")
+ parser.add_argument(
+ '--tb_log_dir',
+ type=str,
+ default="tb_log_dir/scalar",
+ help='Tensorboard logging directory for scalar.')
+
+ parser.add_argument(
+ "-p",
+ "--pruned_params",
+ default=None,
+ type=str,
+ help="The parameters to be pruned when calculating sensitivities.")
+ parser.add_argument(
+ "--pruned_ratios",
+ default="0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9",
+ type=str,
+ help="The ratios pruned iteratively for each parameter when calculating sensitivities.")
+ parser.add_argument(
+ "-P",
+ "--print_params",
+ default=False,
+ action='store_true',
+ help="Whether to only print the parameters' names and shapes.")
+ FLAGS = parser.parse_args()
+ main()