提交 bfedfda8 编写于 作者: W wuzewu

optimize the method of preserving model

上级 a5da11b6
......@@ -166,7 +166,7 @@ class Module(object):
model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
self.exe = fluid.Executor(fluid.CPUPlace())
self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model(
dirname=os.path.join(model_dir, sign_name), executor=self.exe)
model_dir, executor=self.exe)
feed_dict, fetch_dict = _process_input_output_key(
self.config.desc, sign_name)
......@@ -293,7 +293,7 @@ class ModuleConfig(object):
return os.path.join(meta_path, PARAM_FILENAME)
def create_module(sign_arr, module_dir=None, word_dict=None):
def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
""" Create a module from main program
"""
assert sign_arr, "signature array should not be None"
......@@ -301,15 +301,19 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
# check all variable
sign_arr = to_list(sign_arr)
program = sign_arr[0].get_inputs()[0].block.program
feeded_var_names = set()
target_vars = set()
for sign in sign_arr:
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
for input in sign.get_inputs():
feeded_var_names.add(input.name)
_tmp_program = input.block.program
assert program == _tmp_program, "all the variable should come from the same program"
for output in sign.get_outputs():
target_vars.add(output)
_tmp_program = output.block.program
assert program == _tmp_program, "all the variable should come from the same program"
......@@ -401,42 +405,41 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
fetch_var.alias = fetch_names[index]
# save inference program
exe = fluid.Executor(place=fluid.CPUPlace())
model_dir = os.path.join(module_dir, "model")
mkdir(model_dir)
# TODO(wuzewu): save paddle model with a more effective way
for sign in sign_arr:
save_model_dir = os.path.join(model_dir, sign.get_name())
fluid.io.save_inference_model(
save_model_dir,
feeded_var_names=[var.name for var in sign.get_inputs()],
target_vars=sign.get_outputs(),
main_program=program,
executor=exe)
with open(os.path.join(save_model_dir, "__model__"), "rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if HUB_VAR_PREFIX not in var
}
for var, block in varlist.items():
old_name = var
new_name = HUB_VAR_PREFIX + old_name
block._rename_var(old_name, new_name)
mkdir(save_model_dir)
with open(os.path.join(save_model_dir, "__model__"), "wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(save_model_dir):
if (file == "__model__" or HUB_VAR_PREFIX in file):
continue
os.rename(
os.path.join(save_model_dir, file),
os.path.join(save_model_dir, HUB_VAR_PREFIX + file))
if not place:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
save_model_dir = os.path.join(module_dir, "model")
mkdir(save_model_dir)
fluid.io.save_inference_model(
save_model_dir,
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
with open(os.path.join(save_model_dir, "__model__"), "rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if HUB_VAR_PREFIX not in var
}
for var, block in varlist.items():
old_name = var
new_name = HUB_VAR_PREFIX + old_name
block._rename_var(old_name, new_name)
mkdir(save_model_dir)
with open(os.path.join(save_model_dir, "__model__"), "wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(save_model_dir):
if (file == "__model__" or HUB_VAR_PREFIX in file):
continue
os.rename(
os.path.join(save_model_dir, file),
os.path.join(save_model_dir, HUB_VAR_PREFIX + file))
# Serialize module_desc pb
module_pb = module_desc.SerializeToString()
......
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import os
def to_list(input):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册