提交 19b22a59 编写于 作者: B bjjwwang

fix save model

上级 1f35ae38
......@@ -19,7 +19,7 @@ from paddle.fluid.framework import core
from paddle.fluid.framework import default_main_program
from paddle.fluid.framework import Program
from paddle.fluid import CPUPlace
from paddle.fluid.io import save_inference_model
from .paddle_io import save_inference_model, normalize_program
import paddle.fluid as fluid
from paddle.fluid.core import CipherUtils
from paddle.fluid.core import CipherFactory
......@@ -191,12 +191,14 @@ def save_model(server_model_folder,
executor = Executor(place=CPUPlace())
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
feed_vars = [feed_var_dict[x] for x in feed_var_dict]
target_vars = []
target_var_names = []
for key in sorted(fetch_var_dict.keys()):
target_vars.append(fetch_var_dict[key])
target_var_names.append(key)
main_program = normalize_program(main_program, feed_vars, target_vars)
if not encryption and not show_proto:
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
......@@ -209,7 +211,7 @@ def save_model(server_model_folder,
new_params_path = os.path.join(server_model_folder, params_filename)
with open(new_model_path, "wb") as new_model_file:
new_model_file.write(main_program.desc.serialize_to_string())
new_model_file.write(main_program._remove_training_info(False).desc.serialize_to_string())
paddle.static.save_vars(
executor=executor,
......@@ -229,7 +231,7 @@ def save_model(server_model_folder,
key = CipherUtils.gen_key_to_file(128, "key")
params = fluid.io.save_persistables(
executor=executor, dirname=None, main_program=main_program)
model = main_program.desc.serialize_to_string()
model = main_program._remove_training_info(False).desc.serialize_to_string()
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
os.chdir(server_model_folder)
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册