提交 030daa23 编写于 作者: W wuzewu

Add an interface for exporting onnx format model.

上级 247f2e92
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
from typing import Tuple, List from typing import Tuple, List
import paddle import paddle
import paddle2onnx
from easydict import EasyDict from easydict import EasyDict
from paddlehub.compat import paddle_utils from paddlehub.compat import paddle_utils
...@@ -280,3 +281,30 @@ class ModuleV1(object): ...@@ -280,3 +281,30 @@ class ModuleV1(object):
target_vars=list(fetch_dict.values()), target_vars=list(fetch_dict.values()),
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename) params_filename=params_filename)
def export_onnx_model(self, dirname: str, **kwargs):
'''
Export the model to ONNX format.
Args:
dirname(str): The directory to save the onnx model.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
'''
feed_dict, fetch_dict, program = self.context(for_test=True, trainable=False)
inputs = set([var.name for var in feed_dict.values()])
if self.type == 'CV/classification':
outputs = [fetch_dict['class_probs']]
else:
outputs = set([var.name for var in fetch_dict.values()])
outputs = [program.global_block().vars[key] for key in outputs]
save_file = os.path.join(dirname, '{}.onnx'.format(self.name))
paddle2onnx.program2onnx(
program=program,
scope=paddle.static.global_scope(),
feed_var_names=inputs,
target_vars=outputs,
save_file=save_file,
**kwargs)
...@@ -22,9 +22,10 @@ import re ...@@ -22,9 +22,10 @@ import re
import sys import sys
from typing import Callable, Generic, List, Optional, Union from typing import Callable, Generic, List, Optional, Union
import paddle
import paddle2onnx
from easydict import EasyDict from easydict import EasyDict
import paddle
from paddlehub.utils import parser, log, utils from paddlehub.utils import parser, log, utils
from paddlehub.compat import paddle_utils from paddlehub.compat import paddle_utils
from paddlehub.compat.module.module_v1 import ModuleV1 from paddlehub.compat.module.module_v1 import ModuleV1
...@@ -131,6 +132,77 @@ class RunModule(object): ...@@ -131,6 +132,77 @@ class RunModule(object):
def serving_func_name(self): def serving_func_name(self):
return self._get_func_name(self.__class__, _module_serving_func) return self._get_func_name(self.__class__, _module_serving_func)
@property
def _pretrained_model_path(self):
_pretrained_model_attrs = [
'pretrained_model_path', 'rec_pretrained_model_path', 'default_pretrained_model_path'
]
for _attr in _pretrained_model_attrs:
if hasattr(self, _attr):
path = getattr(self, _attr)
if os.path.exists(path) and os.path.isfile(path):
path = os.path.dirname(path)
return path
return None
def export_onnx_model(self, dirname: str, **kwargs):
'''
Export the model to ONNX format.
Args:
dirname(str): The directory to save the onnx model.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
'''
if not self._pretrained_model_path:
if isinstance(self, paddle.nn.Layer):
save_file = os.path.join(dirname, '{}'.format(self.name))
if hasattr(self, 'input_spec'):
input_spec = self.input_sepc
else:
_type = self.type.lower()
if _type.startswith('cv/image'):
input_spec = paddle.static.InputSpec(shape=[None, 3, None, None], dtype='float32')
else:
raise NotImplementedError
paddle.onnx.export(self, save_file, input_spec=[input_spec])
return
raise NotImplementedError
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
model_filename = None
params_filename = None
if os.path.exists(os.path.join(self._pretrained_model_path, 'model')):
model_filename = 'model'
if os.path.exists(os.path.join(self._pretrained_model_path, 'params')):
params_filename = 'params'
if os.path.exists(os.path.join(self._pretrained_model_path, '__params__')):
params_filename = '__params__'
save_file = os.path.join(dirname, '{}.onnx'.format(self.name))
program, inputs, outputs = paddle.fluid.io.load_inference_model(
dirname=self._pretrained_model_path,
model_filename=model_filename,
params_filename=params_filename,
executor=exe)
paddle2onnx.program2onnx(
program=program,
scope=paddle.static.global_scope(),
feed_var_names=inputs,
target_vars=outputs,
save_file=save_file,
**kwargs)
class Module(object): class Module(object):
''' '''
......
...@@ -8,6 +8,8 @@ numpy ...@@ -8,6 +8,8 @@ numpy
matplotlib matplotlib
opencv-python opencv-python
packaging packaging
paddle2onnx >= 0.5
paddlenlp >= 2.0.0rc5
Pillow Pillow
pyyaml pyyaml
pyzmq pyzmq
...@@ -16,4 +18,3 @@ tqdm ...@@ -16,4 +18,3 @@ tqdm
visualdl >= 2.0.0 visualdl >= 2.0.0
# gunicorn not support windows # gunicorn not support windows
gunicorn >= 19.10.0; sys_platform != "win32" gunicorn >= 19.10.0; sys_platform != "win32"
paddlenlp >= 2.0.0rc5
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册