未验证 提交 47d7cac1 编写于 作者: C chenjian 提交者: GitHub

Fix save_inference_model bug in paddlehub (#2143)

上级 52766374
...@@ -37,7 +37,6 @@ from paddlehub.utils import utils ...@@ -37,7 +37,6 @@ from paddlehub.utils import utils
class InvalidHubModule(Exception): class InvalidHubModule(Exception):
def __init__(self, directory: str): def __init__(self, directory: str):
self.directory = directory self.directory = directory
...@@ -200,7 +199,8 @@ class RunModule(object): ...@@ -200,7 +199,8 @@ class RunModule(object):
for key, _sub_module in self.sub_modules().items(): for key, _sub_module in self.sub_modules().items():
try: try:
sub_dirname = os.path.normpath(os.path.join(dirname, key)) sub_dirname = os.path.normpath(os.path.join(dirname, key))
_sub_module.save_inference_model(sub_dirname, _sub_module.save_inference_model(
sub_dirname,
include_sub_modules=include_sub_modules, include_sub_modules=include_sub_modules,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
...@@ -231,14 +231,11 @@ class RunModule(object): ...@@ -231,14 +231,11 @@ class RunModule(object):
if not self._pretrained_model_path: if not self._pretrained_model_path:
raise RuntimeError('Module {} does not support exporting models in Paddle Inference format.'.format( raise RuntimeError('Module {} does not support exporting models in Paddle Inference format.'.format(
self.name)) self.name))
elif not os.path.exists(self._pretrained_model_path): elif not os.path.exists(
self._pretrained_model_path) and not os.path.exists(self._pretrained_model_path + '.pdmodel'):
log.logger.warning('The model path of Module {} does not exist.'.format(self.name)) log.logger.warning('The model path of Module {} does not exist.'.format(self.name))
return return
model_filename = '__model__' if not model_filename else model_filename
if combined:
params_filename = '__params__' if not params_filename else params_filename
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
...@@ -253,21 +250,25 @@ class RunModule(object): ...@@ -253,21 +250,25 @@ class RunModule(object):
if os.path.exists(os.path.join(self._pretrained_model_path, '__params__')): if os.path.exists(os.path.join(self._pretrained_model_path, '__params__')):
_params_filename = '__params__' _params_filename = '__params__'
if _model_filename is not None and _params_filename is not None:
program, feeded_var_names, target_vars = paddle.static.load_inference_model( program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self._pretrained_model_path, self._pretrained_model_path,
executor=exe, executor=exe,
model_filename=_model_filename, model_filename=_model_filename,
params_filename=_params_filename, params_filename=_params_filename,
) )
else:
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
self._pretrained_model_path, executor=exe)
paddle.static.save_inference_model(dirname=dirname, global_block = program.global_block()
main_program=program, feed_vars = [global_block.var(item) for item in feeded_var_names]
executor=exe,
feeded_var_names=feeded_var_names, path_prefix = dirname
target_vars=target_vars, if os.path.isdir(dirname):
model_filename=model_filename, path_prefix = os.path.join(dirname, 'model')
params_filename=params_filename) paddle.static.save_inference_model(
path_prefix, feed_vars=feed_vars, fetch_vars=target_vars, executor=exe, program=program)
log.logger.info('Paddle Inference model saved in {}.'.format(dirname)) log.logger.info('Paddle Inference model saved in {}.'.format(dirname))
...@@ -337,12 +338,14 @@ class RunModule(object): ...@@ -337,12 +338,14 @@ class RunModule(object):
save_file = os.path.join(dirname, '{}.onnx'.format(self.name)) save_file = os.path.join(dirname, '{}.onnx'.format(self.name))
program, inputs, outputs = paddle.static.load_inference_model(dirname=self._pretrained_model_path, program, inputs, outputs = paddle.static.load_inference_model(
dirname=self._pretrained_model_path,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
executor=exe) executor=exe)
paddle2onnx.program2onnx(program=program, paddle2onnx.program2onnx(
program=program,
scope=paddle.static.global_scope(), scope=paddle.static.global_scope(),
feed_var_names=inputs, feed_var_names=inputs,
target_vars=outputs, target_vars=outputs,
...@@ -387,7 +390,8 @@ class Module(object): ...@@ -387,7 +390,8 @@ class Module(object):
from paddlehub.server.server import CacheUpdater from paddlehub.server.server import CacheUpdater
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx') # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name: if name:
module = cls.init_with_name(name=name, module = cls.init_with_name(
name=name,
version=version, version=version,
source=source, source=source,
update=update, update=update,
...@@ -485,7 +489,8 @@ class Module(object): ...@@ -485,7 +489,8 @@ class Module(object):
manager = LocalModuleManager() manager = LocalModuleManager()
user_module_cls = manager.search(name, source=source, branch=branch) user_module_cls = manager.search(name, source=source, branch=branch)
if not user_module_cls or not user_module_cls.version.match(version): if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name=name, user_module_cls = manager.install(
name=name,
version=version, version=version,
source=source, source=source,
update=update, update=update,
...@@ -555,7 +560,9 @@ def moduleinfo(name: str, ...@@ -555,7 +560,9 @@ def moduleinfo(name: str,
_bases.append(_b) _bases.append(_b)
_bases.append(_meta) _bases.append(_meta)
_bases = tuple(_bases) _bases = tuple(_bases)
wrap_cls = builtins.type(cls.__name__, _bases, dict(cls.__dict__)) attr_dict = dict(cls.__dict__)
attr_dict.pop('__dict__', None)
wrap_cls = builtins.type(cls.__name__, _bases, attr_dict)
wrap_cls.name = name wrap_cls.name = name
wrap_cls.version = utils.Version(version) wrap_cls.version = utils.Version(version)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册