From 10f2931e02861c203bcaa629e622a9226df00e77 Mon Sep 17 00:00:00 2001 From: whs Date: Thu, 14 May 2020 21:05:27 +0800 Subject: [PATCH] Add scripts demo in README of pruning (#672) * Add scripts demo in README of pruning --- slim/prune/README.md | 23 +++++++++++++++++++++++ slim/prune/prune.py | 15 +++++++++++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/slim/prune/README.md b/slim/prune/README.md index 227b87a75..d000c329b 100644 --- a/slim/prune/README.md +++ b/slim/prune/README.md @@ -90,3 +90,26 @@ python export_model.py \ 如果需要对自己的模型进行修改,可以参考`prune.py`中对`paddleslim.prune.Pruner`接口的调用方式,基于自己的模型训练脚本进行修改。 本节我们介绍的剪裁示例,需要用户根据先验知识指定每层的剪裁率,除此之外,PaddleSlim还提供了敏感度分析等功能,协助用户选择合适的剪裁率。更多详情请参考:[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/) + +## 9. 更多示例与注意事项 + +## 9.1 faster_rcnn与mask_rcnn + +**当前PaddleSlim的剪裁功能不支持剪裁循环体或条件判断语句块内的卷积层,请避免剪裁循环和判断语句块前的一个卷积和语句块内部的卷积。** + +对于[faster_rcnn_r50](../../configs/faster_rcnn_r50_1x.yml)或[mask_rcnn_r50](../../configs/mask_rcnn_r50_1x.yml)网络,请剪裁卷积`res4f_branch2c`之前的卷积。 + +对[faster_rcnn_r50](../../configs/faster_rcnn_r50_1x.yml)剪裁示例如下: + +``` +# demo for faster_rcnn_r50 +python prune.py -c ../../configs/faster_rcnn_r50_1x.yml --pruned_params "res4f_branch2b_weights,res4f_branch2a_weights" --pruned_ratios="0.3,0.4" --eval +``` + +对[mask_rcnn_r50](../../configs/mask_rcnn_r50_1x.yml)剪裁示例如下: + +``` +# demo for mask_rcnn_r50 +python prune.py -c ../../configs/mask_rcnn_r50_1x.yml --pruned_params "res4f_branch2b_weights,res4f_branch2a_weights" --pruned_ratios="0.2,0.3" --eval + +``` diff --git a/slim/prune/prune.py b/slim/prune/prune.py index c907df9cd..6f6dcc4ca 100644 --- a/slim/prune/prune.py +++ b/slim/prune/prune.py @@ -252,12 +252,19 @@ def main(): tb_mAP_step = 0 if FLAGS.eval: - # evaluation - results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys, - eval_values, eval_cls, cfg) resolution = None - if 'mask' in results[0]: + if 'Mask' in cfg.architecture: resolution = model.mask_head.resolution + # evaluation + results = eval_run( + exe, + compiled_eval_prog, + eval_loader, + eval_keys, + eval_values, + eval_cls, + cfg, + resolution=resolution) dataset = cfg['EvalReader']['dataset'] box_ap_stats = eval_results( results, -- GitLab