提交 77581372 编写于 作者: L LiuChiachi

supports string input export

上级 b51c7922
......@@ -211,8 +211,10 @@ def save_model(server_model_folder,
new_params_path = os.path.join(server_model_folder, params_filename)
with open(new_model_path, "wb") as new_model_file:
new_model_file.write(main_program._remove_training_info(False).desc.serialize_to_string())
new_model_file.write(
main_program._remove_training_info(False)
.desc.serialize_to_string())
paddle.static.save_vars(
executor=executor,
dirname=server_model_folder,
......@@ -231,7 +233,8 @@ def save_model(server_model_folder,
key = CipherUtils.gen_key_to_file(128, "key")
params = fluid.io.save_persistables(
executor=executor, dirname=None, main_program=main_program)
model = main_program._remove_training_info(False).desc.serialize_to_string()
model = main_program._remove_training_info(
False).desc.serialize_to_string()
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
os.chdir(server_model_folder)
......@@ -248,15 +251,20 @@ def save_model(server_model_folder,
fetch_alias = target_var_names
else:
fetch_alias = fetch_alias_names.split(',')
if len(feed_alias) != len(feed_var_dict.keys()) or len(fetch_alias) != len(target_var_names):
raise ValueError("please check the input --feed_alias_names and --fetch_alias_names, should be same size with feed_vars and fetch_vars")
if len(feed_alias) != len(feed_var_dict.keys()) or len(fetch_alias) != len(
target_var_names):
raise ValueError(
"please check the input --feed_alias_names and --fetch_alias_names, should be same size with feed_vars and fetch_vars"
)
for i, key in enumerate(feed_var_dict):
feed_var = model_conf.FeedVar()
feed_var.alias_name = feed_alias[i]
feed_var.name = feed_var_dict[key].name
feed_var.feed_type = var_type_conversion(feed_var_dict[key].dtype)
feed_var.is_lod_tensor = feed_var_dict[key].lod_level >= 1
feed_var.is_lod_tensor = feed_var_dict[
key].lod_level >= 1 if feed_var_dict[
key].lod_level is not None else False
if feed_var.is_lod_tensor:
feed_var.shape.extend([-1])
else:
......@@ -331,7 +339,8 @@ def inference_model_to_serving(dirname,
fetch_dict = {x.name: x for x in fetch_targets}
save_model(serving_server, serving_client, feed_dict, fetch_dict,
inference_program, encryption, key_len, encrypt_conf,
model_filename, params_filename, show_proto, feed_alias_names, fetch_alias_names)
model_filename, params_filename, show_proto, feed_alias_names,
fetch_alias_names)
feed_names = feed_dict.keys()
fetch_names = fetch_dict.keys()
return feed_names, fetch_names
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册