提交 030daa23 编写于 作者: W wuzewu

Add an interface for exporting onnx format model.

上级 247f2e92
......@@ -18,6 +18,7 @@ import os
from typing import Tuple, List
import paddle
import paddle2onnx
from easydict import EasyDict
from paddlehub.compat import paddle_utils
......@@ -280,3 +281,30 @@ class ModuleV1(object):
target_vars=list(fetch_dict.values()),
model_filename=model_filename,
params_filename=params_filename)
def export_onnx_model(self, dirname: str, **kwargs):
'''
Export the model to ONNX format.
Args:
dirname(str): The directory to save the onnx model.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
'''
feed_dict, fetch_dict, program = self.context(for_test=True, trainable=False)
inputs = set([var.name for var in feed_dict.values()])
if self.type == 'CV/classification':
outputs = [fetch_dict['class_probs']]
else:
outputs = set([var.name for var in fetch_dict.values()])
outputs = [program.global_block().vars[key] for key in outputs]
save_file = os.path.join(dirname, '{}.onnx'.format(self.name))
paddle2onnx.program2onnx(
program=program,
scope=paddle.static.global_scope(),
feed_var_names=inputs,
target_vars=outputs,
save_file=save_file,
**kwargs)
......@@ -22,9 +22,10 @@ import re
import sys
from typing import Callable, Generic, List, Optional, Union
import paddle
import paddle2onnx
from easydict import EasyDict
import paddle
from paddlehub.utils import parser, log, utils
from paddlehub.compat import paddle_utils
from paddlehub.compat.module.module_v1 import ModuleV1
......@@ -131,6 +132,77 @@ class RunModule(object):
def serving_func_name(self):
return self._get_func_name(self.__class__, _module_serving_func)
@property
def _pretrained_model_path(self):
_pretrained_model_attrs = [
'pretrained_model_path', 'rec_pretrained_model_path', 'default_pretrained_model_path'
]
for _attr in _pretrained_model_attrs:
if hasattr(self, _attr):
path = getattr(self, _attr)
if os.path.exists(path) and os.path.isfile(path):
path = os.path.dirname(path)
return path
return None
def export_onnx_model(self, dirname: str, **kwargs):
'''
Export the model to ONNX format.
Args:
dirname(str): The directory to save the onnx model.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
'''
if not self._pretrained_model_path:
if isinstance(self, paddle.nn.Layer):
save_file = os.path.join(dirname, '{}'.format(self.name))
if hasattr(self, 'input_spec'):
input_spec = self.input_sepc
else:
_type = self.type.lower()
if _type.startswith('cv/image'):
input_spec = paddle.static.InputSpec(shape=[None, 3, None, None], dtype='float32')
else:
raise NotImplementedError
paddle.onnx.export(self, save_file, input_spec=[input_spec])
return
raise NotImplementedError
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
model_filename = None
params_filename = None
if os.path.exists(os.path.join(self._pretrained_model_path, 'model')):
model_filename = 'model'
if os.path.exists(os.path.join(self._pretrained_model_path, 'params')):
params_filename = 'params'
if os.path.exists(os.path.join(self._pretrained_model_path, '__params__')):
params_filename = '__params__'
save_file = os.path.join(dirname, '{}.onnx'.format(self.name))
program, inputs, outputs = paddle.fluid.io.load_inference_model(
dirname=self._pretrained_model_path,
model_filename=model_filename,
params_filename=params_filename,
executor=exe)
paddle2onnx.program2onnx(
program=program,
scope=paddle.static.global_scope(),
feed_var_names=inputs,
target_vars=outputs,
save_file=save_file,
**kwargs)
class Module(object):
'''
......
......@@ -8,6 +8,8 @@ numpy
matplotlib
opencv-python
packaging
paddle2onnx >= 0.5
paddlenlp >= 2.0.0rc5
Pillow
pyyaml
pyzmq
......@@ -16,4 +18,3 @@ tqdm
visualdl >= 2.0.0
# gunicorn not support windows
gunicorn >= 19.10.0; sys_platform != "win32"
paddlenlp >= 2.0.0rc5
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册