diff --git a/paddle_hub/module.py b/paddle_hub/module.py index cc9130e23bdf149d62d02e1335394349bba7d753..cc71d3763e31a9c1090899916b3ec7f9ba06769a 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -166,7 +166,7 @@ class Module(object): model_dir = os.path.join(self.module_dir, MODEL_DIRNAME) self.exe = fluid.Executor(fluid.CPUPlace()) self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model( - dirname=os.path.join(model_dir, sign_name), executor=self.exe) + model_dir, executor=self.exe) feed_dict, fetch_dict = _process_input_output_key( self.config.desc, sign_name) @@ -293,7 +293,7 @@ class ModuleConfig(object): return os.path.join(meta_path, PARAM_FILENAME) -def create_module(sign_arr, module_dir=None, word_dict=None): +def create_module(sign_arr, module_dir=None, word_dict=None, place=None): """ Create a module from main program """ assert sign_arr, "signature array should not be None" @@ -301,15 +301,19 @@ def create_module(sign_arr, module_dir=None, word_dict=None): # check all variable sign_arr = to_list(sign_arr) program = sign_arr[0].get_inputs()[0].block.program + feeded_var_names = set() + target_vars = set() for sign in sign_arr: assert isinstance(sign, Signature), "sign_arr should be list of Signature" for input in sign.get_inputs(): + feeded_var_names.add(input.name) _tmp_program = input.block.program assert program == _tmp_program, "all the variable should come from the same program" for output in sign.get_outputs(): + target_vars.add(output) _tmp_program = output.block.program assert program == _tmp_program, "all the variable should come from the same program" @@ -401,42 +405,41 @@ def create_module(sign_arr, module_dir=None, word_dict=None): fetch_var.alias = fetch_names[index] # save inference program - exe = fluid.Executor(place=fluid.CPUPlace()) - model_dir = os.path.join(module_dir, "model") - mkdir(model_dir) - # TODO(wuzewu): save paddle model with a more effective way - for sign in sign_arr: - save_model_dir = os.path.join(model_dir, sign.get_name()) - fluid.io.save_inference_model( - save_model_dir, - feeded_var_names=[var.name for var in sign.get_inputs()], - target_vars=sign.get_outputs(), - main_program=program, - executor=exe) - - with open(os.path.join(save_model_dir, "__model__"), "rb") as file: - program_desc_str = file.read() - rename_program = fluid.framework.Program.parse_from_string( - program_desc_str) - varlist = { - var: block - for block in rename_program.blocks for var in block.vars - if HUB_VAR_PREFIX not in var - } - for var, block in varlist.items(): - old_name = var - new_name = HUB_VAR_PREFIX + old_name - block._rename_var(old_name, new_name) - mkdir(save_model_dir) - with open(os.path.join(save_model_dir, "__model__"), "wb") as f: - f.write(rename_program.desc.serialize_to_string()) - - for file in os.listdir(save_model_dir): - if (file == "__model__" or HUB_VAR_PREFIX in file): - continue - os.rename( - os.path.join(save_model_dir, file), - os.path.join(save_model_dir, HUB_VAR_PREFIX + file)) + if not place: + place = fluid.CPUPlace() + exe = fluid.Executor(place=place) + save_model_dir = os.path.join(module_dir, "model") + mkdir(save_model_dir) + fluid.io.save_inference_model( + save_model_dir, + feeded_var_names=list(feeded_var_names), + target_vars=list(target_vars), + main_program=program, + executor=exe) + + with open(os.path.join(save_model_dir, "__model__"), "rb") as file: + program_desc_str = file.read() + rename_program = fluid.framework.Program.parse_from_string( + program_desc_str) + varlist = { + var: block + for block in rename_program.blocks for var in block.vars + if HUB_VAR_PREFIX not in var + } + for var, block in varlist.items(): + old_name = var + new_name = HUB_VAR_PREFIX + old_name + block._rename_var(old_name, new_name) + mkdir(save_model_dir) + with open(os.path.join(save_model_dir, "__model__"), "wb") as f: + f.write(rename_program.desc.serialize_to_string()) + + for file in os.listdir(save_model_dir): + if (file == "__model__" or HUB_VAR_PREFIX in file): + continue + os.rename( + os.path.join(save_model_dir, file), + os.path.join(save_model_dir, HUB_VAR_PREFIX + file)) # Serialize module_desc pb module_pb = module_desc.SerializeToString() diff --git a/paddle_hub/utils.py b/paddle_hub/utils.py index c32d6c385c442e0e0c155e269be0266a5a6c9cce..767ebbb345bbf79b28923ca18a6b4cdecd8597e3 100644 --- a/paddle_hub/utils.py +++ b/paddle_hub/utils.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import paddle import paddle.fluid as fluid +import os def to_list(input):