From 7e4da85870a29fe8fe5088cac8f490df0946daeb Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Wed, 9 Oct 2019 20:42:08 +0800 Subject: [PATCH] [PaddleSlim]Yolov3 quantization demo (#3440) --- slim/quantization/README.md | 130 +++++++++ slim/quantization/compress.py | 267 ++++++++++++++++++ slim/quantization/eval.py | 184 ++++++++++++ slim/quantization/freeze.py | 243 ++++++++++++++++ .../yolov3_mobilenet_v1_slim.yaml | 20 ++ 5 files changed, 844 insertions(+) create mode 100644 slim/quantization/README.md create mode 100644 slim/quantization/compress.py create mode 100644 slim/quantization/eval.py create mode 100644 slim/quantization/freeze.py create mode 100644 slim/quantization/yolov3_mobilenet_v1_slim.yaml diff --git a/slim/quantization/README.md b/slim/quantization/README.md new file mode 100644 index 000000000..5c53d2758 --- /dev/null +++ b/slim/quantization/README.md @@ -0,0 +1,130 @@ +>运行该示例前请安装Paddle1.6或更高版本 + +# 检测模型量化压缩示例 + +## 概述 + +该示例使用PaddleSlim提供的[量化压缩策略](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#1-quantization-aware-training%E9%87%8F%E5%8C%96%E4%BB%8B%E7%BB%8D)对分类模型进行压缩。 +在阅读该示例前,建议您先了解以下内容: + +- [检测模型的常规训练方法](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleDetection) +- [PaddleSlim使用文档](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md) + + +## 配置文件说明 + +关于配置文件如何编写您可以参考: + +- [PaddleSlim配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8) +- [量化策略配置文件编写说明](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#21-%E9%87%8F%E5%8C%96%E8%AE%AD%E7%BB%83) + +其中save_out_nodes需要得到检测结果的Variable的名称,下面介绍如何确定save_out_nodes的参数 +以MobileNet V1为例,可在compress.py中构建好网络之后,直接打印Variable得到Variable的名称信息。 +代码示例: +``` + eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, + extra_keys) + # print(eval_values) +``` +根据运行结果可看到Variable的名字为:`multiclass_nms_0.tmp_0`。 +## 训练 + +根据 [PaddleCV/PaddleDetection/tools/train.py](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/PaddleDetection/tools/train.py) 编写压缩脚本compress.py。 +在该脚本中定义了Compressor对象,用于执行压缩任务。 + +通过`python compress.py --help`查看可配置参数,简述如下: + +- config: 检测库的配置,其中配置了训练超参数、数据集信息等。 +- slim_file: PaddleSlim的配置文件,参见[配置文件说明](#配置文件说明)。 + +您可以通过运行脚本`run.sh`运行该示例,请确保已正确下载[pretrained model](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)。 + +### 训练时的模型结构 +这部分介绍来源于[量化low-level API介绍](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api#1-%E9%87%8F%E5%8C%96%E8%AE%AD%E7%BB%83low-level-apis%E4%BB%8B%E7%BB%8D)。 + +PaddlePaddle框架中有四个和量化相关的IrPass, 分别是QuantizationTransformPass、QuantizationFreezePass、ConvertToInt8Pass以及TransformForMobilePass。在训练时,对网络应用了QuantizationTransformPass,作用是在网络中的conv2d、depthwise_conv2d、mul等算子的各个输入前插入连续的量化op和反量化op,并改变相应反向算子的某些输入。示例图如下: + +

+
+图1:应用QuantizationTransformPass后的结果 +

+ +### 保存断点(checkpoint) + +如果在配置文件中设置了`checkpoint_path`, 则在压缩任务执行过程中会自动保存断点,当任务异常中断时, +重启任务会自动从`checkpoint_path`路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复, +需要修改配置文件中的`checkpoint_path`,或者将`checkpoint_path`路径下文件清空。 + +>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。 + + +## 评估 + +如果在配置文件中设置了`checkpoint_path`,则每个epoch会保存一个量化后的用于评估的模型, +该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__`和`__params__`两个文件。 +其中,`__model__`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。模型结构和训练时一样。 + +如果不需要保存评估模型,可以在定义Compressor对象时,将`save_eval_model`选项设置为False(默认为True)。 + +脚本slim/quantization/eval.py中为使用该模型在评估数据集上做评估的示例。 + + +## 预测 + +如果在配置文件的量化策略中设置了`float_model_save_path`, `int8_model_save_path`, `mobile_model_save_path`, 在训练结束后,会保存模型量化压缩之后用于预测的模型。接下来介绍这三种预测模型的区别。 + +### float预测模型 +在介绍量化训练时的模型结构时介绍了PaddlePaddle框架中有四个和量化相关的IrPass, 分别是QuantizationTransformPass、QuantizationFreezePass、ConvertToInt8Pass以及TransformForMobilePass。float预测模型是在应用QuantizationFreezePass并删除eval_program中多余的operators之后,保存的模型。 + +QuantizationFreezePass主要用于改变IrGraph中量化op和反量化op的顺序,即将类似图1中的量化op和反量化op顺序改变为图2中的布局。除此之外,QuantizationFreezePass还会将`conv2d`、`depthwise_conv2d`、`mul`等算子的权重离线量化为int8_t范围内的值(但数据类型仍为float32),以减少预测过程中对权重的量化操作,示例如图2: + +

+
+图2:应用QuantizationFreezePass后的结果 +

+ +### int8预测模型 +在对训练网络进行QuantizationFreezePass之后,执行ConvertToInt8Pass, +其主要目的是将执行完QuantizationFreezePass后输出的权重类型由`FP32`更改为`INT8`。换言之,用户可以选择将量化后的权重保存为float32类型(不执行ConvertToInt8Pass)或者int8_t类型(执行ConvertToInt8Pass),示例如图3: + +

+
+图3:应用ConvertToInt8Pass后的结果 +

+ +### mobile预测模型 +经TransformForMobilePass转换后,用户可得到兼容[paddle-lite](https://github.com/PaddlePaddle/Paddle-Lite)移动端预测库的量化模型。paddle-mobile中的量化op和反量化op的名称分别为`quantize`和`dequantize`。`quantize`算子和PaddlePaddle框架中的`fake_quantize_abs_max`算子簇的功能类似,`dequantize` 算子和PaddlePaddle框架中的`fake_dequantize_max_abs`算子簇的功能相同。若选择paddle-mobile执行量化训练输出的模型,则需要将`fake_quantize_abs_max`等算子改为`quantize`算子以及将`fake_dequantize_max_abs`等算子改为`dequantize`算子,示例如图4: + +

+
+图4:应用TransformForMobilePass后的结果 +

+ +### python预测 + + +### PaddleLite预测 +float预测模型可使用PaddleLite进行加载预测,可参见教程[Paddle-Lite如何加载运行量化模型](https://github.com/PaddlePaddle/Paddle-Lite/wiki/model_quantization) + +## 从评估模型保存预测模型 +从[配置文件说明](#配置文件说明)中可以看到,在 `end_epoch` 时将保存可用于预测的 `float`, `int8`, `mobile`模型,但是在训练之前不能准确地保存结果最好的epoch的结果,因此,提供了从`${checkpoint_path}/${epoch_id}/eval_model/`下保存的评估模型转化为预测模型的接口 `freeze.py `, 需要配置的参数为: + +- model_path, 加载的模型路径,`为${checkpoint_path}/${epoch_id}/eval_model/` +- weight_quant_type 模型参数的量化方式,和配置文件中的类型保持一致 +- save_path `float`, `int8`, `mobile`模型的保存路径,分别为 `${save_path}/float/`, `${save_path}/int8/`, `${save_path}/mobile/` + +## 示例结果 + +### MobileNetV1 + +| weight量化方式 | activation量化方式| Box ap |Paddle Fluid inference time(ms)| Paddle Lite inference time(ms)| +|---|---|---|---|---| +|baseline|- |76.2%|- |-| +|abs_max|abs_max|- |- |-| +|abs_max|moving_average_abs_max|- |- |-| +|channel_wise_abs_max|abs_max|- |- |-| + +>训练超参: + + +## FAQ diff --git a/slim/quantization/compress.py b/slim/quantization/compress.py new file mode 100644 index 000000000..ef4319e14 --- /dev/null +++ b/slim/quantization/compress.py @@ -0,0 +1,267 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import multiprocessing +import numpy as np +import datetime +from collections import deque +import sys +sys.path.append("../../") +from paddle.fluid.contrib.slim import Compressor +from paddle.fluid.framework import IrGraph +from paddle.fluid import core + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + +# NOTE(paddle-dev): All of these flags should be set before +# `import paddle`. Otherwise, it would not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.data.data_feed import create_reader + +from ppdet.utils.eval_utils import parse_fetches, eval_results +from ppdet.utils.stats import TrainingStats +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu +import ppdet.utils.checkpoint as checkpoint +from ppdet.modeling.model_input import create_feed + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) +def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): + """ + Run evaluation program, return program outputs. + """ + iter_id = 0 + results = [] + if len(cls) != 0: + values = [] + for i in range(len(cls)): + _, accum_map = cls[i].get_map_var() + cls[i].reset(exe) + values.append(accum_map) + + images_num = 0 + start_time = time.time() + has_bbox = 'bbox' in keys + for data in reader(): + data = test_feed.feed(data) + feed_data = {'image': data['image'], + 'im_size': data['im_size']} + outs = exe.run(compile_program, + feed=feed_data, + fetch_list=values[0], + return_numpy=False) + outs.append(data['gt_box']) + outs.append(data['gt_label']) + outs.append(data['is_difficult']) + res = { + k: (np.array(v), v.recursive_sequence_lengths()) + for k, v in zip(keys, outs) + } + results.append(res) + if iter_id % 100 == 0: + logger.info('Test iter {}'.format(iter_id)) + iter_id += 1 + images_num += len(res['bbox'][1][0]) if has_bbox else 1 + logger.info('Test finish iter {}'.format(iter_id)) + + end_time = time.time() + fps = images_num / (end_time - start_time) + if has_bbox: + logger.info('Total number of images: {}, inference time: {} fps.'. + format(images_num, fps)) + else: + logger.info('Total iteration: {}, inference time: {} batch/s.'.format( + images_num, fps)) + + return results + + +def main(): + cfg = load_config(FLAGS.config) + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + if 'log_iter' not in cfg: + cfg.log_iter = 20 + + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + + if cfg.use_gpu: + devices_num = fluid.core.get_cuda_device_count() + else: + devices_num = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + if 'train_feed' not in cfg: + train_feed = create(main_arch + 'TrainFeed') + else: + train_feed = create(cfg.train_feed) + + if 'eval_feed' not in cfg: + eval_feed = create(main_arch + 'EvalFeed') + else: + eval_feed = create(cfg.eval_feed) + + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + lr_builder = create('LearningRate') + optim_builder = create('OptimizerBuilder') + + # build program + startup_prog = fluid.Program() + train_prog = fluid.Program() + with fluid.program_guard(train_prog, startup_prog): + with fluid.unique_name.guard(): + model = create(main_arch) + train_pyreader, feed_vars = create_feed(train_feed) + train_fetches = model.train(feed_vars) + loss = train_fetches['loss'] + lr = lr_builder() + optimizer = optim_builder(lr) + optimizer.minimize(loss) + + + train_reader = create_reader(train_feed, cfg.max_iters * devices_num, + FLAGS.dataset_dir) + train_pyreader.decorate_sample_list_generator(train_reader, place) + + # parse train fetches + train_keys, train_values, _ = parse_fetches(train_fetches) + train_values.append(lr) + + train_fetch_list=[] + for k, v in zip(train_keys, train_values): + train_fetch_list.append((k, v)) + print("train_fetch_list: {}".format(train_fetch_list)) + + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + model = create(main_arch) + eval_pyreader, test_feed_vars = create_feed(eval_feed, use_pyreader=False) + fetches = model.eval(test_feed_vars) + eval_prog = eval_prog.clone(True) + + eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) + #eval_pyreader.decorate_sample_list_generator(eval_reader, place) + test_data_feed = fluid.DataFeeder(test_feed_vars.values(), place) + + # parse eval fetches + extra_keys = [] + if cfg.metric == 'COCO': + extra_keys = ['im_info', 'im_id', 'im_shape'] + if cfg.metric == 'VOC': + extra_keys = ['gt_box', 'gt_label', 'is_difficult'] + eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, + extra_keys) + # print(eval_values) + + eval_fetch_list=[] + for k, v in zip(eval_keys, eval_values): + eval_fetch_list.append((k, v)) + + + exe.run(startup_prog) + + start_iter = 0 + checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights) + + + def eval_func(program, scope): + + #place = fluid.CPUPlace() + #exe = fluid.Executor(place) + results = eval_run(exe, program, eval_reader, + eval_keys, eval_values, eval_cls, test_data_feed) + best_box_ap_list = [] + + resolution = None + if 'mask' in results[0]: + resolution = model.mask_head.resolution + box_ap_stats = eval_results(results, eval_feed, cfg.metric, cfg.num_classes, + resolution, False, FLAGS.output_eval) + if len(best_box_ap_list) == 0: + best_box_ap_list.append(box_ap_stats[0]) + elif box_ap_stats[0] > best_box_ap_list[0]: + best_box_ap_list[0] = box_ap_stats[0] + checkpoint.save(exe, train_prog, os.path.join(save_dir,"best_model")) + logger.info("Best test box ap: {}".format( + best_box_ap_list[0])) + return best_box_ap_list[0] + + test_feed = [('image', test_feed_vars['image'].name), + ('im_size', test_feed_vars['im_size'].name)] + + com = Compressor( + place, + fluid.global_scope(), + train_prog, + train_reader=train_pyreader, + train_feed_list=None, + train_fetch_list=train_fetch_list, + eval_program=eval_prog, + eval_reader=eval_reader, + eval_feed_list=test_feed, + eval_func={'map': eval_func}, + eval_fetch_list=[eval_fetch_list[0]], + train_optimizer=None) + com.config(FLAGS.slim_file) + com.run() + + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "-s", + "--slim_file", + default=None, + type=str, + help="Config file of PaddleSlim.") + parser.add_argument( + "--output_eval", + default=None, + type=str, + help="Evaluation directory, default is current directory.") + parser.add_argument( + "-d", + "--dataset_dir", + default=None, + type=str, + help="Dataset path, same as DataFeed.dataset.dataset_dir") + FLAGS = parser.parse_args() + main() diff --git a/slim/quantization/eval.py b/slim/quantization/eval.py new file mode 100644 index 000000000..df817a662 --- /dev/null +++ b/slim/quantization/eval.py @@ -0,0 +1,184 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import multiprocessing +import numpy as np +import datetime +from collections import deque +import sys +sys.path.append("../../") +from paddle.fluid.contrib.slim import Compressor +from paddle.fluid.framework import IrGraph +from paddle.fluid import core +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass +from paddle.fluid.contrib.slim.quantization import TransformForMobilePass + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + +# NOTE(paddle-dev): All of these flags should be set before +# `import paddle`. Otherwise, it would not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.data.data_feed import create_reader + +from ppdet.utils.eval_utils import parse_fetches, eval_results +from ppdet.utils.stats import TrainingStats +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu +import ppdet.utils.checkpoint as checkpoint +from ppdet.modeling.model_input import create_feed + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) +def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): + """ + Run evaluation program, return program outputs. + """ + iter_id = 0 + results = [] + + images_num = 0 + start_time = time.time() + has_bbox = 'bbox' in keys + for data in reader(): + data = test_feed.feed(data) + feed_data = {'image': data['image'], + 'im_size': data['im_size']} + outs = exe.run(compile_program, + feed=feed_data, + fetch_list=values[0], + return_numpy=False) + outs.append(data['gt_box']) + outs.append(data['gt_label']) + outs.append(data['is_difficult']) + res = { + k: (np.array(v), v.recursive_sequence_lengths()) + for k, v in zip(keys, outs) + } + results.append(res) + if iter_id % 100 == 0: + logger.info('Test iter {}'.format(iter_id)) + iter_id += 1 + images_num += len(res['bbox'][1][0]) if has_bbox else 1 + logger.info('Test finish iter {}'.format(iter_id)) + + end_time = time.time() + fps = images_num / (end_time - start_time) + if has_bbox: + logger.info('Total number of images: {}, inference time: {} fps.'. + format(images_num, fps)) + else: + logger.info('Total iteration: {}, inference time: {} batch/s.'.format( + images_num, fps)) + + return results + + +def main(): + cfg = load_config(FLAGS.config) + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + if 'log_iter' not in cfg: + cfg.log_iter = 20 + + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + + if cfg.use_gpu: + devices_num = fluid.core.get_cuda_device_count() + else: + devices_num = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + + if 'eval_feed' not in cfg: + eval_feed = create(main_arch + 'EvalFeed') + else: + eval_feed = create(cfg.eval_feed) + + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + eval_pyreader, test_feed_vars = create_feed(eval_feed, use_pyreader=False) + + eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) + #eval_pyreader.decorate_sample_list_generator(eval_reader, place) + test_data_feed = fluid.DataFeeder(test_feed_vars.values(), place) + + + assert os.path.exists(FLAGS.model_path) + infer_prog, feed_names, fetch_targets = fluid.io.load_inference_model( + dirname=FLAGS.model_path, executor=exe, + model_filename='model', + params_filename='params') + + eval_keys = ['bbox', 'gt_box', 'gt_label', 'is_difficult'] + eval_values = ['multiclass_nms_0.tmp_0', 'gt_box', 'gt_label', 'is_difficult'] + eval_cls = [] + eval_values[0] = fetch_targets[0] + + results = eval_run(exe, infer_prog, eval_reader, + eval_keys, eval_values, eval_cls, test_data_feed) + + resolution = None + if 'mask' in results[0]: + resolution = model.mask_head.resolution + eval_results(results, eval_feed, cfg.metric, cfg.num_classes, + resolution, False, FLAGS.output_eval) + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "-m", + "--model_path", + default=None, + type=str, + help="path of checkpoint") + parser.add_argument( + "--output_eval", + default=None, + type=str, + help="Evaluation directory, default is current directory.") + parser.add_argument( + "-d", + "--dataset_dir", + default=None, + type=str, + help="Dataset path, same as DataFeed.dataset.dataset_dir") + + FLAGS = parser.parse_args() + main() diff --git a/slim/quantization/freeze.py b/slim/quantization/freeze.py new file mode 100644 index 000000000..f9785a080 --- /dev/null +++ b/slim/quantization/freeze.py @@ -0,0 +1,243 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import multiprocessing +import numpy as np +import datetime +from collections import deque +import sys +sys.path.append("../../") +from paddle.fluid.contrib.slim import Compressor +from paddle.fluid.framework import IrGraph +from paddle.fluid import core +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass +from paddle.fluid.contrib.slim.quantization import TransformForMobilePass + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + +# NOTE(paddle-dev): All of these flags should be set before +# `import paddle`. Otherwise, it would not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.data.data_feed import create_reader + +from ppdet.utils.eval_utils import parse_fetches, eval_results +from ppdet.utils.stats import TrainingStats +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu +import ppdet.utils.checkpoint as checkpoint +from ppdet.modeling.model_input import create_feed + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) +def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): + """ + Run evaluation program, return program outputs. + """ + iter_id = 0 + results = [] + + images_num = 0 + start_time = time.time() + has_bbox = 'bbox' in keys + for data in reader(): + data = test_feed.feed(data) + feed_data = {'image': data['image'], + 'im_size': data['im_size']} + outs = exe.run(compile_program, + feed=feed_data, + fetch_list=values[0], + return_numpy=False) + outs.append(data['gt_box']) + outs.append(data['gt_label']) + outs.append(data['is_difficult']) + res = { + k: (np.array(v), v.recursive_sequence_lengths()) + for k, v in zip(keys, outs) + } + results.append(res) + if iter_id % 100 == 0: + logger.info('Test iter {}'.format(iter_id)) + iter_id += 1 + images_num += len(res['bbox'][1][0]) if has_bbox else 1 + logger.info('Test finish iter {}'.format(iter_id)) + + end_time = time.time() + fps = images_num / (end_time - start_time) + if has_bbox: + logger.info('Total number of images: {}, inference time: {} fps.'. + format(images_num, fps)) + else: + logger.info('Total iteration: {}, inference time: {} batch/s.'.format( + images_num, fps)) + + return results + + +def main(): + cfg = load_config(FLAGS.config) + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + if 'log_iter' not in cfg: + cfg.log_iter = 20 + + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + + if cfg.use_gpu: + devices_num = fluid.core.get_cuda_device_count() + else: + devices_num = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + + if 'eval_feed' not in cfg: + eval_feed = create(main_arch + 'EvalFeed') + else: + eval_feed = create(cfg.eval_feed) + + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + eval_pyreader, test_feed_vars = create_feed(eval_feed, use_pyreader=False) + + eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) + #eval_pyreader.decorate_sample_list_generator(eval_reader, place) + test_data_feed = fluid.DataFeeder(test_feed_vars.values(), place) + + + assert os.path.exists(FLAGS.model_path) + infer_prog, feed_names, fetch_targets = fluid.io.load_inference_model( + dirname=FLAGS.model_path, executor=exe, + model_filename='__model__', + params_filename='__params__') + + eval_keys = ['bbox', 'gt_box', 'gt_label', 'is_difficult'] + eval_values = ['multiclass_nms_0.tmp_0', 'gt_box', 'gt_label', 'is_difficult'] + eval_cls = [] + eval_values[0] = fetch_targets[0] + + results = eval_run(exe, infer_prog, eval_reader, + eval_keys, eval_values, eval_cls, test_data_feed) + + resolution = None + if 'mask' in results[0]: + resolution = model.mask_head.resolution + box_ap_stats = eval_results(results, eval_feed, cfg.metric, cfg.num_classes, + resolution, False, FLAGS.output_eval) + + logger.info("freeze the graph for inference") + test_graph = IrGraph(core.Graph(infer_prog.desc), for_test=True) + + freeze_pass = QuantizationFreezePass( + scope=fluid.global_scope(), + place=place, + weight_quantize_type=FLAGS.weight_quant_type) + freeze_pass.apply(test_graph) + server_program = test_graph.to_program() + fluid.io.save_inference_model( + dirname=os.path.join(FLAGS.save_path, 'float'), + feeded_var_names=feed_names, + target_vars=fetch_targets, + executor=exe, + main_program=server_program, + model_filename='model', + params_filename='params') + + logger.info("convert the weights into int8 type") + convert_int8_pass = ConvertToInt8Pass( + scope=fluid.global_scope(), + place=place) + convert_int8_pass.apply(test_graph) + server_int8_program = test_graph.to_program() + fluid.io.save_inference_model( + dirname=os.path.join(FLAGS.save_path, 'int8'), + feeded_var_names=feed_names, + target_vars=fetch_targets, + executor=exe, + main_program=server_int8_program, + model_filename='model', + params_filename='params') + + logger.info("convert the freezed pass to paddle-lite execution") + mobile_pass = TransformForMobilePass() + mobile_pass.apply(test_graph) + mobile_program = test_graph.to_program() + fluid.io.save_inference_model( + dirname=os.path.join(FLAGS.save_path, 'mobile'), + feeded_var_names=feed_names, + target_vars=fetch_targets, + executor=exe, + main_program=mobile_program, + model_filename='model', + params_filename='params') + + + + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "-m", + "--model_path", + default=None, + type=str, + help="path of checkpoint") + parser.add_argument( + "--output_eval", + default=None, + type=str, + help="Evaluation directory, default is current directory.") + parser.add_argument( + "-d", + "--dataset_dir", + default=None, + type=str, + help="Dataset path, same as DataFeed.dataset.dataset_dir") + parser.add_argument( + "--weight_quant_type", + default='abs_max', + type=str, + help="quantization type for weight") + parser.add_argument( + "--save_path", + default='./output', + type=str, + help="path to save quantization inference model") + + FLAGS = parser.parse_args() + main() diff --git a/slim/quantization/yolov3_mobilenet_v1_slim.yaml b/slim/quantization/yolov3_mobilenet_v1_slim.yaml new file mode 100644 index 000000000..3ad506b1c --- /dev/null +++ b/slim/quantization/yolov3_mobilenet_v1_slim.yaml @@ -0,0 +1,20 @@ +version: 1.0 +strategies: + quantization_strategy: + class: 'QuantizationStrategy' + start_epoch: 0 + end_epoch: 0 + float_model_save_path: './output/yolov3/float' + mobile_model_save_path: './output/yolov3/mobile' + int8_model_save_path: './output/yolov3/int8' + weight_bits: 8 + activation_bits: 8 + weight_quantize_type: 'abs_max' + activation_quantize_type: 'moving_average_abs_max' + save_in_nodes: ['image', 'im_size'] + save_out_nodes: ['multiclass_nms_0.tmp_0'] +compressor: + epoch: 1 + checkpoint_path: './checkpoints/yolov3/' + strategies: + - quantization_strategy -- GitLab