提交 d893810f 编写于 作者: W wuzewu

process muti signature

上级 a7967cef
...@@ -74,33 +74,6 @@ class Module(object): ...@@ -74,33 +74,6 @@ class Module(object):
self.module_name = module_dir.split("/")[-1] self.module_name = module_dir.split("/")[-1]
#TODO(ZeyuChen) add more check about loading module from local path #TODO(ZeyuChen) add more check about loading module from local path
# load paddle inference model
place = fluid.CPUPlace()
model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
self.exe = fluid.Executor(fluid.CPUPlace())
[self.inference_program, self.feed_target_names,
self.fetch_targets] = fluid.io.load_inference_model(
dirname=model_dir, executor=self.exe)
# remove feed fetch operator and variable
ModuleUtils.remove_feed_fetch_op(self.inference_program)
# print("inference_program")
# print(self.inference_program)
print("**feed_target_names**\n{}".format(self.feed_target_names))
print("**fetch_targets**\n{}".format(self.fetch_targets))
self.config = ModuleConfig(self.module_dir)
self.config.load()
self._process_parameter()
#TODO(wuzewu): recover the default unique name generator someother where
self._process_uqn()
def _process_uqn(self):
name_generator_path = ModuleConfig.name_generator_path(self.module_dir)
with open(name_generator_path, "rb") as fi:
fluid.unique_name.switch(pickle.load(fi))
def _process_parameter(self): def _process_parameter(self):
global_block = self.inference_program.global_block() global_block = self.inference_program.global_block()
param_path = ModuleConfig.meta_param_path(self.module_dir) param_path = ModuleConfig.meta_param_path(self.module_dir)
...@@ -133,12 +106,33 @@ class Module(object): ...@@ -133,12 +106,33 @@ class Module(object):
if op.has_attr("is_test"): if op.has_attr("is_test"):
op._set_attr("is_test", is_test) op._set_attr("is_test", is_test)
# load paddle inference model
place = fluid.CPUPlace()
model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
self.exe = fluid.Executor(fluid.CPUPlace())
self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model(
dirname=os.path.join(model_dir, sign_name, executor=self.exe))
# remove feed fetch operator and variable
ModuleUtils.remove_feed_fetch_op(self.inference_program)
# print("inference_program")
# print(self.inference_program)
print("**feed_target_names**\n{}".format(self.feed_target_names))
print("**fetch_targets**\n{}".format(self.fetch_targets))
self.config = ModuleConfig(self.module_dir)
self.config.load()
self._process_parameter()
name_generator_path = ModuleConfig.name_generator_path(self.module_dir)
with open(name_generator_path, "rb") as data:
generator = pickle.load(data)
program = self.get_inference_program().clone() program = self.get_inference_program().clone()
_process_op_attr(program=program, is_test=False) _process_op_attr(program=program, is_test=False)
_set_param_trainable(program=program, trainable=trainable) _set_param_trainable(program=program, trainable=trainable)
return self.feed_target_names, self.fetch_targets, program return self.feed_target_names, self.fetch_targets, program, generator
def get_inference_program(self): def get_inference_program(self):
return self.inference_program return self.inference_program
...@@ -323,14 +317,15 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None): ...@@ -323,14 +317,15 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
model_dir = os.path.join(module_dir, "model") model_dir = os.path.join(module_dir, "model")
mkdir(model_dir) mkdir(model_dir)
# TODO(ZeyuChen): here only deal with one signature # TODO(wuzewu): save paddle model with a more effective way
first_sign = sign_arr[0] for sign in sign_arr:
fluid.io.save_inference_model( save_model_dir = os.path.join(model_dir, sign.get_name())
model_dir, fluid.io.save_inference_model(
feeded_var_names=[var.name for var in first_sign.get_inputs()], save_model_dir,
target_vars=first_sign.get_outputs(), feeded_var_names=[var.name for var in sign.get_inputs()],
main_program=program, target_vars=sign.get_outputs(),
executor=exe) main_program=program,
executor=exe)
# save to disk # save to disk
data = module.SerializeToString() data = module.SerializeToString()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册