diff --git a/paddlehub/compat/module/module_v1.py b/paddlehub/compat/module/module_v1.py index cbaf452c585d2379612bcca219bcafcb843781c5..48b19bd5aa7ac9ac6b89c605ead5dc03fee41ace 100644 --- a/paddlehub/compat/module/module_v1.py +++ b/paddlehub/compat/module/module_v1.py @@ -98,6 +98,9 @@ class ModuleV1(object): log.logger.info('{} pretrained paramaters loaded by PaddleHub'.format(num_param_loaded)) def _load_extra_info(self): + if not 'extra_info' in self.desc: + return + for key, value in self.desc.extra_info.items(): self.__dict__['get_{}'.format(key)] = value @@ -108,7 +111,7 @@ class ModuleV1(object): def _load_model(self): model_path = os.path.join(self.directory, 'model') exe = paddle.static.Executor(paddle.CPUPlace()) - self.program, _, _ = paddle.static.load_inference_model(model_path, executor=exe) + self.program, _, _ = paddle.fluid.io.load_inference_model(model_path, executor=exe) # Clear the callstack since it may leak the privacy of the creator. for block in self.program.blocks: @@ -240,6 +243,9 @@ class ModuleV1(object): def assets_path(self): return os.path.join(self.directory, 'assets') + def get_name_prefix(self): + return self.desc.name_prefix + @property def is_runnable(self): ''' @@ -247,3 +253,29 @@ class ModuleV1(object): `hub run` command. ''' return self.default_signature != None + + def save_inference_model(self, + dirname: str, + model_filename: str = None, + params_filename: str = None, + combined: bool = False): + if hasattr(self, 'processor'): + if hasattr(self.processor, 'save_inference_model'): + return self.processor.save_inference_model(dirname, model_filename, params_filename, combined) + + if combined: + model_filename = '__model__' if not model_filename else model_filename + params_filename = '__params__' if not params_filename else params_filename + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + feed_dict, fetch_dict, program = self.context(for_test=True, trainable=False) + paddle.fluid.io.save_inference_model( + dirname=dirname, + main_program=program, + executor=exe, + feeded_var_names=[var.name for var in list(feed_dict.values())], + target_vars=list(fetch_dict.values()), + model_filename=model_filename, + params_filename=params_filename)