提交 bfedfda8 编写于 作者: W wuzewu

optimize the method of preserving model

上级 a5da11b6
......@@ -166,7 +166,7 @@ class Module(object):
model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
self.exe = fluid.Executor(fluid.CPUPlace())
self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model(
dirname=os.path.join(model_dir, sign_name), executor=self.exe)
model_dir, executor=self.exe)
feed_dict, fetch_dict = _process_input_output_key(
self.config.desc, sign_name)
......@@ -293,7 +293,7 @@ class ModuleConfig(object):
return os.path.join(meta_path, PARAM_FILENAME)
def create_module(sign_arr, module_dir=None, word_dict=None):
def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
""" Create a module from main program
"""
assert sign_arr, "signature array should not be None"
......@@ -301,15 +301,19 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
# check all variable
sign_arr = to_list(sign_arr)
program = sign_arr[0].get_inputs()[0].block.program
feeded_var_names = set()
target_vars = set()
for sign in sign_arr:
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
for input in sign.get_inputs():
feeded_var_names.add(input.name)
_tmp_program = input.block.program
assert program == _tmp_program, "all the variable should come from the same program"
for output in sign.get_outputs():
target_vars.add(output)
_tmp_program = output.block.program
assert program == _tmp_program, "all the variable should come from the same program"
......@@ -401,16 +405,15 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
fetch_var.alias = fetch_names[index]
# save inference program
exe = fluid.Executor(place=fluid.CPUPlace())
model_dir = os.path.join(module_dir, "model")
mkdir(model_dir)
# TODO(wuzewu): save paddle model with a more effective way
for sign in sign_arr:
save_model_dir = os.path.join(model_dir, sign.get_name())
if not place:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
save_model_dir = os.path.join(module_dir, "model")
mkdir(save_model_dir)
fluid.io.save_inference_model(
save_model_dir,
feeded_var_names=[var.name for var in sign.get_inputs()],
target_vars=sign.get_outputs(),
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
......
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import os
def to_list(input):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册