未验证 提交 6f9fd567 编写于 作者: G Guanghua Yu 提交者: GitHub

add prune dygraph, test=dygraph (#2123)

* add prune dygraph, test=dygraph

* fix slim readme and prune
上级 944eee5c
# 模型压缩
在PaddleDetection中, 提供了基于[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)进行模型压缩的完整教程和benchmark。目前支持的方法:
- [剪裁](prune)
推荐您使用剪裁和蒸馏联合训练,或者使用剪裁和量化,进行检测模型压缩。 下面以YOLOv3为例,进行剪裁、蒸馏和量化实验。
## Benchmark
### 剪裁
#### Pascal VOC上benchmark
| 模型 | 压缩策略 | GFLOPs | 模型体积(MB) | 输入尺寸 | 预测时延(SD855)| Box AP | 下载 | 模型配置文件 | 压缩算法配置文件 |
| :----------------| :-------: | :------------: | :-------------: | :------: | :--------: | :------: | :-----------------------------------------------------: |:-------------: | :------: |
| YOLOv3-MobileNetV1 | baseline | 24.13 | 93 | 608 | 289.9ms | 75.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_mobilenet_v1_270e_voc.yml) | - |
| YOLOv3-MobileNetV1 | 剪裁-l1_norm(sensity) | 15.78(-34.49%) | 66(-29%) | 608 | - | 77.6(+2.5) | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/slim/yolov3_mobilenet_v1_voc_prune_l1_norm.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_mobilenet_v1_270e_voc.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/slim/prune/yolov3_prune_l1_norm.yml) |
- SD855预测时延为使用PaddleLite部署,使用arm8架构并使用4线程(4 Threads)推理时延
## 实验环境
- Python 3.7+
- PaddlePaddle >= 2.0.0
- PaddleSlim >= 2.0.0
- CUDA 9.0+
- cuDNN >=7.5
## 快速开始
### 训练
```shell
python tools/train.py -c configs/{MODEL.yml} --slim_config configs/slim/{SLIM_CONFIG.yml}
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
### 评估
```shell
python tools/eval.py -c configs/{MODEL.yml} --slim_config configs/slim/{SLIM_CONFIG.yml} -o weights=output/{SLIM_CONFIG}/model_final
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
### 测试
```shell
python tools/infer.py -c configs/{MODEL.yml} --slim_config configs/slim/{SLIM_CONFIG.yml} \
-o weights=output/{SLIM_CONFIG}/model_final
--infer_img={IMAGE_PATH}
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
- `--infer_img`: 指定测试图像路径。
### 动转静导出模型
```shell
python tools/export_model.py -c configs/{MODEL.yml} --slim_config configs/slim/{SLIM_CONFIG.yml} -o weights=output/{SLIM_CONFIG}/model_final
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
# Weights of yolov3_mobilenet_v1_voc
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_voc.pdparams
load_static_weights: False
weight_type: resume
slim: Pruner
Pruner:
criterion: fpgm
pruned_params: ['yolo_block.0.0.0.conv.weights', 'yolo_block.0.0.1.conv.weights', 'yolo_block.0.1.0.conv.weights',
'yolo_block.0.1.1.conv.weights', 'yolo_block.0.2.conv.weights', 'yolo_block.0.tip.conv.weights',
'yolo_block.1.0.0.conv.weights', 'yolo_block.1.0.1.conv.weights', 'yolo_block.1.1.0.conv.weights',
'yolo_block.1.1.1.conv.weights', 'yolo_block.1.2.conv.weights', 'yolo_block.1.tip.conv.weights',
'yolo_block.2.0.0.conv.weights', 'yolo_block.2.0.1.conv.weights', 'yolo_block.2.1.0.conv.weights',
'yolo_block.2.1.1.conv.weights', 'yolo_block.2.2.conv.weights', 'yolo_block.2.tip.conv.weights']
pruned_ratios: [0.1,0.2,0.2,0.2,0.2,0.1,0.2,0.3,0.3,0.3,0.2,0.1,0.3,0.4,0.4,0.4,0.4,0.3]
print_params: False
# Weights of yolov3_mobilenet_v1_voc
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_voc.pdparams
load_static_weights: False
weight_type: resume
slim: Pruner
Pruner:
criterion: l1_norm
pruned_params: ['yolo_block.0.0.0.conv.weights', 'yolo_block.0.0.1.conv.weights', 'yolo_block.0.1.0.conv.weights',
'yolo_block.0.1.1.conv.weights', 'yolo_block.0.2.conv.weights', 'yolo_block.0.tip.conv.weights',
'yolo_block.1.0.0.conv.weights', 'yolo_block.1.0.1.conv.weights', 'yolo_block.1.1.0.conv.weights',
'yolo_block.1.1.1.conv.weights', 'yolo_block.1.2.conv.weights', 'yolo_block.1.tip.conv.weights',
'yolo_block.2.0.0.conv.weights', 'yolo_block.2.0.1.conv.weights', 'yolo_block.2.1.0.conv.weights',
'yolo_block.2.1.1.conv.weights', 'yolo_block.2.2.conv.weights', 'yolo_block.2.tip.conv.weights']
pruned_ratios: [0.1,0.2,0.2,0.2,0.2,0.1,0.2,0.3,0.3,0.3,0.2,0.1,0.3,0.4,0.4,0.4,0.4,0.3]
print_params: False
......@@ -13,4 +13,4 @@
# limitations under the License.
from . import (core, data, engine, modeling, model_zoo, optimizer, metrics,
py_op, utils)
py_op, utils, slim)
......@@ -48,9 +48,18 @@ class Trainer(object):
assert mode.lower() in ['train', 'eval', 'test'], \
"mode should be 'train', 'eval' or 'test'"
self.mode = mode.lower()
self.optimizer = None
# build model
self.model = create(cfg.architecture)
# model slim build
if cfg.slim:
if self.mode == 'train':
self.load_weights(cfg.pretrain_weights, cfg.weight_type)
slim = create(cfg.slim)
slim(self.model)
if ParallelEnv().nranks > 1:
self.model = paddle.DataParallel(self.model)
......@@ -62,7 +71,6 @@ class Trainer(object):
self.dataset, cfg.worker_num)
# build optimizer in train mode
self.optimizer = None
if self.mode == 'train':
steps_per_epoch = len(self.loader)
self.lr = create('LearningRate')(steps_per_epoch)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import prune
from .prune import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle.utils import try_import
from ppdet.core.workspace import register, serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
def print_prune_params(model):
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
logger.info('Parameter name: {}, shape: {}'.format(
weight_name, model_dict[key].shape))
@register
@serializable
class Pruner(object):
def __init__(self,
criterion,
pruned_params,
pruned_ratios,
print_params=False):
super(Pruner, self).__init__()
assert criterion in ['l1_norm', 'fpgm'], \
"unsupported prune criterion: {}".format(criterion)
self.criterion = criterion
self.pruned_params = pruned_params
self.pruned_ratios = pruned_ratios
self.print_params = print_params
def __call__(self, model):
paddleslim = try_import('paddleslim')
from paddleslim.analysis import dygraph_flops as flops
input_spec = [{
"image": paddle.ones(
shape=[1, 3, 640, 640], dtype='float32'),
"im_shape": paddle.full(
[1, 2], 640, dtype='float32'),
"scale_factor": paddle.ones(
shape=[1, 2], dtype='float32')
}]
if self.print_params:
print_prune_params(model)
ori_flops = flops(model, input_spec) / 1000
logger.info("FLOPs before pruning: {}GFLOPs".format(ori_flops))
if self.criterion == 'fpgm':
pruner = paddleslim.dygraph.FPGMFilterPruner(model, input_spec)
elif self.criterion == 'l1_norm':
pruner = paddleslim.dygraph.L1NormFilterPruner(model, input_spec)
logger.info("pruned params: {}".format(self.pruned_params))
pruned_ratios = [float(n) for n in self.pruned_ratios]
ratios = {}
for i, param in enumerate(self.pruned_params):
ratios[param] = pruned_ratios[i]
pruner.prune_vars(ratios, [0])
pruned_flops = flops(model, input_spec) / 1000
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
pruned_flops, (ori_flops - pruned_flops) / ori_flops))
return model
......@@ -49,6 +49,12 @@ def parse_args():
parser.add_argument(
'--json_eval', action='store_true', default=False, help='')
parser.add_argument(
"--slim_config",
default=None,
type=str,
help="Configuration file of slim method.")
parser.add_argument(
'--use_gpu', action='store_true', default=False, help='')
......@@ -72,6 +78,9 @@ def main():
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config)
merge_config(slim_cfg)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
......
......@@ -43,6 +43,11 @@ def parse_args():
type=str,
default="output_inference",
help="Directory for storing the output model files.")
parser.add_argument(
"--slim_config",
default=None,
type=str,
help="Configuration file of slim method.")
args = parser.parse_args()
return args
......@@ -67,6 +72,9 @@ def main():
if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn':
FLAGS.opt['norm_type'] = 'bn'
merge_config(FLAGS.opt)
if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config)
merge_config(slim_cfg)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
......
......@@ -59,6 +59,11 @@ def parse_args():
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--slim_config",
default=None,
type=str,
help="Configuration file of slim method.")
parser.add_argument(
"--use_vdl",
type=bool,
......@@ -126,6 +131,9 @@ def main():
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config)
merge_config(slim_cfg)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
......
......@@ -65,10 +65,10 @@ def parse_args():
default=False,
help="Whether to perform evaluation in train")
parser.add_argument(
"--output_eval",
"--slim_config",
default=None,
type=str,
help="Evaluation directory, default is current directory.")
help="Configuration file of slim method.")
parser.add_argument(
"--enable_ce",
type=bool,
......@@ -92,7 +92,8 @@ def run(FLAGS, cfg):
trainer = Trainer(cfg, mode='train')
# load weights
trainer.load_weights(cfg.pretrain_weights, FLAGS.weight_type)
if not FLAGS.slim_config:
trainer.load_weights(cfg.pretrain_weights, FLAGS.weight_type)
# training
trainer.train()
......@@ -103,6 +104,11 @@ def main():
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config)
merge_config(slim_cfg)
if 'weight_type' not in cfg:
cfg.weight_type = FLAGS.weight_type
check.check_config(cfg)
check.check_gpu(cfg.use_gpu)
check.check_version()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册