From 47d7cac155b11fb883b67a318e7bbc679e6c5c36 Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 28 Nov 2022 11:50:18 +0800 Subject: [PATCH] Fix save_inference_model bug in paddlehub (#2143) --- paddlehub/module/module.py | 105 ++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 49 deletions(-) diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index d494eb39..a9d10ce0 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -37,7 +37,6 @@ from paddlehub.utils import utils class InvalidHubModule(Exception): - def __init__(self, directory: str): self.directory = directory @@ -200,11 +199,12 @@ class RunModule(object): for key, _sub_module in self.sub_modules().items(): try: sub_dirname = os.path.normpath(os.path.join(dirname, key)) - _sub_module.save_inference_model(sub_dirname, - include_sub_modules=include_sub_modules, - model_filename=model_filename, - params_filename=params_filename, - combined=combined) + _sub_module.save_inference_model( + sub_dirname, + include_sub_modules=include_sub_modules, + model_filename=model_filename, + params_filename=params_filename, + combined=combined) except: utils.record_exception('Failed to save sub module {}'.format(_sub_module.name)) @@ -231,14 +231,11 @@ class RunModule(object): if not self._pretrained_model_path: raise RuntimeError('Module {} does not support exporting models in Paddle Inference format.'.format( self.name)) - elif not os.path.exists(self._pretrained_model_path): + elif not os.path.exists( + self._pretrained_model_path) and not os.path.exists(self._pretrained_model_path + '.pdmodel'): log.logger.warning('The model path of Module {} does not exist.'.format(self.name)) return - model_filename = '__model__' if not model_filename else model_filename - if combined: - params_filename = '__params__' if not params_filename else params_filename - place = paddle.CPUPlace() exe = paddle.static.Executor(place) @@ -253,21 +250,25 @@ class RunModule(object): if os.path.exists(os.path.join(self._pretrained_model_path, '__params__')): _params_filename = '__params__' + if _model_filename is not None and _params_filename is not None: + program, feeded_var_names, target_vars = paddle.static.load_inference_model( + self._pretrained_model_path, + executor=exe, + model_filename=_model_filename, + params_filename=_params_filename, + ) + else: + program, feeded_var_names, target_vars = paddle.static.load_inference_model( + self._pretrained_model_path, executor=exe) - program, feeded_var_names, target_vars = paddle.static.load_inference_model( - dirname=self._pretrained_model_path, - executor=exe, - model_filename=_model_filename, - params_filename=_params_filename, - ) - - paddle.static.save_inference_model(dirname=dirname, - main_program=program, - executor=exe, - feeded_var_names=feeded_var_names, - target_vars=target_vars, - model_filename=model_filename, - params_filename=params_filename) + global_block = program.global_block() + feed_vars = [global_block.var(item) for item in feeded_var_names] + + path_prefix = dirname + if os.path.isdir(dirname): + path_prefix = os.path.join(dirname, 'model') + paddle.static.save_inference_model( + path_prefix, feed_vars=feed_vars, fetch_vars=target_vars, executor=exe, program=program) log.logger.info('Paddle Inference model saved in {}.'.format(dirname)) @@ -337,17 +338,19 @@ class RunModule(object): save_file = os.path.join(dirname, '{}.onnx'.format(self.name)) - program, inputs, outputs = paddle.static.load_inference_model(dirname=self._pretrained_model_path, - model_filename=model_filename, - params_filename=params_filename, - executor=exe) + program, inputs, outputs = paddle.static.load_inference_model( + dirname=self._pretrained_model_path, + model_filename=model_filename, + params_filename=params_filename, + executor=exe) - paddle2onnx.program2onnx(program=program, - scope=paddle.static.global_scope(), - feed_var_names=inputs, - target_vars=outputs, - save_file=save_file, - **kwargs) + paddle2onnx.program2onnx( + program=program, + scope=paddle.static.global_scope(), + feed_var_names=inputs, + target_vars=outputs, + save_file=save_file, + **kwargs) class Module(object): @@ -387,13 +390,14 @@ class Module(object): from paddlehub.server.server import CacheUpdater # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx') if name: - module = cls.init_with_name(name=name, - version=version, - source=source, - update=update, - branch=branch, - ignore_env_mismatch=ignore_env_mismatch, - **kwargs) + module = cls.init_with_name( + name=name, + version=version, + source=source, + update=update, + branch=branch, + ignore_env_mismatch=ignore_env_mismatch, + **kwargs) CacheUpdater("update_cache", module=name, version=version).start() elif directory: module = cls.init_with_directory(directory=directory, **kwargs) @@ -485,12 +489,13 @@ class Module(object): manager = LocalModuleManager() user_module_cls = manager.search(name, source=source, branch=branch) if not user_module_cls or not user_module_cls.version.match(version): - user_module_cls = manager.install(name=name, - version=version, - source=source, - update=update, - branch=branch, - ignore_env_mismatch=ignore_env_mismatch) + user_module_cls = manager.install( + name=name, + version=version, + source=source, + update=update, + branch=branch, + ignore_env_mismatch=ignore_env_mismatch) directory = manager._get_normalized_path(user_module_cls.name) @@ -555,7 +560,9 @@ def moduleinfo(name: str, _bases.append(_b) _bases.append(_meta) _bases = tuple(_bases) - wrap_cls = builtins.type(cls.__name__, _bases, dict(cls.__dict__)) + attr_dict = dict(cls.__dict__) + attr_dict.pop('__dict__', None) + wrap_cls = builtins.type(cls.__name__, _bases, attr_dict) wrap_cls.name = name wrap_cls.version = utils.Version(version) -- GitLab