提交 bfedfda8 编写于 作者: W wuzewu

optimize the method of preserving model

上级 a5da11b6
...@@ -166,7 +166,7 @@ class Module(object): ...@@ -166,7 +166,7 @@ class Module(object):
model_dir = os.path.join(self.module_dir, MODEL_DIRNAME) model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
self.exe = fluid.Executor(fluid.CPUPlace()) self.exe = fluid.Executor(fluid.CPUPlace())
self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model( self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model(
dirname=os.path.join(model_dir, sign_name), executor=self.exe) model_dir, executor=self.exe)
feed_dict, fetch_dict = _process_input_output_key( feed_dict, fetch_dict = _process_input_output_key(
self.config.desc, sign_name) self.config.desc, sign_name)
...@@ -293,7 +293,7 @@ class ModuleConfig(object): ...@@ -293,7 +293,7 @@ class ModuleConfig(object):
return os.path.join(meta_path, PARAM_FILENAME) return os.path.join(meta_path, PARAM_FILENAME)
def create_module(sign_arr, module_dir=None, word_dict=None): def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
""" Create a module from main program """ Create a module from main program
""" """
assert sign_arr, "signature array should not be None" assert sign_arr, "signature array should not be None"
...@@ -301,15 +301,19 @@ def create_module(sign_arr, module_dir=None, word_dict=None): ...@@ -301,15 +301,19 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
# check all variable # check all variable
sign_arr = to_list(sign_arr) sign_arr = to_list(sign_arr)
program = sign_arr[0].get_inputs()[0].block.program program = sign_arr[0].get_inputs()[0].block.program
feeded_var_names = set()
target_vars = set()
for sign in sign_arr: for sign in sign_arr:
assert isinstance(sign, assert isinstance(sign,
Signature), "sign_arr should be list of Signature" Signature), "sign_arr should be list of Signature"
for input in sign.get_inputs(): for input in sign.get_inputs():
feeded_var_names.add(input.name)
_tmp_program = input.block.program _tmp_program = input.block.program
assert program == _tmp_program, "all the variable should come from the same program" assert program == _tmp_program, "all the variable should come from the same program"
for output in sign.get_outputs(): for output in sign.get_outputs():
target_vars.add(output)
_tmp_program = output.block.program _tmp_program = output.block.program
assert program == _tmp_program, "all the variable should come from the same program" assert program == _tmp_program, "all the variable should come from the same program"
...@@ -401,42 +405,41 @@ def create_module(sign_arr, module_dir=None, word_dict=None): ...@@ -401,42 +405,41 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
fetch_var.alias = fetch_names[index] fetch_var.alias = fetch_names[index]
# save inference program # save inference program
exe = fluid.Executor(place=fluid.CPUPlace()) if not place:
model_dir = os.path.join(module_dir, "model") place = fluid.CPUPlace()
mkdir(model_dir) exe = fluid.Executor(place=place)
# TODO(wuzewu): save paddle model with a more effective way save_model_dir = os.path.join(module_dir, "model")
for sign in sign_arr: mkdir(save_model_dir)
save_model_dir = os.path.join(model_dir, sign.get_name()) fluid.io.save_inference_model(
fluid.io.save_inference_model( save_model_dir,
save_model_dir, feeded_var_names=list(feeded_var_names),
feeded_var_names=[var.name for var in sign.get_inputs()], target_vars=list(target_vars),
target_vars=sign.get_outputs(), main_program=program,
main_program=program, executor=exe)
executor=exe)
with open(os.path.join(save_model_dir, "__model__"), "rb") as file:
with open(os.path.join(save_model_dir, "__model__"), "rb") as file: program_desc_str = file.read()
program_desc_str = file.read() rename_program = fluid.framework.Program.parse_from_string(
rename_program = fluid.framework.Program.parse_from_string( program_desc_str)
program_desc_str) varlist = {
varlist = { var: block
var: block for block in rename_program.blocks for var in block.vars
for block in rename_program.blocks for var in block.vars if HUB_VAR_PREFIX not in var
if HUB_VAR_PREFIX not in var }
} for var, block in varlist.items():
for var, block in varlist.items(): old_name = var
old_name = var new_name = HUB_VAR_PREFIX + old_name
new_name = HUB_VAR_PREFIX + old_name block._rename_var(old_name, new_name)
block._rename_var(old_name, new_name) mkdir(save_model_dir)
mkdir(save_model_dir) with open(os.path.join(save_model_dir, "__model__"), "wb") as f:
with open(os.path.join(save_model_dir, "__model__"), "wb") as f: f.write(rename_program.desc.serialize_to_string())
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(save_model_dir):
for file in os.listdir(save_model_dir): if (file == "__model__" or HUB_VAR_PREFIX in file):
if (file == "__model__" or HUB_VAR_PREFIX in file): continue
continue os.rename(
os.rename( os.path.join(save_model_dir, file),
os.path.join(save_model_dir, file), os.path.join(save_model_dir, HUB_VAR_PREFIX + file))
os.path.join(save_model_dir, HUB_VAR_PREFIX + file))
# Serialize module_desc pb # Serialize module_desc pb
module_pb = module_desc.SerializeToString() module_pb = module_desc.SerializeToString()
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import os
def to_list(input): def to_list(input):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册