未验证 提交 47d7cac1 编写于 作者: C chenjian 提交者: GitHub

Fix save_inference_model bug in paddlehub (#2143)

上级 52766374
...@@ -37,7 +37,6 @@ from paddlehub.utils import utils ...@@ -37,7 +37,6 @@ from paddlehub.utils import utils
class InvalidHubModule(Exception): class InvalidHubModule(Exception):
def __init__(self, directory: str): def __init__(self, directory: str):
self.directory = directory self.directory = directory
...@@ -200,11 +199,12 @@ class RunModule(object): ...@@ -200,11 +199,12 @@ class RunModule(object):
for key, _sub_module in self.sub_modules().items(): for key, _sub_module in self.sub_modules().items():
try: try:
sub_dirname = os.path.normpath(os.path.join(dirname, key)) sub_dirname = os.path.normpath(os.path.join(dirname, key))
_sub_module.save_inference_model(sub_dirname, _sub_module.save_inference_model(
include_sub_modules=include_sub_modules, sub_dirname,
model_filename=model_filename, include_sub_modules=include_sub_modules,
params_filename=params_filename, model_filename=model_filename,
combined=combined) params_filename=params_filename,
combined=combined)
except: except:
utils.record_exception('Failed to save sub module {}'.format(_sub_module.name)) utils.record_exception('Failed to save sub module {}'.format(_sub_module.name))
...@@ -231,14 +231,11 @@ class RunModule(object): ...@@ -231,14 +231,11 @@ class RunModule(object):
if not self._pretrained_model_path: if not self._pretrained_model_path:
raise RuntimeError('Module {} does not support exporting models in Paddle Inference format.'.format( raise RuntimeError('Module {} does not support exporting models in Paddle Inference format.'.format(
self.name)) self.name))
elif not os.path.exists(self._pretrained_model_path): elif not os.path.exists(
self._pretrained_model_path) and not os.path.exists(self._pretrained_model_path + '.pdmodel'):
log.logger.warning('The model path of Module {} does not exist.'.format(self.name)) log.logger.warning('The model path of Module {} does not exist.'.format(self.name))
return return
model_filename = '__model__' if not model_filename else model_filename
if combined:
params_filename = '__params__' if not params_filename else params_filename
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
...@@ -253,21 +250,25 @@ class RunModule(object): ...@@ -253,21 +250,25 @@ class RunModule(object):
if os.path.exists(os.path.join(self._pretrained_model_path, '__params__')): if os.path.exists(os.path.join(self._pretrained_model_path, '__params__')):
_params_filename = '__params__' _params_filename = '__params__'
if _model_filename is not None and _params_filename is not None:
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
self._pretrained_model_path,
executor=exe,
model_filename=_model_filename,
params_filename=_params_filename,
)
else:
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
self._pretrained_model_path, executor=exe)
program, feeded_var_names, target_vars = paddle.static.load_inference_model( global_block = program.global_block()
dirname=self._pretrained_model_path, feed_vars = [global_block.var(item) for item in feeded_var_names]
executor=exe,
model_filename=_model_filename, path_prefix = dirname
params_filename=_params_filename, if os.path.isdir(dirname):
) path_prefix = os.path.join(dirname, 'model')
paddle.static.save_inference_model(
paddle.static.save_inference_model(dirname=dirname, path_prefix, feed_vars=feed_vars, fetch_vars=target_vars, executor=exe, program=program)
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
log.logger.info('Paddle Inference model saved in {}.'.format(dirname)) log.logger.info('Paddle Inference model saved in {}.'.format(dirname))
...@@ -337,17 +338,19 @@ class RunModule(object): ...@@ -337,17 +338,19 @@ class RunModule(object):
save_file = os.path.join(dirname, '{}.onnx'.format(self.name)) save_file = os.path.join(dirname, '{}.onnx'.format(self.name))
program, inputs, outputs = paddle.static.load_inference_model(dirname=self._pretrained_model_path, program, inputs, outputs = paddle.static.load_inference_model(
model_filename=model_filename, dirname=self._pretrained_model_path,
params_filename=params_filename, model_filename=model_filename,
executor=exe) params_filename=params_filename,
executor=exe)
paddle2onnx.program2onnx(program=program, paddle2onnx.program2onnx(
scope=paddle.static.global_scope(), program=program,
feed_var_names=inputs, scope=paddle.static.global_scope(),
target_vars=outputs, feed_var_names=inputs,
save_file=save_file, target_vars=outputs,
**kwargs) save_file=save_file,
**kwargs)
class Module(object): class Module(object):
...@@ -387,13 +390,14 @@ class Module(object): ...@@ -387,13 +390,14 @@ class Module(object):
from paddlehub.server.server import CacheUpdater from paddlehub.server.server import CacheUpdater
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx') # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name: if name:
module = cls.init_with_name(name=name, module = cls.init_with_name(
version=version, name=name,
source=source, version=version,
update=update, source=source,
branch=branch, update=update,
ignore_env_mismatch=ignore_env_mismatch, branch=branch,
**kwargs) ignore_env_mismatch=ignore_env_mismatch,
**kwargs)
CacheUpdater("update_cache", module=name, version=version).start() CacheUpdater("update_cache", module=name, version=version).start()
elif directory: elif directory:
module = cls.init_with_directory(directory=directory, **kwargs) module = cls.init_with_directory(directory=directory, **kwargs)
...@@ -485,12 +489,13 @@ class Module(object): ...@@ -485,12 +489,13 @@ class Module(object):
manager = LocalModuleManager() manager = LocalModuleManager()
user_module_cls = manager.search(name, source=source, branch=branch) user_module_cls = manager.search(name, source=source, branch=branch)
if not user_module_cls or not user_module_cls.version.match(version): if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name=name, user_module_cls = manager.install(
version=version, name=name,
source=source, version=version,
update=update, source=source,
branch=branch, update=update,
ignore_env_mismatch=ignore_env_mismatch) branch=branch,
ignore_env_mismatch=ignore_env_mismatch)
directory = manager._get_normalized_path(user_module_cls.name) directory = manager._get_normalized_path(user_module_cls.name)
...@@ -555,7 +560,9 @@ def moduleinfo(name: str, ...@@ -555,7 +560,9 @@ def moduleinfo(name: str,
_bases.append(_b) _bases.append(_b)
_bases.append(_meta) _bases.append(_meta)
_bases = tuple(_bases) _bases = tuple(_bases)
wrap_cls = builtins.type(cls.__name__, _bases, dict(cls.__dict__)) attr_dict = dict(cls.__dict__)
attr_dict.pop('__dict__', None)
wrap_cls = builtins.type(cls.__name__, _bases, attr_dict)
wrap_cls.name = name wrap_cls.name = name
wrap_cls.version = utils.Version(version) wrap_cls.version = utils.Version(version)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册