From cc7fd90059bdc95c902857e1d10cb8202f50ec71 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 11 Aug 2022 14:37:48 +0800 Subject: [PATCH] support load onnx mdoel in act (#1328) --- .../auto-compression/auto_compression_api.rst | 10 +- .../auto_compression/pytorch_yolov5/README.md | 23 +--- .../configs/yolov5s_qat_dis.yaml | 4 +- .../auto_compression/pytorch_yolov5/eval.py | 11 +- .../pytorch_yolov5/post_quant.py | 12 ++- .../auto_compression/pytorch_yolov5/run.py | 4 +- .../auto_compression/pytorch_yolov6/README.md | 18 +--- .../configs/yolov6s_qat_dis.yaml | 4 +- .../auto_compression/pytorch_yolov6/eval.py | 10 +- .../pytorch_yolov6/post_quant.py | 14 +-- .../auto_compression/pytorch_yolov6/run.py | 4 +- .../auto_compression/pytorch_yolov7/README.md | 19 +--- .../configs/yolov7_qat_dis.yaml | 4 +- .../auto_compression/pytorch_yolov7/eval.py | 10 +- .../pytorch_yolov7/post_quant.py | 12 ++- .../auto_compression/pytorch_yolov7/run.py | 4 +- paddleslim/auto_compression/compressor.py | 75 ++++++------- .../auto_compression/utils/dataloader.py | 3 +- .../auto_compression/utils/load_model.py | 49 ++++++++- paddleslim/common/__init__.py | 3 +- paddleslim/common/convert_model.py | 102 ++++++++++++++++++ requirements.txt | 3 +- tests/act/test_act_api.py | 24 +++++ 23 files changed, 266 insertions(+), 156 deletions(-) create mode 100644 paddleslim/common/convert_model.py diff --git a/docs/zh_cn/api_cn/static/auto-compression/auto_compression_api.rst b/docs/zh_cn/api_cn/static/auto-compression/auto_compression_api.rst index f5731df4..c308413d 100644 --- a/docs/zh_cn/api_cn/static/auto-compression/auto_compression_api.rst +++ b/docs/zh_cn/api_cn/static/auto-compression/auto_compression_api.rst @@ -3,19 +3,19 @@ AutoCompression自动压缩功能 AutoCompression --------------- -.. py:class:: paddleslim.auto_compression.AutoCompression(model_dir, model_filename, params_filename, save_dir, strategy_config, train_config, train_dataloader, eval_callback, devices='gpu') +.. py:class:: paddleslim.auto_compression.AutoCompression(model_dir, train_dataloader, model_filename, params_filename, save_dir, strategy_config, train_config, eval_callback, devices='gpu') -`源代码 `_ +`源代码 `_ 根据指定的配置对使用 ``paddle.jit.save`` 接口或者 ``paddle.static.save_inference_model`` 接口保存的推理模型进行压缩。 **参数: ** - **model_dir(str)** - 需要压缩的推理模型所在的目录。 +- **train_dataloader(paddle.io.DataLoader)** - 训练数据迭代器。注意:如果选择离线量化超参搜索策略的话, ``train_dataloader`` 和 ``eval_callback`` 设置相同的数据读取即可。 - **model_filename(str)** - 需要压缩的推理模型文件名称。 - **params_filename(str)** - 需要压缩的推理模型参数文件名称。 - **save_dir(str)** - 压缩后模型的所保存的目录。 -- **train_dataloader(paddle.io.DataLoader)** - 训练数据迭代器。注意:如果选择离线量化超参搜索策略的话, ``train_dataloader`` 和 ``eval_callback`` 设置相同的数据读取即可。 - **train_config(dict)** - 训练配置。可以配置的参数请参考: ``_ 。注意:如果选择离线量化超参搜索策略的话, ``train_config`` 直接设置为 ``None`` 即可。 - **strategy_config(dict, list(dict), 可选)** - 使用的压缩策略,可以通过设置多个单种策略来并行使用这些压缩方式。字典的关键字必须在: ``Quantization`` (量化配置, 可配置的参数参考 ``_ ), @@ -82,13 +82,13 @@ AutoCompression eval_dataloader = Cifar10(mode='eval') - ac = AutoCompression(model_path, model_filename, params_filename, save_dir, \ + ac = AutoCompression(model_path, train_dataloader, model_filename, params_filename, save_dir, \ strategy_config="Quantization": Quantization(**default_ptq_config), "Distillation": HyperParameterOptimization(**default_distill_config)}, \ - train_config=None, train_dataloader=train_dataloader, eval_callback=eval_dataloader,devices='gpu') + train_config=None, eval_callback=eval_dataloader,devices='gpu') ``` diff --git a/example/auto_compression/pytorch_yolov5/README.md b/example/auto_compression/pytorch_yolov5/README.md index 81cab034..aeaf7f28 100644 --- a/example/auto_compression/pytorch_yolov5/README.md +++ b/example/auto_compression/pytorch_yolov5/README.md @@ -22,7 +22,7 @@ | 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | 配置文件 | Inference模型 | | :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | -| YOLOv5s | Base模型 | 640*640 | 37.4 | 5.95ms | 2.44ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar) | +| YOLOv5s | Base模型 | 640*640 | 37.4 | 5.95ms | 2.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) | | YOLOv5s | KL离线量化 | 640*640 | 36.0 | - | - | 1.87ms | - | - | | YOLOv5s | 量化蒸馏训练 | 640*640 | **36.9** | - | - | **1.87ms** | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.onnx) | @@ -60,37 +60,18 @@ pip install paddledet 注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。 -(4)安装X2Paddle的1.3.6以上版本: -```shell -pip install x2paddle sympy onnx -``` - #### 3.2 准备数据集 本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。 如果已经准备好数据集,请直接修改[./configs/yolov6_reader.yml]中`EvalDataset`的`dataset_dir`字段为自己数据集路径即可。 -#### 3.3 准备预测模型 - -(1)准备ONNX模型: +#### 3.3 准备ONNX预测模型 可通过[ultralytics/yolov5](https://github.com/ultralytics/yolov5) 官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型。也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx)。 ```shell python export.py --weights yolov5s.pt --include onnx ``` - -(2) 转换模型: -```shell -x2paddle --framework=onnx --model=yolov5s.onnx --save_dir=pd_model -cp -r pd_model/inference_model/ yolov5s_infer -``` -即可得到YOLOv5s模型的预测模型(`model.pdmodel` 和 `model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv5s的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar)。 - - -预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 - - #### 3.4 自动压缩并产出模型 蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: diff --git a/example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml b/example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml index 06fb8d89..57168222 100644 --- a/example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml +++ b/example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml @@ -4,9 +4,7 @@ Global: input_list: {'image': 'x2paddle_images'} Evaluation: True arch: 'YOLOv5' - model_dir: ./yolov5s_infer - model_filename: model.pdmodel - params_filename: model.pdiparams + model_dir: ./yolov5s.onnx Distillation: alpha: 1.0 diff --git a/example/auto_compression/pytorch_yolov5/eval.py b/example/auto_compression/pytorch_yolov5/eval.py index 55be2feb..69dcf237 100644 --- a/example/auto_compression/pytorch_yolov5/eval.py +++ b/example/auto_compression/pytorch_yolov5/eval.py @@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric from paddleslim.auto_compression.config_helpers import load_config as load_slim_config - +from paddleslim.common import load_onnx_model from post_process import YOLOv5PostProcess @@ -76,13 +76,8 @@ def eval(): place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) - - val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( - global_config["model_dir"], - exe, - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"]) - print('Loaded model from: {}'.format(global_config["model_dir"])) + val_program, feed_target_names, fetch_targets = load_onnx_model( + global_config["model_dir"]) metric = global_config['metric'] for batch_id, data in enumerate(val_loader): diff --git a/example/auto_compression/pytorch_yolov5/post_quant.py b/example/auto_compression/pytorch_yolov5/post_quant.py index 8c866727..a5316d29 100644 --- a/example/auto_compression/pytorch_yolov5/post_quant.py +++ b/example/auto_compression/pytorch_yolov5/post_quant.py @@ -22,6 +22,7 @@ from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.quant import quant_post_static +from paddleslim.common import load_onnx_model def argsparser(): @@ -77,20 +78,23 @@ def main(): place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) + load_onnx_model(global_config["model_dir"]) + inference_model_path = global_config["model_dir"].rstrip().rstrip( + '.onnx') + '_infer' quant_post_static( executor=exe, - model_dir=global_config["model_dir"], + model_dir=inference_model_path, quantize_model_path=FLAGS.save_dir, data_loader=train_loader, - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"], + model_filename='model.pdmodel', + params_filename='model.pdiparams', batch_size=32, batch_nums=10, algo=FLAGS.algo, hist_percent=0.999, is_full_quantize=False, bias_correction=False, - onnx_format=False) + onnx_format=True) if __name__ == '__main__': diff --git a/example/auto_compression/pytorch_yolov5/run.py b/example/auto_compression/pytorch_yolov5/run.py index 965a546f..bd36e25e 100644 --- a/example/auto_compression/pytorch_yolov5/run.py +++ b/example/auto_compression/pytorch_yolov5/run.py @@ -159,11 +159,9 @@ def main(): ac = AutoCompression( model_dir=global_config["model_dir"], - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"], + train_dataloader=train_loader, save_dir=FLAGS.save_dir, config=all_config, - train_dataloader=train_loader, eval_callback=eval_func) ac.compress() diff --git a/example/auto_compression/pytorch_yolov6/README.md b/example/auto_compression/pytorch_yolov6/README.md index acb43c36..f841a020 100644 --- a/example/auto_compression/pytorch_yolov6/README.md +++ b/example/auto_compression/pytorch_yolov6/README.md @@ -22,7 +22,7 @@ | 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | 配置文件 | Inference模型 | | :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | -| YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_infer.tar) | +| YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) | | YOLOv6s | KL离线量化 | 640*640 | 30.3 | - | - | 1.83ms | - | - | | YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | - | - | **1.83ms** | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.onnx) | @@ -36,7 +36,6 @@ - PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) - PaddleSlim > 2.3版本 - PaddleDet >= 2.4 -- [X2Paddle](https://github.com/PaddlePaddle/X2Paddle) >= 1.3.6 - opencv-python (1)安装paddlepaddle: @@ -59,10 +58,6 @@ pip install paddledet 注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。 -(4)安装X2Paddle的1.3.6以上版本: -```shell -pip install x2paddle sympy onnx -``` #### 3.2 准备数据集 @@ -78,17 +73,6 @@ pip install x2paddle sympy onnx 可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型。也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx)。 -(2) 转换模型: -``` -x2paddle --framework=onnx --model=yolov6s.onnx --save_dir=pd_model -cp -r pd_model/inference_model/ yolov6s_infer -``` -即可得到YOLOv6s模型的预测模型(`model.pdmodel` 和 `model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv6s的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_infer.tar)。 - - -预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 - - #### 3.4 自动压缩并产出模型 蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: diff --git a/example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml b/example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml index 7bbd3324..7e56f81c 100644 --- a/example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml +++ b/example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml @@ -4,9 +4,7 @@ Global: input_list: {'image': 'x2paddle_image_arrays'} Evaluation: True arch: 'YOLOv6' - model_dir: ./yolov6s_infer - model_filename: model.pdmodel - params_filename: model.pdiparams + model_dir: ./yolov6s.onnx Distillation: alpha: 1.0 diff --git a/example/auto_compression/pytorch_yolov6/eval.py b/example/auto_compression/pytorch_yolov6/eval.py index 62127b51..cadc28aa 100644 --- a/example/auto_compression/pytorch_yolov6/eval.py +++ b/example/auto_compression/pytorch_yolov6/eval.py @@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric from paddleslim.auto_compression.config_helpers import load_config as load_slim_config - +from paddleslim.common import load_onnx_model from post_process import YOLOv6PostProcess @@ -77,12 +77,8 @@ def eval(): place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) - val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( - global_config["model_dir"], - exe, - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"]) - print('Loaded model from: {}'.format(global_config["model_dir"])) + val_program, feed_target_names, fetch_targets = load_onnx_model( + global_config["model_dir"]) metric = global_config['metric'] for batch_id, data in enumerate(val_loader): diff --git a/example/auto_compression/pytorch_yolov6/post_quant.py b/example/auto_compression/pytorch_yolov6/post_quant.py index aa4f5d8f..a5316d29 100644 --- a/example/auto_compression/pytorch_yolov6/post_quant.py +++ b/example/auto_compression/pytorch_yolov6/post_quant.py @@ -22,8 +22,7 @@ from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.quant import quant_post_static - -from post_process import YOLOv6PostProcess +from paddleslim.common import load_onnx_model def argsparser(): @@ -79,20 +78,23 @@ def main(): place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) + load_onnx_model(global_config["model_dir"]) + inference_model_path = global_config["model_dir"].rstrip().rstrip( + '.onnx') + '_infer' quant_post_static( executor=exe, - model_dir=global_config["model_dir"], + model_dir=inference_model_path, quantize_model_path=FLAGS.save_dir, data_loader=train_loader, - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"], + model_filename='model.pdmodel', + params_filename='model.pdiparams', batch_size=32, batch_nums=10, algo=FLAGS.algo, hist_percent=0.999, is_full_quantize=False, bias_correction=False, - onnx_format=False) + onnx_format=True) if __name__ == '__main__': diff --git a/example/auto_compression/pytorch_yolov6/run.py b/example/auto_compression/pytorch_yolov6/run.py index 05fe7fdd..4db9af11 100644 --- a/example/auto_compression/pytorch_yolov6/run.py +++ b/example/auto_compression/pytorch_yolov6/run.py @@ -161,11 +161,9 @@ def main(): ac = AutoCompression( model_dir=global_config["model_dir"], - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"], + train_dataloader=train_loader, save_dir=FLAGS.save_dir, config=all_config, - train_dataloader=train_loader, eval_callback=eval_func) ac.compress() diff --git a/example/auto_compression/pytorch_yolov7/README.md b/example/auto_compression/pytorch_yolov7/README.md index 7306391f..768a99eb 100644 --- a/example/auto_compression/pytorch_yolov7/README.md +++ b/example/auto_compression/pytorch_yolov7/README.md @@ -22,7 +22,7 @@ | 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | 配置文件 | Inference模型 | | :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | -| YOLOv7 | Base模型 | 640*640 | 51.1 | 26.84ms | 7.44ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_infer.tar) | +| YOLOv7 | Base模型 | 640*640 | 51.1 | 26.84ms | 7.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) | | YOLOv7 | KL离线量化 | 640*640 | 50.2 | - | - | 4.55ms | - | - | | YOLOv7 | 量化蒸馏训练 | 640*640 | **50.8** | - | - | **4.55ms** | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.onnx) | @@ -36,7 +36,6 @@ - PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) - PaddleSlim > 2.3版本 - PaddleDet >= 2.4 -- [X2Paddle](https://github.com/PaddlePaddle/X2Paddle) >= 1.3.6 - opencv-python (1)安装paddlepaddle: @@ -59,10 +58,6 @@ pip install paddledet 注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。 -(4)安装X2Paddle的1.3.6以上版本: -```shell -pip install x2paddle sympy onnx -``` #### 3.2 准备数据集 @@ -86,18 +81,6 @@ python export.py --weights yolov7.pt --include onnx 也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx)。 - -(2) 转换模型: -``` -x2paddle --framework=onnx --model=yolov7.onnx --save_dir=pd_model -cp -r pd_model/inference_model/ yolov7_infer -``` -即可得到YOLOv7模型的预测模型(`model.pdmodel` 和 `model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv7的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_infer.tar)。 - - -预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 - - #### 3.4 自动压缩并产出模型 蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: diff --git a/example/auto_compression/pytorch_yolov7/configs/yolov7_qat_dis.yaml b/example/auto_compression/pytorch_yolov7/configs/yolov7_qat_dis.yaml index fb3a5069..ce1edd3f 100644 --- a/example/auto_compression/pytorch_yolov7/configs/yolov7_qat_dis.yaml +++ b/example/auto_compression/pytorch_yolov7/configs/yolov7_qat_dis.yaml @@ -3,9 +3,7 @@ Global: reader_config: configs/yolov7_reader.yaml input_list: {'image': 'x2paddle_images'} Evaluation: True - model_dir: ./yolov7_infer - model_filename: model.pdmodel - params_filename: model.pdiparams + model_dir: ./yolov7.onnx Distillation: alpha: 1.0 diff --git a/example/auto_compression/pytorch_yolov7/eval.py b/example/auto_compression/pytorch_yolov7/eval.py index 478c4e1a..519775bc 100644 --- a/example/auto_compression/pytorch_yolov7/eval.py +++ b/example/auto_compression/pytorch_yolov7/eval.py @@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric from paddleslim.auto_compression.config_helpers import load_config as load_slim_config - +from paddleslim.common import load_onnx_model from post_process import YOLOv7PostProcess @@ -77,12 +77,8 @@ def eval(): place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) - val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( - global_config["model_dir"], - exe, - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"]) - print('Loaded model from: {}'.format(global_config["model_dir"])) + val_program, feed_target_names, fetch_targets = load_onnx_model( + global_config["model_dir"]) metric = global_config['metric'] for batch_id, data in enumerate(val_loader): diff --git a/example/auto_compression/pytorch_yolov7/post_quant.py b/example/auto_compression/pytorch_yolov7/post_quant.py index 8c866727..a5316d29 100644 --- a/example/auto_compression/pytorch_yolov7/post_quant.py +++ b/example/auto_compression/pytorch_yolov7/post_quant.py @@ -22,6 +22,7 @@ from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.quant import quant_post_static +from paddleslim.common import load_onnx_model def argsparser(): @@ -77,20 +78,23 @@ def main(): place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) + load_onnx_model(global_config["model_dir"]) + inference_model_path = global_config["model_dir"].rstrip().rstrip( + '.onnx') + '_infer' quant_post_static( executor=exe, - model_dir=global_config["model_dir"], + model_dir=inference_model_path, quantize_model_path=FLAGS.save_dir, data_loader=train_loader, - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"], + model_filename='model.pdmodel', + params_filename='model.pdiparams', batch_size=32, batch_nums=10, algo=FLAGS.algo, hist_percent=0.999, is_full_quantize=False, bias_correction=False, - onnx_format=False) + onnx_format=True) if __name__ == '__main__': diff --git a/example/auto_compression/pytorch_yolov7/run.py b/example/auto_compression/pytorch_yolov7/run.py index ed73d81a..5120ed9b 100644 --- a/example/auto_compression/pytorch_yolov7/run.py +++ b/example/auto_compression/pytorch_yolov7/run.py @@ -152,11 +152,9 @@ def main(): ac = AutoCompression( model_dir=global_config["model_dir"], - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"], + train_dataloader=train_loader, save_dir=FLAGS.save_dir, config=all_config, - train_dataloader=train_loader, eval_callback=eval_func) ac.compress() diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index 5587e92d..27fbead1 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -35,7 +35,7 @@ from .strategy_config import TrainConfig, ProgramInfo, merge_config from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config from .config_helpers import load_config, extract_strategy_config, extract_train_config from .utils.predict import with_variable_shape -from .utils import get_feed_vars, wrap_dataloader, load_inference_model +from .utils import get_feed_vars, wrap_dataloader, load_inference_model, get_model_dir _logger = get_logger(__name__, level=logging.INFO) @@ -49,10 +49,10 @@ except Exception as e: class AutoCompression: def __init__(self, model_dir, - model_filename, - params_filename, - save_dir, train_dataloader, + model_filename=None, + params_filename=None, + save_dir='./output', config=None, input_shapes=None, target_speedup=None, @@ -66,13 +66,13 @@ class AutoCompression: model_dir(str): The path of inference model that will be compressed, and the model and params that saved by ``paddle.static.save_inference_model`` are under the path. + train_data_loader(Python Generator, Paddle.io.DataLoader): The + Generator or Dataloader provides train data, and it could + return a batch every time. model_filename(str): The name of model file. params_filename(str): The name of params file. save_dir(str): The path to save compressed model. The models in this directory will be overwrited after calling 'compress()' function. - train_data_loader(Python Generator, Paddle.io.DataLoader): The - Generator or Dataloader provides train data, and it could - return a batch every time. input_shapes(dict|tuple|list): It is used when the model has implicit dimensions except batch size. If it is a dict, the key is the name of input and the value is the shape. Given the input shape of input "X" is [-1, 3, -1, -1] which means the batch size, hight @@ -117,18 +117,8 @@ class AutoCompression: deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'. """ self.model_dir = model_dir.rstrip('/') - - if model_filename == 'None': - model_filename = None - self.model_filename = model_filename - if params_filename == 'None': - params_filename = None - self.params_filename = params_filename - - if params_filename is None and model_filename is not None: - raise NotImplementedError( - "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first." - ) + self.updated_model_dir, self.model_filename, self.params_filename = get_model_dir( + model_dir, model_filename, params_filename) self.final_dir = save_dir if not os.path.exists(self.final_dir): @@ -163,8 +153,7 @@ class AutoCompression: paddle.enable_static() self._exe, self._places = self._prepare_envs() - self.model_type = self._get_model_type(self._exe, self.model_dir, - model_filename, params_filename) + self.model_type = self._get_model_type() if self.train_config is not None and self.train_config.use_fleet: fleet.init(is_collective=True) @@ -249,8 +238,8 @@ class AutoCompression: paddle.enable_static() exe = paddle.static.Executor(paddle.CPUPlace()) [inference_program, feed_target_names, - fetch_targets] = (load_inference_model(model_dir, exe, model_filename, - params_filename)) + fetch_targets] = load_inference_model(model_dir, exe, model_filename, + params_filename) if type(input_shapes) in [list, tuple]: assert len( @@ -310,23 +299,26 @@ class AutoCompression: exe = paddle.static.Executor(places) return exe, places - def _get_model_type(self, exe, model_dir, model_filename, params_filename): - [inference_program, _, _]= (load_inference_model( \ - model_dir, \ - model_filename=model_filename, params_filename=params_filename, - executor=exe)) + def _get_model_type(self): + [inference_program, _, _] = (load_inference_model( + self.model_dir, + model_filename=self.model_filename, + params_filename=self.params_filename, + executor=self._exe)) _, _, model_type = get_patterns(inference_program) if self.model_filename is None: - new_model_filename = '__new_model__' + opt_model_filename = '__opt_model__' else: - new_model_filename = 'new_' + self.model_filename + opt_model_filename = 'opt_' + self.model_filename program_bytes = inference_program._remove_training_info( clip_extra=False).desc.serialize_to_string() - with open(os.path.join(self.model_dir, new_model_filename), "wb") as f: + with open( + os.path.join(self.updated_model_dir, opt_model_filename), + "wb") as f: f.write(program_bytes) shutil.move( - os.path.join(self.model_dir, new_model_filename), - os.path.join(self.model_dir, self.model_filename)) + os.path.join(self.updated_model_dir, opt_model_filename), + os.path.join(self.updated_model_dir, self.model_filename)) _logger.info(f"Detect model type: {model_type}") return model_type @@ -603,10 +595,16 @@ class AutoCompression: train_config): # start compress, including train/eval model # TODO: add the emd loss of evaluation model. + # If model is ONNX, convert it to inference model firstly. + load_inference_model( + self.model_dir, + model_filename=self.model_filename, + params_filename=self.params_filename, + executor=self._exe) if strategy == 'quant_post': quant_post( self._exe, - model_dir=self.model_dir, + model_dir=self.updated_model_dir, quantize_model_path=os.path.join( self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), data_loader=self.train_dataloader, @@ -632,11 +630,16 @@ class AutoCompression: if platform.system().lower() != 'linux': raise NotImplementedError( "post-quant-hpo is not support in system other than linux") - + # If model is ONNX, convert it to inference model firstly. + load_inference_model( + self.model_dir, + model_filename=self.model_filename, + params_filename=self.params_filename, + executor=self._exe) post_quant_hpo.quant_post_hpo( self._exe, self._places, - model_dir=self.model_dir, + model_dir=self.updated_model_dir, quantize_model_path=os.path.join( self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), train_dataloader=self.train_dataloader, diff --git a/paddleslim/auto_compression/utils/dataloader.py b/paddleslim/auto_compression/utils/dataloader.py index f0f36716..31e375de 100644 --- a/paddleslim/auto_compression/utils/dataloader.py +++ b/paddleslim/auto_compression/utils/dataloader.py @@ -3,6 +3,7 @@ import time import numpy as np import paddle from collections.abc import Iterable +from .load_model import load_inference_model __all__ = ["wrap_dataloader", "get_feed_vars"] @@ -13,7 +14,7 @@ def get_feed_vars(model_dir, model_filename, params_filename): paddle.enable_static() exe = paddle.static.Executor(paddle.CPUPlace()) [inference_program, feed_target_names, fetch_targets] = ( - paddle.static.load_inference_model( + load_inference_model( model_dir, exe, model_filename=model_filename, diff --git a/paddleslim/auto_compression/utils/load_model.py b/paddleslim/auto_compression/utils/load_model.py index bb61ab56..637e808a 100644 --- a/paddleslim/auto_compression/utils/load_model.py +++ b/paddleslim/auto_compression/utils/load_model.py @@ -14,15 +14,37 @@ import os import paddle +from ...common import load_onnx_model -__all__ = ['load_inference_model'] +__all__ = ['load_inference_model', 'get_model_dir'] def load_inference_model(path_prefix, executor, model_filename=None, params_filename=None): - if model_filename is not None and params_filename is not None: + # Load onnx model to Inference model. + if path_prefix.endswith('.onnx'): + inference_program, feed_target_names, fetch_targets = load_onnx_model( + path_prefix) + return [inference_program, feed_target_names, fetch_targets] + # Load Inference model. + # TODO: clean code + if model_filename is not None and model_filename.endswith('.pdmodel'): + model_name = '.'.join(model_filename.split('.')[:-1]) + assert os.path.exists( + os.path.join(path_prefix, model_name + '.pdmodel') + ), 'Please check {}, or fix model_filename parameter.'.format( + os.path.join(path_prefix, model_name + '.pdmodel')) + assert os.path.exists( + os.path.join(path_prefix, model_name + '.pdiparams') + ), 'Please check {}, or fix params_filename parameter.'.format( + os.path.join(path_prefix, model_name + '.pdiparams')) + model_path_prefix = os.path.join(path_prefix, model_name) + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model( + path_prefix=model_path_prefix, executor=executor)) + elif model_filename is not None and params_filename is not None: [inference_program, feed_target_names, fetch_targets] = ( paddle.static.load_inference_model( path_prefix=path_prefix, @@ -43,3 +65,26 @@ def load_inference_model(path_prefix, path_prefix=path_prefix, executor=executor)) return [inference_program, feed_target_names, fetch_targets] + + +def get_model_dir(model_dir, model_filename, params_filename): + if model_dir.endswith('.onnx'): + updated_model_dir = model_dir.rstrip().rstrip('.onnx') + '_infer' + else: + updated_model_dir = model_dir.rstrip('/') + + if model_filename == None: + updated_model_filename = 'model.pdmodel' + else: + updated_model_filename = model_filename + + if params_filename == None: + updated_params_filename = 'model.pdiparams' + else: + updated_params_filename = params_filename + + if params_filename is None and model_filename is not None: + raise NotImplementedError( + "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first." + ) + return updated_model_dir, updated_model_filename, updated_params_filename diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index c3e40415..e866790d 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -25,11 +25,12 @@ from .analyze_helper import VarCollector from . import wrapper_function from . import recover_program from . import patterns +from .convert_model import load_onnx_model __all__ = [ 'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer', 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter', - 'Server', 'Client', 'RLBaseController', 'VarCollector' + 'Server', 'Client', 'RLBaseController', 'VarCollector', 'load_onnx_model' ] __all__ += wrapper_function.__all__ diff --git a/paddleslim/common/convert_model.py b/paddleslim/common/convert_model.py new file mode 100644 index 00000000..1b501269 --- /dev/null +++ b/paddleslim/common/convert_model.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import logging +import os +import shutil +import sys + +import paddle +from x2paddle.decoder.onnx_decoder import ONNXDecoder +from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper +from x2paddle.optimizer.optimizer import GraphOptimizer +from x2paddle.utils import ConverterCheck + +from . import get_logger +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = ['load_onnx_model'] + + +def load_onnx_model(model_path, disable_feedback=False): + assert model_path.endswith( + '.onnx' + ), '{} does not end with .onnx suffix and cannot be loaded.'.format( + model_path) + inference_model_path = model_path.rstrip().rstrip('.onnx') + '_infer' + exe = paddle.static.Executor(paddle.CPUPlace()) + if os.path.exists(os.path.join( + inference_model_path, 'model.pdmodel')) and os.path.exists( + os.path.join(inference_model_path, 'model.pdiparams')): + val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( + os.path.join(inference_model_path, 'model'), exe) + _logger.info('Loaded model from: {}'.format(inference_model_path)) + return val_program, feed_target_names, fetch_targets + else: + # onnx to paddle inference model. + time_info = int(time.time()) + if not disable_feedback: + ConverterCheck( + task="ONNX", time_info=time_info, convert_state="Start").start() + # check onnx installation and version + try: + import onnx + version = onnx.version.version + v0, v1, v2 = version.split('.') + version_sum = int(v0) * 100 + int(v1) * 10 + int(v2) + if version_sum < 160: + _logger.info("[ERROR] onnx>=1.6.0 is required") + sys.exit(1) + except: + _logger.info( + "[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\"." + ) + sys.exit(1) + + try: + _logger.info("Now translating model from onnx to paddle.") + model = ONNXDecoder(model_path) + mapper = ONNXOpMapper(model) + mapper.paddle_graph.build() + graph_opt = GraphOptimizer(source_frame="onnx") + graph_opt.optimize(mapper.paddle_graph) + _logger.info("Model optimized.") + onnx2paddle_out_dir = os.path.join(inference_model_path, + 'onnx2paddle') + mapper.paddle_graph.gen_model(onnx2paddle_out_dir) + _logger.info("Successfully exported Paddle static graph model!") + if not disable_feedback: + ConverterCheck( + task="ONNX", time_info=time_info, + convert_state="Success").start() + shutil.move( + os.path.join(onnx2paddle_out_dir, 'inference_model', + 'model.pdmodel'), + os.path.join(inference_model_path, 'model.pdmodel')) + shutil.move( + os.path.join(onnx2paddle_out_dir, 'inference_model', + 'model.pdiparams'), + os.path.join(inference_model_path, 'model.pdiparams')) + except: + _logger.info( + "[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues" + ) + sys.exit(1) + + paddle.enable_static() + val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( + os.path.join(inference_model_path, 'model'), exe) + _logger.info('Loaded model from: {}'.format(inference_model_path)) + return val_program, feed_target_names, fetch_targets diff --git a/requirements.txt b/requirements.txt index 8ed081ba..488d1d91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -#paddlepaddle == 1.6.0rc0 tqdm pyzmq matplotlib @@ -6,3 +5,5 @@ pillow pyyaml scikit-learn smac +onnx +x2paddle==1.3.8 diff --git a/tests/act/test_act_api.py b/tests/act/test_act_api.py index ec73d46e..aa50bb2b 100644 --- a/tests/act/test_act_api.py +++ b/tests/act/test_act_api.py @@ -9,6 +9,7 @@ import numpy as np from paddle.io import Dataset from paddleslim.auto_compression import AutoCompression from paddleslim.auto_compression.config_helpers import load_config +from paddleslim.auto_compression.utils.load_model import load_inference_model class RandomEvalDataset(Dataset): @@ -120,5 +121,28 @@ class TestDictQATDist(ACTBase): ac.compress() +class TestLoadONNXModel(ACTBase): + def __init__(self, *args, **kwargs): + super(TestLoadONNXModel, self).__init__(*args, **kwargs) + os.system( + 'wget https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx') + self.model_dir = 'yolov5s.onnx' + + def test_compress(self): + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + _, _, _ = load_inference_model( + self.model_dir, + executor=exe, + model_filename='model.pdmodel', + params_filename='model.paiparams') + # reload model + _, _, _ = load_inference_model( + self.model_dir, + executor=exe, + model_filename='model.pdmodel', + params_filename='model.paiparams') + + if __name__ == '__main__': unittest.main() -- GitLab