From 366eb59c9ff75d8cab093f0449ead860e0b75649 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 23 Mar 2020 11:37:54 +0800 Subject: [PATCH] add prune export_model (#378) --- .../extensions/distill_pruned_model/README.md | 12 ++ slim/prune/README.md | 14 +- slim/prune/eval.py | 2 +- slim/prune/export_model.py | 157 ++++++++++++++++++ 4 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 slim/prune/export_model.py diff --git a/slim/extensions/distill_pruned_model/README.md b/slim/extensions/distill_pruned_model/README.md index ea096b0ee..e82129104 100644 --- a/slim/extensions/distill_pruned_model/README.md +++ b/slim/extensions/distill_pruned_model/README.md @@ -67,3 +67,15 @@ python ../../prune/eval.py \ --pruned_ratios="0.2,0.3,0.4" \ -o weights=output/yolov3_mobilenet_v1_voc/model_final ``` + +## 6. 模型导出 + +如果想要将剪裁模型接入到C++预测库或者Serving服务,可通过`../../prune/export_model.py`导出该模型。 + +``` +python ../../prune/export_model.py \ +-c ../../../configs/yolov3_mobilenet_v1_voc.yml \ +--pruned_params "yolo_block.0.0.0.conv.weights,yolo_block.0.0.1.conv.weights,yolo_block.0.1.0.conv.weights" \ +--pruned_ratios="0.2,0.3,0.4" \ +-o weights=output/yolov3_mobilenet_v1_voc/model_final +``` diff --git a/slim/prune/README.md b/slim/prune/README.md index ef354dbf3..227b87a75 100644 --- a/slim/prune/README.md +++ b/slim/prune/README.md @@ -74,7 +74,19 @@ python eval.py \ -o weights=output/yolov3_mobilenet_v1_voc/model_final ``` -## 7. 扩展模型 +## 7. 模型导出 + +如果想要将剪裁模型接入到C++预测库或者Serving服务,可通过`export_model.py`导出该模型。 + +``` +python export_model.py \ +-c ../../configs/yolov3_mobilenet_v1_voc.yml \ +--pruned_params "yolo_block.0.0.0.conv.weights,yolo_block.0.0.1.conv.weights,yolo_block.0.1.0.conv.weights" \ +--pruned_ratios="0.2,0.3,0.4" \ +-o weights=output/yolov3_mobilenet_v1_voc/model_final +``` + +## 8. 扩展模型 如果需要对自己的模型进行修改,可以参考`prune.py`中对`paddleslim.prune.Pruner`接口的调用方式,基于自己的模型训练脚本进行修改。 本节我们介绍的剪裁示例,需要用户根据先验知识指定每层的剪裁率,除此之外,PaddleSlim还提供了敏感度分析等功能,协助用户选择合适的剪裁率。更多详情请参考:[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/) diff --git a/slim/prune/eval.py b/slim/prune/eval.py index 7d421a277..db74a2174 100644 --- a/slim/prune/eval.py +++ b/slim/prune/eval.py @@ -176,7 +176,7 @@ def main(): # load model exe.run(startup_prog) if 'weights' in cfg: - checkpoint.load_params(exe, eval_prog, cfg.weights) + checkpoint.load_checkpoint(exe, eval_prog, cfg.weights) results = eval_run(exe, compile_program, loader, keys, values, cls, cfg, sub_eval_prog, sub_keys, sub_values) diff --git a/slim/prune/export_model.py b/slim/prune/export_model.py new file mode 100644 index 000000000..bb5863bc4 --- /dev/null +++ b/slim/prune/export_model.py @@ -0,0 +1,157 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.utils.cli import ArgsParser +import ppdet.utils.checkpoint as checkpoint +from paddleslim.prune import Pruner +from paddleslim.analysis import flops + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def prune_feed_vars(feeded_var_names, target_vars, prog): + """ + Filter out feed variables which are not in program, + pruned feed variables are only used in post processing + on model output, which are not used in program, such + as im_id to identify image order, im_shape to clip bbox + in image. + """ + exist_var_names = [] + prog = prog.clone() + prog = prog._prune(targets=target_vars) + global_block = prog.global_block() + for name in feeded_var_names: + try: + v = global_block.var(name) + exist_var_names.append(str(v.name)) + except Exception: + logger.info('save_inference_model pruned unused feed ' + 'variables {}'.format(name)) + pass + return exist_var_names + + +def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): + cfg_name = os.path.basename(FLAGS.config).split('.')[0] + save_dir = os.path.join(FLAGS.output_dir, cfg_name) + feed_var_names = [var.name for var in feed_vars.values()] + target_vars = list(test_fetches.values()) + feed_var_names = prune_feed_vars(feed_var_names, target_vars, infer_prog) + logger.info("Export inference model to {}, input: {}, output: " + "{}...".format(save_dir, feed_var_names, + [str(var.name) for var in target_vars])) + fluid.io.save_inference_model( + save_dir, + feeded_var_names=feed_var_names, + target_vars=target_vars, + executor=exe, + main_program=infer_prog, + params_filename="__params__") + + +def main(): + cfg = load_config(FLAGS.config) + + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + + # Use CPU for exporting inference model instead of GPU + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + model = create(main_arch) + + startup_prog = fluid.Program() + infer_prog = fluid.Program() + with fluid.program_guard(infer_prog, startup_prog): + with fluid.unique_name.guard(): + inputs_def = cfg['TestReader']['inputs_def'] + inputs_def['use_dataloader'] = False + feed_vars, _ = model.build_inputs(**inputs_def) + test_fetches = model.test(feed_vars) + infer_prog = infer_prog.clone(True) + + pruned_params = FLAGS.pruned_params + assert ( + FLAGS.pruned_params is not None + ), "FLAGS.pruned_params is empty!!! Please set it by '--pruned_params' option." + pruned_params = FLAGS.pruned_params.strip().split(",") + logger.info("pruned params: {}".format(pruned_params)) + pruned_ratios = [float(n) for n in FLAGS.pruned_ratios.strip().split(",")] + logger.info("pruned ratios: {}".format(pruned_ratios)) + assert (len(pruned_params) == len(pruned_ratios) + ), "The length of pruned params and pruned ratios should be equal." + assert (pruned_ratios > [0] * len(pruned_ratios) and + pruned_ratios < [1] * len(pruned_ratios) + ), "The elements of pruned ratios should be in range (0, 1)." + + base_flops = flops(infer_prog) + pruner = Pruner() + infer_prog, _, _ = pruner.prune( + infer_prog, + fluid.global_scope(), + params=pruned_params, + ratios=pruned_ratios, + place=place, + only_graph=True) + pruned_flops = flops(infer_prog) + logger.info("pruned FLOPS: {}".format( + float(base_flops - pruned_flops) / base_flops)) + + exe.run(startup_prog) + checkpoint.load_checkpoint(exe, infer_prog, cfg.weights) + + save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output model files.") + + parser.add_argument( + "-p", + "--pruned_params", + default=None, + type=str, + help="The parameters to be pruned when calculating sensitivities.") + parser.add_argument( + "--pruned_ratios", + default=None, + type=str, + help="The ratios pruned iteratively for each parameter when calculating sensitivities." + ) + + FLAGS = parser.parse_args() + main() -- GitLab