diff --git a/demo/auto_compression/detection/README.md b/demo/auto_compression/detection/README.md index d1e9c7d8538e74709ad81dbe71bf3af82002041e..2e842475e909cea0e46ec30685e4efa047eb1414 100644 --- a/demo/auto_compression/detection/README.md +++ b/demo/auto_compression/detection/README.md @@ -113,24 +113,29 @@ wget https://bj.bcebos.com/v1/paddle-slim-models/detection/ppyoloe_crn_l_300e_co tar -xf ppyoloe_crn_l_300e_coco.tar ``` -#### 3.4. 测试模型精度 +**注意**:TinyPose模型暂不支持精度测试。 -使用run.py脚本得到模型的mAP: +#### 3.4 自动压缩并产出模型 + +蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: ``` +# 单卡 export CUDA_VISIBLE_DEVICES=0 -python run.py --config_path=./configs/ppyoloe_l_qat_dis.yaml --eval=True +# 多卡 +# export CUDA_VISIBLE_DEVICES=0,1,2,3 +python run.py --config_path=./configs/ppyoloe_l_qat_dis.yaml --save_dir='./output/' ``` -**注意**:TinyPose模型暂不支持精度测试。 +#### 3.5 测试模型精度 -#### 3.5 自动压缩并产出模型 - -蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: +使用run.py脚本得到模型的mAP: ``` export CUDA_VISIBLE_DEVICES=0 -python run.py --config_path=./configs/ppyoloe_l_qat_dis.yaml --save_dir='./output/' +python eval.py --config_path=./configs/ppyoloe_l_qat_dis.yaml ``` +**注意**:要测试的模型路径可以在配置文件中`model_dir`字段下进行修改。 + ## 4.预测部署 diff --git a/demo/auto_compression/detection/configs/ssd_mbv1_voc_qat_dis.yaml b/demo/auto_compression/detection/configs/ssd_mbv1_voc_qat_dis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3030a283d4ca8c1ac60c5eff4558d25ef21f7363 --- /dev/null +++ b/demo/auto_compression/detection/configs/ssd_mbv1_voc_qat_dis.yaml @@ -0,0 +1,33 @@ +Global: + reader_config: configs/ssd_reader.yml + input_list: ['image', 'scale_factor', 'im_shape'] + Evaluation: True + model_dir: ./ssd_mobilenet_v1_300_120e_voc/ + model_filename: model.pdmodel + params_filename: model.pdiparams + +Distillation: + alpha: 1.0 + loss: l2 + node: + - concat_0.tmp_0 + - concat_2.tmp_0 + +Quantization: + activation_quantize_type: 'range_abs_max' + quantize_op_types: + - conv2d + - depthwise_conv2d + +TrainConfig: + train_iter: 80000 + eval_iter: 1000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.00001 + T_max: 120000 + optimizer_builder: + optimizer: + type: SGD + weight_decay: 4.0e-05 + diff --git a/demo/auto_compression/detection/configs/ssd_reader.yml b/demo/auto_compression/detection/configs/ssd_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..4d8ef6844c96e7d2515cb792704cb0b3319ea8de --- /dev/null +++ b/demo/auto_compression/detection/configs/ssd_reader.yml @@ -0,0 +1,30 @@ +metric: VOC +map_type: 11point +num_classes: 20 + +# Datset configuration +TrainDataset: + !VOCDataSet + dataset_dir: dataset/voc + anno_path: trainval.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +EvalDataset: + !VOCDataSet + dataset_dir: dataset/voc + anno_path: test.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +worker_num: 0 + +# preprocess reader in test +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [300, 300], keep_ratio: False, interp: 1} + - NormalizeImage: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false} + - Permute: {} + batch_size: 16 + collate_batch: false diff --git a/demo/auto_compression/detection/eval.py b/demo/auto_compression/detection/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..b91d4300f11ffe10ff3fa9b2d9ce035ac44512ad --- /dev/null +++ b/demo/auto_compression/detection/eval.py @@ -0,0 +1,167 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.core.workspace import create +from ppdet.metrics import COCOMetric, VOCMetric +from paddleslim.auto_compression.config_helpers import load_config as load_slim_config + +from post_process import YOLOv5PostProcess + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + + return parser + + +def print_arguments(args): + print('----------- Running Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------') + + +def reader_wrapper(reader, input_list): + def gen(): + for data in reader: + in_dict = {} + if isinstance(input_list, list): + for input_name in input_list: + in_dict[input_name] = data[input_name] + elif isinstance(input_list, dict): + for input_name in input_list.keys(): + in_dict[input_list[input_name]] = data[input_name] + yield in_dict + + return gen + + +def convert_numpy_data(data, metric): + data_all = {} + data_all = {k: np.array(v) for k, v in data.items()} + if isinstance(metric, VOCMetric): + for k, v in data_all.items(): + if not isinstance(v[0], np.ndarray): + tmp_list = [] + for t in v: + tmp_list.append(np.array(t)) + data_all[k] = np.array(tmp_list) + else: + data_all = {k: np.array(v) for k, v in data.items()} + return data_all + + +def eval(): + + place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + + val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model( + global_config["model_dir"], + exe, + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"]) + print('Loaded model from: {}'.format(global_config["model_dir"])) + + metric = global_config['metric'] + for batch_id, data in enumerate(val_loader): + data_all = convert_numpy_data(data, metric) + data_input = {} + for k, v in data.items(): + if isinstance(global_config['input_list'], list): + if k in global_config['input_list']: + data_input[k] = np.array(v) + elif isinstance(global_config['input_list'], dict): + if k in global_config['input_list'].keys(): + data_input[global_config['input_list'][k]] = np.array(v) + + outs = exe.run(val_program, + feed=data_input, + fetch_list=fetch_targets, + return_numpy=False) + res = {} + if 'arch' in global_config and global_config['arch'] == 'YOLOv5': + postprocess = YOLOv5PostProcess( + score_threshold=0.001, nms_threshold=0.6, multi_label=True) + res = postprocess(np.array(outs[0]), data_all['scale_factor']) + else: + for out in outs: + v = np.array(out) + if len(v.shape) > 1: + res['bbox'] = v + else: + res['bbox_num'] = v + metric.update(data_all, res) + if batch_id % 100 == 0: + print('Eval iter:', batch_id) + metric.accumulate() + metric.log() + metric.reset() + + +def main(): + global global_config + _, _, global_config = load_slim_config(FLAGS.config_path) + reader_cfg = load_config(global_config['reader_config']) + + dataset = reader_cfg['EvalDataset'] + global val_loader + val_loader = create('EvalReader')(reader_cfg['EvalDataset'], + reader_cfg['worker_num'], + return_list=True) + metric = None + if reader_cfg['metric'] == 'COCO': + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + elif reader_cfg['metric'] == 'VOC': + metric = VOCMetric( + label_list=dataset.get_label_list(), + class_num=reader_cfg['num_classes'], + map_type=reader_cfg['map_type']) + else: + raise ValueError("metric currently only supports COCO and VOC.") + global_config['metric'] = metric + + eval() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + print_arguments(FLAGS) + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/demo/auto_compression/detection/run.py b/demo/auto_compression/detection/run.py index 05d0428fade54821dc2867caf202efa91a1dae82..2d6dcc9ba7d1e3e1cea810bf39b734ee79406906 100644 --- a/demo/auto_compression/detection/run.py +++ b/demo/auto_compression/detection/run.py @@ -19,7 +19,7 @@ import argparse import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create -from ppdet.metrics import COCOMetric +from ppdet.metrics import COCOMetric, VOCMetric from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression @@ -72,65 +72,25 @@ def reader_wrapper(reader, input_list): return gen -def eval(config): - - place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() - exe = paddle.static.Executor(place) - - val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model( - config["model_dir"], - exe, - model_filename=config["model_filename"], - params_filename=config["params_filename"], ) - clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} - - anno_file = dataset.get_anno() - metric = COCOMetric( - anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') - for batch_id, data in enumerate(val_loader): +def convert_numpy_data(data, metric): + data_all = {} + data_all = {k: np.array(v) for k, v in data.items()} + if isinstance(metric, VOCMetric): + for k, v in data_all.items(): + if not isinstance(v[0], np.ndarray): + tmp_list = [] + for t in v: + tmp_list.append(np.array(t)) + data_all[k] = np.array(tmp_list) + else: data_all = {k: np.array(v) for k, v in data.items()} - data_input = {} - for k, v in data.items(): - if isinstance(config['input_list'], list): - if k in config['input_list']: - data_input[k] = np.array(v) - elif isinstance(config['input_list'], dict): - if k in config['input_list'].keys(): - data_input[config['input_list'][k]] = np.array(v) - - outs = exe.run(val_program, - feed=data_input, - fetch_list=fetch_targets, - return_numpy=False) - res = {} - if 'arch' in config and config['arch'] == 'YOLOv5': - postprocess = YOLOv5PostProcess( - score_threshold=0.001, nms_threshold=0.6, multi_label=True) - res = postprocess(np.array(outs[0]), data_all['scale_factor']) - else: - for out in outs: - v = np.array(out) - if len(v.shape) > 1: - res['bbox'] = v - else: - res['bbox_num'] = v - - metric.update(data_all, res) - if batch_id % 100 == 0: - print('Eval iter:', batch_id) - metric.accumulate() - metric.log() - metric.reset() + return data_all def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): - clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} - - anno_file = dataset.get_anno() - metric = COCOMetric( - anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + metric = global_config['metric'] for batch_id, data in enumerate(val_loader): - data_all = {k: np.array(v) for k, v in data.items()} + data_all = convert_numpy_data(data, metric) data_input = {} for k, v in data.items(): if isinstance(global_config['input_list'], list): @@ -177,16 +137,25 @@ def main(): return_list=True) train_loader = reader_wrapper(train_loader, global_config['input_list']) - global dataset dataset = reader_cfg['EvalDataset'] global val_loader val_loader = create('EvalReader')(reader_cfg['EvalDataset'], reader_cfg['worker_num'], return_list=True) - - if FLAGS.eval: - eval(global_config) - sys.exit(0) + metric = None + if reader_cfg['metric'] == 'COCO': + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + elif reader_cfg['metric'] == 'VOC': + metric = VOCMetric( + label_list=dataset.get_label_list(), + class_num=reader_cfg['num_classes'], + map_type=reader_cfg['map_type']) + else: + raise ValueError("metric currently only supports COCO and VOC.") + global_config['metric'] = metric if 'Evaluation' in global_config.keys() and global_config['Evaluation']: eval_func = eval_function diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index b7b1c26b6d18778db1482bb6eba9838c3475a665..4ea8cf4a00f6a064af4f88f5ca9ead0a863d015f 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -611,7 +611,8 @@ class AutoCompression: def _start_train(self, train_program_info, test_program_info, strategy): best_metric = -1.0 - total_epochs = self.train_config.epochs if self.train_config.epochs else 1 + total_epochs = self.train_config.epochs if self.train_config.epochs else 100 + total_train_iter = 0 for epoch_id in range(total_epochs): for batch_id, data in enumerate(self.train_dataloader()): np_probs_float, = self._exe.run(train_program_info.program, \ @@ -627,11 +628,13 @@ class AutoCompression: else: logging_iter = self.train_config.logging_iter if batch_id % int(logging_iter) == 0: - _logger.info("epoch: {}, batch: {}, loss: {}".format( - epoch_id, batch_id, np_probs_float)) - - if batch_id % int( - self.train_config.eval_iter) == 0 and batch_id != 0: + _logger.info( + "Total iter: {}, epoch: {}, batch: {}, loss: {}".format( + total_train_iter, epoch_id, batch_id, + np_probs_float)) + total_train_iter += 1 + if total_train_iter % int(self.train_config.eval_iter + ) == 0 and total_train_iter != 0: if self.eval_function is not None: # GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation. @@ -644,8 +647,9 @@ class AutoCompression: test_program_info.fetch_targets) _logger.info( - "epoch: {}, batch: {} metric of compressed model is: {}, best metric of compressed model is {}". - format(epoch_id, batch_id, metric, best_metric)) + "epoch: {} metric of compressed model is: {:.6f}, best metric of compressed model is {:.6f}". + format(epoch_id, metric, best_metric)) + if metric > best_metric: paddle.static.save( program=test_program_info.program._program, @@ -665,7 +669,7 @@ class AutoCompression: _logger.warning( "Not set eval function, so unable to test accuracy performance." ) - if self.train_config.train_iter and batch_id >= self.train_config.train_iter: + if self.train_config.train_iter and total_train_iter >= self.train_config.train_iter: break if 'unstructure' in self._strategy or self.train_config.sparse_model: