未验证 提交 cc7fd900 编写于 作者: G Guanghua Yu 提交者: GitHub

support load onnx mdoel in act (#1328)

上级 5784dfe1
...@@ -3,19 +3,19 @@ AutoCompression自动压缩功能 ...@@ -3,19 +3,19 @@ AutoCompression自动压缩功能
AutoCompression AutoCompression
--------------- ---------------
.. py:class:: paddleslim.auto_compression.AutoCompression(model_dir, model_filename, params_filename, save_dir, strategy_config, train_config, train_dataloader, eval_callback, devices='gpu') .. py:class:: paddleslim.auto_compression.AutoCompression(model_dir, train_dataloader, model_filename, params_filename, save_dir, strategy_config, train_config, eval_callback, devices='gpu')
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/auto_compression.py#L32>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/auto_compression.py#L49>`_
根据指定的配置对使用 ``paddle.jit.save`` 接口或者 ``paddle.static.save_inference_model`` 接口保存的推理模型进行压缩。 根据指定的配置对使用 ``paddle.jit.save`` 接口或者 ``paddle.static.save_inference_model`` 接口保存的推理模型进行压缩。
**参数: ** **参数: **
- **model_dir(str)** - 需要压缩的推理模型所在的目录。 - **model_dir(str)** - 需要压缩的推理模型所在的目录。
- **train_dataloader(paddle.io.DataLoader)** - 训练数据迭代器。注意:如果选择离线量化超参搜索策略的话, ``train_dataloader`` 和 ``eval_callback`` 设置相同的数据读取即可。
- **model_filename(str)** - 需要压缩的推理模型文件名称。 - **model_filename(str)** - 需要压缩的推理模型文件名称。
- **params_filename(str)** - 需要压缩的推理模型参数文件名称。 - **params_filename(str)** - 需要压缩的推理模型参数文件名称。
- **save_dir(str)** - 压缩后模型的所保存的目录。 - **save_dir(str)** - 压缩后模型的所保存的目录。
- **train_dataloader(paddle.io.DataLoader)** - 训练数据迭代器。注意:如果选择离线量化超参搜索策略的话, ``train_dataloader`` 和 ``eval_callback`` 设置相同的数据读取即可。
- **train_config(dict)** - 训练配置。可以配置的参数请参考: `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L103>`_ 。注意:如果选择离线量化超参搜索策略的话, ``train_config`` 直接设置为 ``None`` 即可。 - **train_config(dict)** - 训练配置。可以配置的参数请参考: `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L103>`_ 。注意:如果选择离线量化超参搜索策略的话, ``train_config`` 直接设置为 ``None`` 即可。
- **strategy_config(dict, list(dict), 可选)** - 使用的压缩策略,可以通过设置多个单种策略来并行使用这些压缩方式。字典的关键字必须在: - **strategy_config(dict, list(dict), 可选)** - 使用的压缩策略,可以通过设置多个单种策略来并行使用这些压缩方式。字典的关键字必须在:
``Quantization`` (量化配置, 可配置的参数参考 `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L24>`_ ), ``Quantization`` (量化配置, 可配置的参数参考 `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L24>`_ ),
...@@ -82,13 +82,13 @@ AutoCompression ...@@ -82,13 +82,13 @@ AutoCompression
eval_dataloader = Cifar10(mode='eval') eval_dataloader = Cifar10(mode='eval')
ac = AutoCompression(model_path, model_filename, params_filename, save_dir, \ ac = AutoCompression(model_path, train_dataloader, model_filename, params_filename, save_dir, \
strategy_config="Quantization": Quantization(**default_ptq_config), strategy_config="Quantization": Quantization(**default_ptq_config),
"Distillation": HyperParameterOptimization(**default_distill_config)}, \ "Distillation": HyperParameterOptimization(**default_distill_config)}, \
train_config=None, train_dataloader=train_dataloader, eval_callback=eval_dataloader,devices='gpu') train_config=None, eval_callback=eval_dataloader,devices='gpu')
``` ```
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 | | 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | | :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv5s | Base模型 | 640*640 | 37.4 | 5.95ms | 2.44ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar) | | YOLOv5s | Base模型 | 640*640 | 37.4 | 5.95ms | 2.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) |
| YOLOv5s | KL离线量化 | 640*640 | 36.0 | - | - | 1.87ms | - | - | | YOLOv5s | KL离线量化 | 640*640 | 36.0 | - | - | 1.87ms | - | - |
| YOLOv5s | 量化蒸馏训练 | 640*640 | **36.9** | - | - | **1.87ms** | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.onnx) | | YOLOv5s | 量化蒸馏训练 | 640*640 | **36.9** | - | - | **1.87ms** | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.onnx) |
...@@ -60,37 +60,18 @@ pip install paddledet ...@@ -60,37 +60,18 @@ pip install paddledet
注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。 注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。
(4)安装X2Paddle的1.3.6以上版本:
```shell
pip install x2paddle sympy onnx
```
#### 3.2 准备数据集 #### 3.2 准备数据集
本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。 本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。
如果已经准备好数据集,请直接修改[./configs/yolov6_reader.yml]中`EvalDataset``dataset_dir`字段为自己数据集路径即可。 如果已经准备好数据集,请直接修改[./configs/yolov6_reader.yml]中`EvalDataset``dataset_dir`字段为自己数据集路径即可。
#### 3.3 准备预测模型 #### 3.3 准备ONNX预测模型
(1)准备ONNX模型:
可通过[ultralytics/yolov5](https://github.com/ultralytics/yolov5) 官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型。也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) 可通过[ultralytics/yolov5](https://github.com/ultralytics/yolov5) 官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型。也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx)
```shell ```shell
python export.py --weights yolov5s.pt --include onnx python export.py --weights yolov5s.pt --include onnx
``` ```
(2) 转换模型:
```shell
x2paddle --framework=onnx --model=yolov5s.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov5s_infer
```
即可得到YOLOv5s模型的预测模型(`model.pdmodel``model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv5s的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar)
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
#### 3.4 自动压缩并产出模型 #### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: 蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
......
...@@ -4,9 +4,7 @@ Global: ...@@ -4,9 +4,7 @@ Global:
input_list: {'image': 'x2paddle_images'} input_list: {'image': 'x2paddle_images'}
Evaluation: True Evaluation: True
arch: 'YOLOv5' arch: 'YOLOv5'
model_dir: ./yolov5s_infer model_dir: ./yolov5s.onnx
model_filename: model.pdmodel
params_filename: model.pdiparams
Distillation: Distillation:
alpha: 1.0 alpha: 1.0
......
...@@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config ...@@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.common import load_onnx_model
from post_process import YOLOv5PostProcess from post_process import YOLOv5PostProcess
...@@ -76,13 +76,8 @@ def eval(): ...@@ -76,13 +76,8 @@ def eval():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = load_onnx_model(
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( global_config["model_dir"])
global_config["model_dir"],
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
metric = global_config['metric'] metric = global_config['metric']
for batch_id, data in enumerate(val_loader): for batch_id, data in enumerate(val_loader):
......
...@@ -22,6 +22,7 @@ from ppdet.core.workspace import create ...@@ -22,6 +22,7 @@ from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.quant import quant_post_static from paddleslim.quant import quant_post_static
from paddleslim.common import load_onnx_model
def argsparser(): def argsparser():
...@@ -77,20 +78,23 @@ def main(): ...@@ -77,20 +78,23 @@ def main():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
load_onnx_model(global_config["model_dir"])
inference_model_path = global_config["model_dir"].rstrip().rstrip(
'.onnx') + '_infer'
quant_post_static( quant_post_static(
executor=exe, executor=exe,
model_dir=global_config["model_dir"], model_dir=inference_model_path,
quantize_model_path=FLAGS.save_dir, quantize_model_path=FLAGS.save_dir,
data_loader=train_loader, data_loader=train_loader,
model_filename=global_config["model_filename"], model_filename='model.pdmodel',
params_filename=global_config["params_filename"], params_filename='model.pdiparams',
batch_size=32, batch_size=32,
batch_nums=10, batch_nums=10,
algo=FLAGS.algo, algo=FLAGS.algo,
hist_percent=0.999, hist_percent=0.999,
is_full_quantize=False, is_full_quantize=False,
bias_correction=False, bias_correction=False,
onnx_format=False) onnx_format=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -159,11 +159,9 @@ def main(): ...@@ -159,11 +159,9 @@ def main():
ac = AutoCompression( ac = AutoCompression(
model_dir=global_config["model_dir"], model_dir=global_config["model_dir"],
model_filename=global_config["model_filename"], train_dataloader=train_loader,
params_filename=global_config["params_filename"],
save_dir=FLAGS.save_dir, save_dir=FLAGS.save_dir,
config=all_config, config=all_config,
train_dataloader=train_loader,
eval_callback=eval_func) eval_callback=eval_func)
ac.compress() ac.compress()
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 | | 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | | :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_infer.tar) | | YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) |
| YOLOv6s | KL离线量化 | 640*640 | 30.3 | - | - | 1.83ms | - | - | | YOLOv6s | KL离线量化 | 640*640 | 30.3 | - | - | 1.83ms | - | - |
| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | - | - | **1.83ms** | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.onnx) | | YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | - | - | **1.83ms** | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.onnx) |
...@@ -36,7 +36,6 @@ ...@@ -36,7 +36,6 @@
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) - PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim > 2.3版本 - PaddleSlim > 2.3版本
- PaddleDet >= 2.4 - PaddleDet >= 2.4
- [X2Paddle](https://github.com/PaddlePaddle/X2Paddle) >= 1.3.6
- opencv-python - opencv-python
(1)安装paddlepaddle: (1)安装paddlepaddle:
...@@ -59,10 +58,6 @@ pip install paddledet ...@@ -59,10 +58,6 @@ pip install paddledet
注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。 注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。
(4)安装X2Paddle的1.3.6以上版本:
```shell
pip install x2paddle sympy onnx
```
#### 3.2 准备数据集 #### 3.2 准备数据集
...@@ -78,17 +73,6 @@ pip install x2paddle sympy onnx ...@@ -78,17 +73,6 @@ pip install x2paddle sympy onnx
可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型。也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) 可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型。也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx)
(2) 转换模型:
```
x2paddle --framework=onnx --model=yolov6s.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov6s_infer
```
即可得到YOLOv6s模型的预测模型(`model.pdmodel``model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv6s的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_infer.tar)
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
#### 3.4 自动压缩并产出模型 #### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: 蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
......
...@@ -4,9 +4,7 @@ Global: ...@@ -4,9 +4,7 @@ Global:
input_list: {'image': 'x2paddle_image_arrays'} input_list: {'image': 'x2paddle_image_arrays'}
Evaluation: True Evaluation: True
arch: 'YOLOv6' arch: 'YOLOv6'
model_dir: ./yolov6s_infer model_dir: ./yolov6s.onnx
model_filename: model.pdmodel
params_filename: model.pdiparams
Distillation: Distillation:
alpha: 1.0 alpha: 1.0
......
...@@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config ...@@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.common import load_onnx_model
from post_process import YOLOv6PostProcess from post_process import YOLOv6PostProcess
...@@ -77,12 +77,8 @@ def eval(): ...@@ -77,12 +77,8 @@ def eval():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( val_program, feed_target_names, fetch_targets = load_onnx_model(
global_config["model_dir"], global_config["model_dir"])
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
metric = global_config['metric'] metric = global_config['metric']
for batch_id, data in enumerate(val_loader): for batch_id, data in enumerate(val_loader):
......
...@@ -22,8 +22,7 @@ from ppdet.core.workspace import create ...@@ -22,8 +22,7 @@ from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.quant import quant_post_static from paddleslim.quant import quant_post_static
from paddleslim.common import load_onnx_model
from post_process import YOLOv6PostProcess
def argsparser(): def argsparser():
...@@ -79,20 +78,23 @@ def main(): ...@@ -79,20 +78,23 @@ def main():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
load_onnx_model(global_config["model_dir"])
inference_model_path = global_config["model_dir"].rstrip().rstrip(
'.onnx') + '_infer'
quant_post_static( quant_post_static(
executor=exe, executor=exe,
model_dir=global_config["model_dir"], model_dir=inference_model_path,
quantize_model_path=FLAGS.save_dir, quantize_model_path=FLAGS.save_dir,
data_loader=train_loader, data_loader=train_loader,
model_filename=global_config["model_filename"], model_filename='model.pdmodel',
params_filename=global_config["params_filename"], params_filename='model.pdiparams',
batch_size=32, batch_size=32,
batch_nums=10, batch_nums=10,
algo=FLAGS.algo, algo=FLAGS.algo,
hist_percent=0.999, hist_percent=0.999,
is_full_quantize=False, is_full_quantize=False,
bias_correction=False, bias_correction=False,
onnx_format=False) onnx_format=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -161,11 +161,9 @@ def main(): ...@@ -161,11 +161,9 @@ def main():
ac = AutoCompression( ac = AutoCompression(
model_dir=global_config["model_dir"], model_dir=global_config["model_dir"],
model_filename=global_config["model_filename"], train_dataloader=train_loader,
params_filename=global_config["params_filename"],
save_dir=FLAGS.save_dir, save_dir=FLAGS.save_dir,
config=all_config, config=all_config,
train_dataloader=train_loader,
eval_callback=eval_func) eval_callback=eval_func)
ac.compress() ac.compress()
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 | | 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | | :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv7 | Base模型 | 640*640 | 51.1 | 26.84ms | 7.44ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_infer.tar) | | YOLOv7 | Base模型 | 640*640 | 51.1 | 26.84ms | 7.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) |
| YOLOv7 | KL离线量化 | 640*640 | 50.2 | - | - | 4.55ms | - | - | | YOLOv7 | KL离线量化 | 640*640 | 50.2 | - | - | 4.55ms | - | - |
| YOLOv7 | 量化蒸馏训练 | 640*640 | **50.8** | - | - | **4.55ms** | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.onnx) | | YOLOv7 | 量化蒸馏训练 | 640*640 | **50.8** | - | - | **4.55ms** | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.onnx) |
...@@ -36,7 +36,6 @@ ...@@ -36,7 +36,6 @@
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) - PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim > 2.3版本 - PaddleSlim > 2.3版本
- PaddleDet >= 2.4 - PaddleDet >= 2.4
- [X2Paddle](https://github.com/PaddlePaddle/X2Paddle) >= 1.3.6
- opencv-python - opencv-python
(1)安装paddlepaddle: (1)安装paddlepaddle:
...@@ -59,10 +58,6 @@ pip install paddledet ...@@ -59,10 +58,6 @@ pip install paddledet
注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。 注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。
(4)安装X2Paddle的1.3.6以上版本:
```shell
pip install x2paddle sympy onnx
```
#### 3.2 准备数据集 #### 3.2 准备数据集
...@@ -86,18 +81,6 @@ python export.py --weights yolov7.pt --include onnx ...@@ -86,18 +81,6 @@ python export.py --weights yolov7.pt --include onnx
也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) 也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx)
(2) 转换模型:
```
x2paddle --framework=onnx --model=yolov7.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov7_infer
```
即可得到YOLOv7模型的预测模型(`model.pdmodel``model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv7的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_infer.tar)
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
#### 3.4 自动压缩并产出模型 #### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: 蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
......
...@@ -3,9 +3,7 @@ Global: ...@@ -3,9 +3,7 @@ Global:
reader_config: configs/yolov7_reader.yaml reader_config: configs/yolov7_reader.yaml
input_list: {'image': 'x2paddle_images'} input_list: {'image': 'x2paddle_images'}
Evaluation: True Evaluation: True
model_dir: ./yolov7_infer model_dir: ./yolov7.onnx
model_filename: model.pdmodel
params_filename: model.pdiparams
Distillation: Distillation:
alpha: 1.0 alpha: 1.0
......
...@@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config ...@@ -21,7 +21,7 @@ from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.common import load_onnx_model
from post_process import YOLOv7PostProcess from post_process import YOLOv7PostProcess
...@@ -77,12 +77,8 @@ def eval(): ...@@ -77,12 +77,8 @@ def eval():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( val_program, feed_target_names, fetch_targets = load_onnx_model(
global_config["model_dir"], global_config["model_dir"])
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
metric = global_config['metric'] metric = global_config['metric']
for batch_id, data in enumerate(val_loader): for batch_id, data in enumerate(val_loader):
......
...@@ -22,6 +22,7 @@ from ppdet.core.workspace import create ...@@ -22,6 +22,7 @@ from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.quant import quant_post_static from paddleslim.quant import quant_post_static
from paddleslim.common import load_onnx_model
def argsparser(): def argsparser():
...@@ -77,20 +78,23 @@ def main(): ...@@ -77,20 +78,23 @@ def main():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
load_onnx_model(global_config["model_dir"])
inference_model_path = global_config["model_dir"].rstrip().rstrip(
'.onnx') + '_infer'
quant_post_static( quant_post_static(
executor=exe, executor=exe,
model_dir=global_config["model_dir"], model_dir=inference_model_path,
quantize_model_path=FLAGS.save_dir, quantize_model_path=FLAGS.save_dir,
data_loader=train_loader, data_loader=train_loader,
model_filename=global_config["model_filename"], model_filename='model.pdmodel',
params_filename=global_config["params_filename"], params_filename='model.pdiparams',
batch_size=32, batch_size=32,
batch_nums=10, batch_nums=10,
algo=FLAGS.algo, algo=FLAGS.algo,
hist_percent=0.999, hist_percent=0.999,
is_full_quantize=False, is_full_quantize=False,
bias_correction=False, bias_correction=False,
onnx_format=False) onnx_format=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -152,11 +152,9 @@ def main(): ...@@ -152,11 +152,9 @@ def main():
ac = AutoCompression( ac = AutoCompression(
model_dir=global_config["model_dir"], model_dir=global_config["model_dir"],
model_filename=global_config["model_filename"], train_dataloader=train_loader,
params_filename=global_config["params_filename"],
save_dir=FLAGS.save_dir, save_dir=FLAGS.save_dir,
config=all_config, config=all_config,
train_dataloader=train_loader,
eval_callback=eval_func) eval_callback=eval_func)
ac.compress() ac.compress()
......
...@@ -35,7 +35,7 @@ from .strategy_config import TrainConfig, ProgramInfo, merge_config ...@@ -35,7 +35,7 @@ from .strategy_config import TrainConfig, ProgramInfo, merge_config
from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config
from .config_helpers import load_config, extract_strategy_config, extract_train_config from .config_helpers import load_config, extract_strategy_config, extract_train_config
from .utils.predict import with_variable_shape from .utils.predict import with_variable_shape
from .utils import get_feed_vars, wrap_dataloader, load_inference_model from .utils import get_feed_vars, wrap_dataloader, load_inference_model, get_model_dir
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -49,10 +49,10 @@ except Exception as e: ...@@ -49,10 +49,10 @@ except Exception as e:
class AutoCompression: class AutoCompression:
def __init__(self, def __init__(self,
model_dir, model_dir,
model_filename,
params_filename,
save_dir,
train_dataloader, train_dataloader,
model_filename=None,
params_filename=None,
save_dir='./output',
config=None, config=None,
input_shapes=None, input_shapes=None,
target_speedup=None, target_speedup=None,
...@@ -66,13 +66,13 @@ class AutoCompression: ...@@ -66,13 +66,13 @@ class AutoCompression:
model_dir(str): The path of inference model that will be compressed, and model_dir(str): The path of inference model that will be compressed, and
the model and params that saved by ``paddle.static.save_inference_model`` the model and params that saved by ``paddle.static.save_inference_model``
are under the path. are under the path.
train_data_loader(Python Generator, Paddle.io.DataLoader): The
Generator or Dataloader provides train data, and it could
return a batch every time.
model_filename(str): The name of model file. model_filename(str): The name of model file.
params_filename(str): The name of params file. params_filename(str): The name of params file.
save_dir(str): The path to save compressed model. The models in this directory will be overwrited save_dir(str): The path to save compressed model. The models in this directory will be overwrited
after calling 'compress()' function. after calling 'compress()' function.
train_data_loader(Python Generator, Paddle.io.DataLoader): The
Generator or Dataloader provides train data, and it could
return a batch every time.
input_shapes(dict|tuple|list): It is used when the model has implicit dimensions except batch size. input_shapes(dict|tuple|list): It is used when the model has implicit dimensions except batch size.
If it is a dict, the key is the name of input and the value is the shape. If it is a dict, the key is the name of input and the value is the shape.
Given the input shape of input "X" is [-1, 3, -1, -1] which means the batch size, hight Given the input shape of input "X" is [-1, 3, -1, -1] which means the batch size, hight
...@@ -117,18 +117,8 @@ class AutoCompression: ...@@ -117,18 +117,8 @@ class AutoCompression:
deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'. deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'.
""" """
self.model_dir = model_dir.rstrip('/') self.model_dir = model_dir.rstrip('/')
self.updated_model_dir, self.model_filename, self.params_filename = get_model_dir(
if model_filename == 'None': model_dir, model_filename, params_filename)
model_filename = None
self.model_filename = model_filename
if params_filename == 'None':
params_filename = None
self.params_filename = params_filename
if params_filename is None and model_filename is not None:
raise NotImplementedError(
"NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
)
self.final_dir = save_dir self.final_dir = save_dir
if not os.path.exists(self.final_dir): if not os.path.exists(self.final_dir):
...@@ -163,8 +153,7 @@ class AutoCompression: ...@@ -163,8 +153,7 @@ class AutoCompression:
paddle.enable_static() paddle.enable_static()
self._exe, self._places = self._prepare_envs() self._exe, self._places = self._prepare_envs()
self.model_type = self._get_model_type(self._exe, self.model_dir, self.model_type = self._get_model_type()
model_filename, params_filename)
if self.train_config is not None and self.train_config.use_fleet: if self.train_config is not None and self.train_config.use_fleet:
fleet.init(is_collective=True) fleet.init(is_collective=True)
...@@ -249,8 +238,8 @@ class AutoCompression: ...@@ -249,8 +238,8 @@ class AutoCompression:
paddle.enable_static() paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = (load_inference_model(model_dir, exe, model_filename, fetch_targets] = load_inference_model(model_dir, exe, model_filename,
params_filename)) params_filename)
if type(input_shapes) in [list, tuple]: if type(input_shapes) in [list, tuple]:
assert len( assert len(
...@@ -310,23 +299,26 @@ class AutoCompression: ...@@ -310,23 +299,26 @@ class AutoCompression:
exe = paddle.static.Executor(places) exe = paddle.static.Executor(places)
return exe, places return exe, places
def _get_model_type(self, exe, model_dir, model_filename, params_filename): def _get_model_type(self):
[inference_program, _, _]= (load_inference_model( \ [inference_program, _, _] = (load_inference_model(
model_dir, \ self.model_dir,
model_filename=model_filename, params_filename=params_filename, model_filename=self.model_filename,
executor=exe)) params_filename=self.params_filename,
executor=self._exe))
_, _, model_type = get_patterns(inference_program) _, _, model_type = get_patterns(inference_program)
if self.model_filename is None: if self.model_filename is None:
new_model_filename = '__new_model__' opt_model_filename = '__opt_model__'
else: else:
new_model_filename = 'new_' + self.model_filename opt_model_filename = 'opt_' + self.model_filename
program_bytes = inference_program._remove_training_info( program_bytes = inference_program._remove_training_info(
clip_extra=False).desc.serialize_to_string() clip_extra=False).desc.serialize_to_string()
with open(os.path.join(self.model_dir, new_model_filename), "wb") as f: with open(
os.path.join(self.updated_model_dir, opt_model_filename),
"wb") as f:
f.write(program_bytes) f.write(program_bytes)
shutil.move( shutil.move(
os.path.join(self.model_dir, new_model_filename), os.path.join(self.updated_model_dir, opt_model_filename),
os.path.join(self.model_dir, self.model_filename)) os.path.join(self.updated_model_dir, self.model_filename))
_logger.info(f"Detect model type: {model_type}") _logger.info(f"Detect model type: {model_type}")
return model_type return model_type
...@@ -603,10 +595,16 @@ class AutoCompression: ...@@ -603,10 +595,16 @@ class AutoCompression:
train_config): train_config):
# start compress, including train/eval model # start compress, including train/eval model
# TODO: add the emd loss of evaluation model. # TODO: add the emd loss of evaluation model.
# If model is ONNX, convert it to inference model firstly.
load_inference_model(
self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
executor=self._exe)
if strategy == 'quant_post': if strategy == 'quant_post':
quant_post( quant_post(
self._exe, self._exe,
model_dir=self.model_dir, model_dir=self.updated_model_dir,
quantize_model_path=os.path.join( quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
data_loader=self.train_dataloader, data_loader=self.train_dataloader,
...@@ -632,11 +630,16 @@ class AutoCompression: ...@@ -632,11 +630,16 @@ class AutoCompression:
if platform.system().lower() != 'linux': if platform.system().lower() != 'linux':
raise NotImplementedError( raise NotImplementedError(
"post-quant-hpo is not support in system other than linux") "post-quant-hpo is not support in system other than linux")
# If model is ONNX, convert it to inference model firstly.
load_inference_model(
self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
executor=self._exe)
post_quant_hpo.quant_post_hpo( post_quant_hpo.quant_post_hpo(
self._exe, self._exe,
self._places, self._places,
model_dir=self.model_dir, model_dir=self.updated_model_dir,
quantize_model_path=os.path.join( quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
train_dataloader=self.train_dataloader, train_dataloader=self.train_dataloader,
......
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
import numpy as np import numpy as np
import paddle import paddle
from collections.abc import Iterable from collections.abc import Iterable
from .load_model import load_inference_model
__all__ = ["wrap_dataloader", "get_feed_vars"] __all__ = ["wrap_dataloader", "get_feed_vars"]
...@@ -13,7 +14,7 @@ def get_feed_vars(model_dir, model_filename, params_filename): ...@@ -13,7 +14,7 @@ def get_feed_vars(model_dir, model_filename, params_filename):
paddle.enable_static() paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model( load_inference_model(
model_dir, model_dir,
exe, exe,
model_filename=model_filename, model_filename=model_filename,
......
...@@ -14,15 +14,37 @@ ...@@ -14,15 +14,37 @@
import os import os
import paddle import paddle
from ...common import load_onnx_model
__all__ = ['load_inference_model'] __all__ = ['load_inference_model', 'get_model_dir']
def load_inference_model(path_prefix, def load_inference_model(path_prefix,
executor, executor,
model_filename=None, model_filename=None,
params_filename=None): params_filename=None):
if model_filename is not None and params_filename is not None: # Load onnx model to Inference model.
if path_prefix.endswith('.onnx'):
inference_program, feed_target_names, fetch_targets = load_onnx_model(
path_prefix)
return [inference_program, feed_target_names, fetch_targets]
# Load Inference model.
# TODO: clean code
if model_filename is not None and model_filename.endswith('.pdmodel'):
model_name = '.'.join(model_filename.split('.')[:-1])
assert os.path.exists(
os.path.join(path_prefix, model_name + '.pdmodel')
), 'Please check {}, or fix model_filename parameter.'.format(
os.path.join(path_prefix, model_name + '.pdmodel'))
assert os.path.exists(
os.path.join(path_prefix, model_name + '.pdiparams')
), 'Please check {}, or fix params_filename parameter.'.format(
os.path.join(path_prefix, model_name + '.pdiparams'))
model_path_prefix = os.path.join(path_prefix, model_name)
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=executor))
elif model_filename is not None and params_filename is not None:
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model( paddle.static.load_inference_model(
path_prefix=path_prefix, path_prefix=path_prefix,
...@@ -43,3 +65,26 @@ def load_inference_model(path_prefix, ...@@ -43,3 +65,26 @@ def load_inference_model(path_prefix,
path_prefix=path_prefix, executor=executor)) path_prefix=path_prefix, executor=executor))
return [inference_program, feed_target_names, fetch_targets] return [inference_program, feed_target_names, fetch_targets]
def get_model_dir(model_dir, model_filename, params_filename):
if model_dir.endswith('.onnx'):
updated_model_dir = model_dir.rstrip().rstrip('.onnx') + '_infer'
else:
updated_model_dir = model_dir.rstrip('/')
if model_filename == None:
updated_model_filename = 'model.pdmodel'
else:
updated_model_filename = model_filename
if params_filename == None:
updated_params_filename = 'model.pdiparams'
else:
updated_params_filename = params_filename
if params_filename is None and model_filename is not None:
raise NotImplementedError(
"NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
)
return updated_model_dir, updated_model_filename, updated_params_filename
...@@ -25,11 +25,12 @@ from .analyze_helper import VarCollector ...@@ -25,11 +25,12 @@ from .analyze_helper import VarCollector
from . import wrapper_function from . import wrapper_function
from . import recover_program from . import recover_program
from . import patterns from . import patterns
from .convert_model import load_onnx_model
__all__ = [ __all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer', 'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter', 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter',
'Server', 'Client', 'RLBaseController', 'VarCollector' 'Server', 'Client', 'RLBaseController', 'VarCollector', 'load_onnx_model'
] ]
__all__ += wrapper_function.__all__ __all__ += wrapper_function.__all__
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import logging
import os
import shutil
import sys
import paddle
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
from x2paddle.optimizer.optimizer import GraphOptimizer
from x2paddle.utils import ConverterCheck
from . import get_logger
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ['load_onnx_model']
def load_onnx_model(model_path, disable_feedback=False):
assert model_path.endswith(
'.onnx'
), '{} does not end with .onnx suffix and cannot be loaded.'.format(
model_path)
inference_model_path = model_path.rstrip().rstrip('.onnx') + '_infer'
exe = paddle.static.Executor(paddle.CPUPlace())
if os.path.exists(os.path.join(
inference_model_path, 'model.pdmodel')) and os.path.exists(
os.path.join(inference_model_path, 'model.pdiparams')):
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
os.path.join(inference_model_path, 'model'), exe)
_logger.info('Loaded model from: {}'.format(inference_model_path))
return val_program, feed_target_names, fetch_targets
else:
# onnx to paddle inference model.
time_info = int(time.time())
if not disable_feedback:
ConverterCheck(
task="ONNX", time_info=time_info, convert_state="Start").start()
# check onnx installation and version
try:
import onnx
version = onnx.version.version
v0, v1, v2 = version.split('.')
version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
if version_sum < 160:
_logger.info("[ERROR] onnx>=1.6.0 is required")
sys.exit(1)
except:
_logger.info(
"[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\"."
)
sys.exit(1)
try:
_logger.info("Now translating model from onnx to paddle.")
model = ONNXDecoder(model_path)
mapper = ONNXOpMapper(model)
mapper.paddle_graph.build()
graph_opt = GraphOptimizer(source_frame="onnx")
graph_opt.optimize(mapper.paddle_graph)
_logger.info("Model optimized.")
onnx2paddle_out_dir = os.path.join(inference_model_path,
'onnx2paddle')
mapper.paddle_graph.gen_model(onnx2paddle_out_dir)
_logger.info("Successfully exported Paddle static graph model!")
if not disable_feedback:
ConverterCheck(
task="ONNX", time_info=time_info,
convert_state="Success").start()
shutil.move(
os.path.join(onnx2paddle_out_dir, 'inference_model',
'model.pdmodel'),
os.path.join(inference_model_path, 'model.pdmodel'))
shutil.move(
os.path.join(onnx2paddle_out_dir, 'inference_model',
'model.pdiparams'),
os.path.join(inference_model_path, 'model.pdiparams'))
except:
_logger.info(
"[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues"
)
sys.exit(1)
paddle.enable_static()
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
os.path.join(inference_model_path, 'model'), exe)
_logger.info('Loaded model from: {}'.format(inference_model_path))
return val_program, feed_target_names, fetch_targets
#paddlepaddle == 1.6.0rc0
tqdm tqdm
pyzmq pyzmq
matplotlib matplotlib
...@@ -6,3 +5,5 @@ pillow ...@@ -6,3 +5,5 @@ pillow
pyyaml pyyaml
scikit-learn scikit-learn
smac smac
onnx
x2paddle==1.3.8
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
from paddle.io import Dataset from paddle.io import Dataset
from paddleslim.auto_compression import AutoCompression from paddleslim.auto_compression import AutoCompression
from paddleslim.auto_compression.config_helpers import load_config from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression.utils.load_model import load_inference_model
class RandomEvalDataset(Dataset): class RandomEvalDataset(Dataset):
...@@ -120,5 +121,28 @@ class TestDictQATDist(ACTBase): ...@@ -120,5 +121,28 @@ class TestDictQATDist(ACTBase):
ac.compress() ac.compress()
class TestLoadONNXModel(ACTBase):
def __init__(self, *args, **kwargs):
super(TestLoadONNXModel, self).__init__(*args, **kwargs)
os.system(
'wget https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx')
self.model_dir = 'yolov5s.onnx'
def test_compress(self):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
_, _, _ = load_inference_model(
self.model_dir,
executor=exe,
model_filename='model.pdmodel',
params_filename='model.paiparams')
# reload model
_, _, _ = load_inference_model(
self.model_dir,
executor=exe,
model_filename='model.pdmodel',
params_filename='model.paiparams')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册