未验证 提交 d941928d 编写于 作者: W whs 提交者: GitHub

Add tutorial for pruning. (#152)

上级 684fe1b6
>运行该示例前请安装Paddle1.6或更高版本
# 卷积层通道剪裁教程
# 检测模型卷积通道剪裁示例
请确保已正确[安装PaddleDetection](../../docs/INSTALL_cn.md)及其依赖。
## 概述
该文档介绍如何使用[PaddleSlim](https://paddlepaddle.github.io/PaddleSlim)的卷积通道剪裁接口对检测库中的模型的卷积层的通道数进行剪裁。
该示例使用PaddleSlim提供的[卷积通道剪裁压缩策略](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86)对检测库中的模型进行压缩。
在阅读该示例前,建议您先了解以下内容:
在检测库中,可以直接调用`PaddleDetection/slim/prune/prune.py`脚本实现剪裁,在该脚本中调用了PaddleSlim的[paddleslim.prune.Pruner](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#Pruner)接口。
- <a href="../../README_cn.md">检测库的常规训练方法</a>
- [检测模型数据准备](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/INSTALL_cn.md#%E6%95%B0%E6%8D%AE%E9%9B%86)
- [PaddleSlim使用文档](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md)
该教程中所示操作,如无特殊说明,均在`PaddleDetection/slim/prune/`路径下执行。
## 1. 数据准备
## 配置文件说明
请参考检测库[数据下载](../../docs/INSTALL_cn.md)文档准备数据。
关于配置文件如何编写您可以参考:
## 2. 模型选择
- [PaddleSlim配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8)
- [裁剪策略配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#22-%E6%A8%A1%E5%9E%8B%E9%80%9A%E9%81%93%E5%89%AA%E8%A3%81)
通过`-c`选项指定待裁剪模型的配置文件的相对路径,更多可选配置文件请参考: [检测库配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/0.1/configs)
其中,配置文件中的`pruned_params`需要根据当前模型的网络结构特点设置,它用来指定要裁剪的parameters.
对于剪裁任务,原模型的权重不一定对剪裁后的模型训练的重训练有贡献,所以加载原模型的权重不是必需的步骤。
这里以MobileNetV1-YoloV3模型为例,其卷积可以三种:主干网络中的普通卷积,主干网络中的`depthwise convolution``yolo block`里的普通卷积。PaddleSlim暂时无法对`depthwise convolution`直接进行剪裁, 因为`depthwise convolution``channel`的变化会同时影响到前后的卷积层。我们这里只对主干网络中的普通卷积和`yolo block`里的普通卷积做裁剪。
通过以下方式可视化模型结构:
通过`-o weights`指定模型的权重,可以指定url或本地文件系统的路径。如下所示:
```
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
graph = IrGraph(core.Graph(train_prog.desc), for_test=True)
marked_nodes = set()
for op in graph.all_op_nodes():
print(op.name())
if op.name().find('conv') > -1:
marked_nodes.add(op)
graph.draw('.', 'forward', marked_nodes)
-o weights=https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_voc.tar
```
该示例中MobileNetV1-YoloV3模型结构的可视化结果:<a href="./images/MobileNetV1-YoloV3.pdf">MobileNetV1-YoloV3.pdf</a>
同时通过以下命令观察目标卷积层的参数(parameters)的名称和shape:
```
for param in fluid.default_main_program().global_block().all_parameters():
if 'weights' in param.name:
print(param.name, param.shape)
-o weights=output/yolov3_mobilenet_v1_voc/model_final
```
官方已发布的模型请参考: [模型库](https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.1/docs/MODEL_ZOO_cn.md)
从可视化结果,我们可以排除后续会做concat的卷积层,最终得到如下要裁剪的参数名称:
## 3. 确定待分析参数
```
conv2_1_sep_weights
conv2_2_sep_weights
conv3_1_sep_weights
conv4_1_sep_weights
conv5_1_sep_weights
conv5_2_sep_weights
conv5_3_sep_weights
conv5_4_sep_weights
conv5_5_sep_weights
conv5_6_sep_weights
yolo_block.0.0.0.conv.weights
yolo_block.0.0.1.conv.weights
yolo_block.0.1.0.conv.weights
yolo_block.0.1.1.conv.weights
yolo_block.1.0.0.conv.weights
yolo_block.1.0.1.conv.weights
yolo_block.1.1.0.conv.weights
yolo_block.1.1.1.conv.weights
yolo_block.1.2.conv.weights
yolo_block.2.0.0.conv.weights
yolo_block.2.0.1.conv.weights
yolo_block.2.1.1.conv.weights
yolo_block.2.2.conv.weights
yolo_block.2.tip.conv.weights
```
我们通过剪裁卷积层参数达到缩减卷积层通道数的目的,在剪裁之前,我们需要确定待裁卷积层的参数的名称。
通过以下命令查看当前模型的所有参数:
```
(conv2_1_sep_weights)|(conv2_2_sep_weights)|(conv3_1_sep_weights)|(conv4_1_sep_weights)|(conv5_1_sep_weights)|(conv5_2_sep_weights)|(conv5_3_sep_weights)|(conv5_4_sep_weights)|(conv5_5_sep_weights)|(conv5_6_sep_weights)|(yolo_block.0.0.0.conv.weights)|(yolo_block.0.0.1.conv.weights)|(yolo_block.0.1.0.conv.weights)|(yolo_block.0.1.1.conv.weights)|(yolo_block.1.0.0.conv.weights)|(yolo_block.1.0.1.conv.weights)|(yolo_block.1.1.0.conv.weights)|(yolo_block.1.1.1.conv.weights)|(yolo_block.1.2.conv.weights)|(yolo_block.2.0.0.conv.weights)|(yolo_block.2.0.1.conv.weights)|(yolo_block.2.1.1.conv.weights)|(yolo_block.2.2.conv.weights)|(yolo_block.2.tip.conv.weights)
python prune.py \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
--print_params
```
综上,我们将MobileNetV2配置文件中的`pruned_params`设置为以下正则表达式:
```
(conv2_1_sep_weights)|(conv2_2_sep_weights)|(conv3_1_sep_weights)|(conv4_1_sep_weights)|(conv5_1_sep_weights)|(conv5_2_sep_weights)|(conv5_3_sep_weights)|(conv5_4_sep_weights)|(conv5_5_sep_weights)|(conv5_6_sep_weights)|(yolo_block.0.0.0.conv.weights)|(yolo_block.0.0.1.conv.weights)|(yolo_block.0.1.0.conv.weights)|(yolo_block.0.1.1.conv.weights)|(yolo_block.1.0.0.conv.weights)|(yolo_block.1.0.1.conv.weights)|(yolo_block.1.1.0.conv.weights)|(yolo_block.1.1.1.conv.weights)|(yolo_block.1.2.conv.weights)|(yolo_block.2.0.0.conv.weights)|(yolo_block.2.0.1.conv.weights)|(yolo_block.2.1.1.conv.weights)|(yolo_block.2.2.conv.weights)|(yolo_block.2.tip.conv.weights)
```
通过观察参数名称和参数的形状,筛选出所有卷积层参数,并确定要裁剪的卷积层参数。
我们可以用上述操作观察其它检测模型的参数名称规律,然后设置合适的正则表达式来剪裁合适的参数。
## 训练
根据<a href="../../tools/train.py">PaddleDetection/tools/train.py</a>编写压缩脚本compress.py。
在该脚本中定义了Compressor对象,用于执行压缩任务。
### 执行示例
step1: 设置gpu卡
```
export CUDA_VISIBLE_DEVICES=0
```
step2: 开始训练
## 4. 启动剪裁任务
使用PaddleDetection提供的配置文件在用8卡进行训练:
使用`prune.py`启动裁剪任务时,通过`--pruned_params`选项指定待裁剪的参数名称列表,参数名之间用空格分隔,通过`--pruned_ratios`选项指定各个参数被裁掉的比例。
```
python compress.py \
-s yolov3_mobilenet_v1_slim.yaml \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-o max_iters=258 \
YoloTrainFeed.batch_size=64 \
-d "../../dataset/voc"
python prune.py \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
--pruned_params "yolo_block.0.0.0.conv.weights,yolo_block.0.0.1.conv.weights,yolo_block.0.1.0.conv.weights" \
--pruned_ratios="0.2 0.3 0.4"
```
>通过命令行覆盖设置max_iters选项,因为PaddleDetection中训练是以`batch`为单位迭代的,并没有涉及`epoch`的概念,但是PaddleSlim需要知道当前训练进行到第几个`epoch`, 所以需要将`max_iters`设置为一个`epoch`内的`batch`的数量。
如果要调整训练卡数,需要调整配置文件`yolov3_mobilenet_v1_voc.yml`中的以下参数:
- **max_iters:** 一个`epoch`中batch的数量,需要设置为`total_num / batch_size`, 其中`total_num`为训练样本总数量,`batch_size`为多卡上总的batch size.
- **YoloTrainFeed.batch_size:** 当使用DataLoader时,表示单张卡上的batch size; 当使用普通reader时,则表示多卡上的总的`batch_size``batch_size`受限于显存大小。
- **LeaningRate.base_lr:** 根据多卡的总`batch_size`调整`base_lr`,两者大小正相关,可以简单的按比例进行调整。
- **LearningRate.schedulers.PiecewiseDecay.milestones:** 请根据batch size的变化对其调整。
- **LearningRate.schedulers.PiecewiseDecay.LinearWarmup.steps:** 请根据batch size的变化对其进行调整。
以下为4卡训练示例,通过命令行覆盖`yolov3_mobilenet_v1_voc.yml`中的参数:
```
python compress.py \
-s yolov3_mobilenet_v1_slim.yaml \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-o max_iters=258 \
YoloTrainFeed.batch_size=64 \
-d "../../dataset/voc"
```
以下为2卡训练示例,受显存所制,单卡`batch_size`不变,总`batch_size`减小,`base_lr`减小,一个epoch内batch数量增加,同时需要调整学习率相关参数,如下:
```
python compress.py \
-s yolov3_mobilenet_v1_slim.yaml \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-o max_iters=516 \
LeaningRate.base_lr=0.005 \
YoloTrainFeed.batch_size=32 \
LearningRate.schedulers='[!PiecewiseDecay {gamma: 0.1, milestones: [110000, 124000]}, !LinearWarmup {start_factor: 0., steps: 2000}]' \
-d "../../dataset/voc"
```
通过`python compress.py --help`查看可配置参数。
通过`python ../../tools/configure.py ${option_name} help`查看如何通过命令行覆盖配置文件`yolov3_mobilenet_v1_voc.yml`中的参数。
### 保存断点(checkpoint)
如果在配置文件中设置了`checkpoint_path`, 则在压缩任务执行过程中会自动保存断点,当任务异常中断时,
重启任务会自动从`checkpoint_path`路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复,
需要修改配置文件中的`checkpoint_path`,或者将`checkpoint_path`路径下文件清空。
>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。
## 评估
如果在配置文件中设置了`checkpoint_path`,则每个epoch会保存一个压缩后的用于评估的模型,
该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__``__params__`两个文件。
其中,`__model__`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。
如果不需要保存评估模型,可以在定义Compressor对象时,将`save_eval_model`选项设置为False(默认为True)。
运行命令为:
```
python ../eval.py \
--model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
--model_name __model__ \
--params_name __params__ \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
-d "../../dataset/voc"
```
## 预测
如果在配置文件中设置了`checkpoint_path`,并且在定义Compressor对象时指定了`prune_infer_model`选项,则每个epoch都会
保存一个`inference model`。该模型是通过删除eval_program中多余的operators而得到的。
该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__.infer``__params__`两个文件。
其中,`__model__.infer`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。
更多关于`prune_infer_model`选项的介绍,请参考:[Compressor介绍](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#121-%E5%A6%82%E4%BD%95%E6%94%B9%E5%86%99%E6%99%AE%E9%80%9A%E8%AE%AD%E7%BB%83%E8%84%9A%E6%9C%AC)
### python预测
在脚本<a href="../infer.py">PaddleDetection/tools/infer.py</a>中展示了如何使用fluid python API加载使用预测模型进行预测。
运行命令为:
```
python ../infer.py \
--model_path ${checkpoint_path}/${epoch_id}/eval_model/ \
--model_name __model__.infer \
--params_name __params__ \
-c ../../configs/yolov3_mobilenet_v1_voc.yml \
--infer_dir ../../demo
```
### PaddleLite
该示例中产出的预测(inference)模型可以直接用PaddleLite进行加载使用。
关于PaddleLite如何使用,请参考:[PaddleLite使用文档](https://github.com/PaddlePaddle/Paddle-Lite/wiki#%E4%BD%BF%E7%94%A8)
## 示例结果
> 当前release的结果并非超参调优后的最好结果,仅做示例参考,后续我们会优化当前结果。
### MobileNetV1-YOLO-V3
| FLOPS |Box AP| model_size |Paddle Fluid inference time(ms)| Paddle Lite inference time(ms)|
|---|---|---|---|---|
|baseline|76.2 |93M |- |-|
|-50%|69.48 |51M |- |-|
## 5. 扩展模型
## FAQ
如果需要对自己的模型进行修改,可以参考`prune.py`中对`paddleslim.prune.Pruner`接口的调用方式,基于自己的模型训练脚本进行修改。
本节我们介绍的剪裁示例,需要用户根据先验知识指定每层的剪裁率,除此之外,PaddleSlim还提供了敏感度分析等功能,协助用户选择合适的剪裁率。更多详情请参考:[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import datetime
from collections import deque
from paddleslim.prune import Pruner
from paddleslim.analysis import flops
from paddle import fluid
from ppdet.experimental import mixed_precision_context
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.reader import create_reader
from ppdet.utils.cli import print_total_cfg
from ppdet.utils import dist_utils
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu, check_version
import ppdet.utils.checkpoint as checkpoint
from ppdet.modeling.model_input import create_feed
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def main():
env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID'])
import random
local_seed = (99 + trainer_id)
random.seed(local_seed)
np.random.seed(local_seed)
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
else:
raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt)
if 'log_iter' not in cfg:
cfg.log_iter = 20
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
# check if paddlepaddle version is satisfied
check_version()
if not FLAGS.dist or trainer_id == 0:
print_total_cfg(cfg)
if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count()
else:
devices_num = int(os.environ.get('CPU_NUM', 1))
if 'FLAGS_selected_gpus' in env:
device_id = int(env['FLAGS_selected_gpus'])
else:
device_id = 0
place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
lr_builder = create('LearningRate')
optim_builder = create('OptimizerBuilder')
# build program
startup_prog = fluid.Program()
train_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
if FLAGS.fp16:
assert (getattr(model.backbone, 'norm_type', None)
!= 'affine_channel'), \
'--fp16 currently does not support affine channel, ' \
' please modify backbone settings to use batch norm'
with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx:
inputs_def = cfg['TrainReader']['inputs_def']
feed_vars, train_loader = model.build_inputs(**inputs_def)
train_fetches = model.train(feed_vars)
loss = train_fetches['loss']
if FLAGS.fp16:
loss *= ctx.get_loss_scale_var()
lr = lr_builder()
optimizer = optim_builder(lr)
optimizer.minimize(loss)
if FLAGS.fp16:
loss /= ctx.get_loss_scale_var()
# parse train fetches
train_keys, train_values, _ = parse_fetches(train_fetches)
train_values.append(lr)
if FLAGS.print_params:
print("-------------------------All parameters in current graph----------------------")
for block in train_prog.blocks:
for param in block.all_parameters():
print("parameter name: {}\tshape: {}".format(param.name, param.shape))
print("------------------------------------------------------------------------------")
return
if FLAGS.eval:
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
inputs_def = cfg['EvalReader']['inputs_def']
feed_vars, eval_loader = model.build_inputs(**inputs_def)
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
eval_reader = create_reader(cfg.EvalReader)
eval_loader.set_sample_list_generator(eval_reader, place)
# parse eval fetches
extra_keys = []
if cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg.metric == 'VOC':
extra_keys = ['gt_box', 'gt_label', 'is_difficult']
if cfg.metric == 'WIDERFACE':
extra_keys = ['im_id', 'im_shape', 'gt_box']
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys)
# compile program for multi-devices
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_optimizer_ops = False
build_strategy.fuse_elewise_add_act_ops = True
# only enable sync_bn in multi GPU devices
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
and cfg.use_gpu
exec_strategy = fluid.ExecutionStrategy()
# iteration number when CompiledProgram tries to drop local execution scopes.
# Set it to be 1 to save memory usages, so that unused variables in
# local execution scopes can be deleted after each iteration.
exec_strategy.num_iteration_per_drop_scope = 1
if FLAGS.dist:
dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
train_prog)
exec_strategy.num_threads = 1
exe.run(startup_prog)
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
start_iter = 0
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
elif cfg.pretrain_weights:
checkpoint.load_params(
exe, train_prog, cfg.pretrain_weights)
pruned_params = FLAGS.pruned_params
assert (FLAGS.pruned_params is not None), "FLAGS.pruned_params is empty!!! Please set it by '--pruned_params' option."
pruned_params = FLAGS.pruned_params.strip().split(",")
logger.info("pruned params: {}".format(pruned_params))
pruned_ratios = [float(n) for n in FLAGS.pruned_ratios.strip().split(" ")]
logger.info("pruned ratios: {}".format(pruned_ratios))
assert(len(pruned_params) == len(pruned_ratios)), "The length of pruned params and pruned ratios should be equal."
assert(pruned_ratios > [0] * len(pruned_ratios) and pruned_ratios < [1] * len(pruned_ratios)), "The elements of pruned ratios should be in range (0, 1)."
pruner = Pruner()
train_prog = pruner.prune(
train_prog,
fluid.global_scope(),
params=pruned_params,
ratios=pruned_ratios,
place=place,
only_graph=False)[0]
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
if FLAGS.eval:
base_flops = flops(eval_prog)
eval_prog = pruner.prune(
eval_prog,
fluid.global_scope(),
params=pruned_params,
ratios=pruned_ratios,
place=place,
only_graph=True)[0]
pruned_flops = flops(eval_prog)
logger.info("FLOPs -{}; total FLOPs: {}; pruned FLOPs: {}".format(float(base_flops - pruned_flops)/base_flops, base_flops, pruned_flops))
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
devices_num, cfg)
train_loader.set_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \
callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized()
# if map_type not set, use default 11point, only use in VOC eval
map_type = cfg.map_type if 'map_type' in cfg else '11point'
train_stats = TrainingStats(cfg.log_smooth_window, train_keys)
train_loader.start()
start_time = time.time()
end_time = time.time()
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(cfg.save_dir, cfg_name)
time_stat = deque(maxlen=cfg.log_smooth_window)
best_box_ap_list = [0.0, 0] #[map, iter]
# use tb-paddle to log data
if FLAGS.use_tb:
from tb_paddle import SummaryWriter
tb_writer = SummaryWriter(FLAGS.tb_log_dir)
tb_loss_step = 0
tb_mAP_step = 0
if FLAGS.eval:
# evaluation
results = eval_run(exe, compiled_eval_prog, eval_loader,
eval_keys, eval_values, eval_cls)
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
dataset = cfg['EvalReader']['dataset']
box_ap_stats = eval_results(
results,
cfg.metric,
cfg.num_classes,
resolution,
is_bbox_normalized,
FLAGS.output_eval,
map_type,
dataset=dataset)
for it in range(start_iter, cfg.max_iters):
start_time = end_time
end_time = time.time()
time_stat.append(end_time - start_time)
time_cost = np.mean(time_stat)
eta_sec = (cfg.max_iters - it) * time_cost
eta = str(datetime.timedelta(seconds=int(eta_sec)))
outs = exe.run(compiled_train_prog, fetch_list=train_values)
stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
# use tb-paddle to log loss
if FLAGS.use_tb:
if it % cfg.log_iter == 0:
for loss_name, loss_value in stats.items():
tb_writer.add_scalar(loss_name, loss_value, tb_loss_step)
tb_loss_step += 1
train_stats.update(stats)
logs = train_stats.log()
if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
it, np.mean(outs[-1]), logs, time_cost, eta)
logger.info(strs)
if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
and (not FLAGS.dist or trainer_id == 0):
save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
if FLAGS.eval:
# evaluation
results = eval_run(exe, compiled_eval_prog, eval_loader,
eval_keys, eval_values, eval_cls)
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
box_ap_stats = eval_results(
results, eval_feed, cfg.metric, cfg.num_classes, resolution,
is_bbox_normalized, FLAGS.output_eval, map_type)
# use tb_paddle to log mAP
if FLAGS.use_tb:
tb_writer.add_scalar("mAP", box_ap_stats[0], tb_mAP_step)
tb_mAP_step += 1
if box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0]
best_box_ap_list[1] = it
checkpoint.save(exe, train_prog,
os.path.join(save_dir, "best_model"))
logger.info("Best test box ap: {}, in iter: {}".format(
best_box_ap_list[0], best_box_ap_list[1]))
train_loader.reset()
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"-r",
"--resume_checkpoint",
default=None,
type=str,
help="Checkpoint path for resuming training.")
parser.add_argument(
"--fp16",
action='store_true',
default=False,
help="Enable mixed precision training.")
parser.add_argument(
"--loss_scale",
default=8.,
type=float,
help="Mixed precision training loss scale.")
parser.add_argument(
"--eval",
action='store_true',
default=False,
help="Whether to perform evaluation in train")
parser.add_argument(
"--output_eval",
default=None,
type=str,
help="Evaluation directory, default is current directory.")
parser.add_argument(
"--use_tb",
type=bool,
default=False,
help="whether to record the data to Tensorboard.")
parser.add_argument(
'--tb_log_dir',
type=str,
default="tb_log_dir/scalar",
help='Tensorboard logging directory for scalar.')
parser.add_argument(
"-p",
"--pruned_params",
default=None,
type=str,
help="The parameters to be pruned when calculating sensitivities.")
parser.add_argument(
"--pruned_ratios",
default="0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9",
type=str,
help="The ratios pruned iteratively for each parameter when calculating sensitivities.")
parser.add_argument(
"-P",
"--print_params",
default=False,
action='store_true',
help="Whether to only print the parameters' names and shapes.")
FLAGS = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册