未验证 提交 0dd15555 编写于 作者: G Guanghua Yu 提交者: GitHub

support paddle2onnx in act (#1349)

上级 3e972ce0
......@@ -113,6 +113,7 @@ def main():
config=all_config,
eval_callback=eval_func)
ac.compress()
ac.export_onnx()
if __name__ == '__main__':
......
......@@ -29,7 +29,7 @@ from ..quant.quanter import convert, quant_post
from ..common.recover_program import recover_inference_program
from ..common import get_logger
from ..common.patterns import get_patterns
from ..common.load_model import load_inference_model, get_model_dir
from ..common.load_model import load_inference_model, get_model_dir, export_onnx
from ..common.dataloader import wrap_dataloader, get_feed_vars
from ..common.config_helper import load_config
from ..analysis import TableLatencyPredictor
......@@ -826,3 +826,17 @@ class AutoCompression:
fetch_vars=test_program_info.fetch_targets,
executor=self._exe,
program=test_program)
def export_onnx(self,
model_name='quant_model.onnx',
deploy_backend='tensorrt'):
infer_model_path = os.path.join(self.final_dir, self.model_filename)
assert os.path.exists(
infer_model_path), 'Not found {}, please check it.'.format(
infer_model_path)
export_onnx(
self.final_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
save_file_path=os.path.join(self.final_dir, model_name),
deploy_backend=deploy_backend)
......@@ -25,7 +25,7 @@ from .analyze_helper import VarCollector
from . import wrapper_function
from . import recover_program
from . import patterns
from .load_model import load_inference_model, get_model_dir, load_onnx_model
from .load_model import load_inference_model, get_model_dir, load_onnx_model, export_onnx
from .dataloader import wrap_dataloader, get_feed_vars
from .config_helper import load_config, save_config
......
......@@ -17,16 +17,15 @@ import logging
import os
import shutil
import sys
import pkg_resources as pkg
import paddle
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
from x2paddle.optimizer.optimizer import GraphOptimizer
from x2paddle.utils import ConverterCheck
from . import get_logger
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ['load_inference_model', 'get_model_dir', 'load_onnx_model']
__all__ = [
'load_inference_model', 'get_model_dir', 'load_onnx_model', 'export_onnx'
]
def load_inference_model(path_prefix,
......@@ -116,24 +115,37 @@ def load_onnx_model(model_path, disable_feedback=False):
return val_program, feed_target_names, fetch_targets
else:
# onnx to paddle inference model.
try:
pkg.require('x2paddle')
except:
from pip._internal import main
main(['install', 'x2paddle'])
try:
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
from x2paddle.optimizer.optimizer import GraphOptimizer
from x2paddle.utils import ConverterCheck
except:
_logger.error(
"x2paddle is not installed, please use \"pip install x2paddle\"."
)
time_info = int(time.time())
if not disable_feedback:
ConverterCheck(
task="ONNX", time_info=time_info, convert_state="Start").start()
# check onnx installation and version
try:
pkg.require('onnx')
import onnx
version = onnx.version.version
v0, v1, v2 = version.split('.')
version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
if version_sum < 160:
_logger.info("[ERROR] onnx>=1.6.0 is required")
sys.exit(1)
_logger.error(
"onnx>=1.6.0 is required, please use \"pip install onnx\".")
except:
_logger.info(
"[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\"."
)
sys.exit(1)
from pip._internal import main
main(['install', 'onnx==1.12.0'])
# support distributed convert model
model_idx = paddle.distributed.get_rank(
......@@ -182,4 +194,30 @@ def load_onnx_model(model_path, disable_feedback=False):
shutil.rmtree(
os.path.join(inference_model_path, 'onnx2paddle_{}'.format(
model_idx)))
return val_program, feed_target_names, fetch_targets
\ No newline at end of file
return val_program, feed_target_names, fetch_targets
def export_onnx(model_dir,
model_filename=None,
params_filename=None,
save_file_path='output.onnx',
opset_version=13,
deploy_backend='tensorrt'):
if not model_filename:
model_filename = 'model.pdmodel'
if not params_filename:
params_filename = 'model.pdiparams'
try:
pkg.require('paddle2onnx')
except:
from pip._internal import main
main(['install', 'paddle2onnx==1.0.0rc3'])
import paddle2onnx
paddle2onnx.command.c_paddle_to_onnx(
model_file=os.path.join(model_dir, model_filename),
params_file=os.path.join(model_dir, params_filename),
save_file=save_file_path,
opset_version=opset_version,
enable_onnx_checker=True,
deploy_backend=deploy_backend)
_logger.info('Convert model to ONNX: {}'.format(save_file_path))
......@@ -8,8 +8,8 @@ import unittest
import numpy as np
from paddle.io import Dataset
from paddleslim.auto_compression import AutoCompression
from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression.utils.load_model import load_inference_model
from paddleslim.common import load_config
from paddleslim.common import load_inference_model, export_onnx
class RandomEvalDataset(Dataset):
......@@ -142,6 +142,14 @@ class TestLoadONNXModel(ACTBase):
executor=exe,
model_filename='model.pdmodel',
params_filename='model.paiparams')
# convert onnx
export_onnx(
self.model_dir,
model_filename='model.pdmodel',
params_filename='model.paiparams',
save_file_path='output.onnx',
opset_version=13,
deploy_backend='tensorrt')
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册