提交 bcaefb3f 编写于 作者: W wangjiawei04

add paddle jit save

上级 6783bcb6
...@@ -23,7 +23,90 @@ from paddle.fluid.io import save_inference_model ...@@ -23,7 +23,90 @@ from paddle.fluid.io import save_inference_model
import paddle.fluid as fluid import paddle.fluid as fluid
from ..proto import general_model_config_pb2 as model_conf from ..proto import general_model_config_pb2 as model_conf
import os import os
import paddle
import paddle.nn.functional as F
from paddle.jit import to_static
def save_dygraph_model(serving_model_folder, client_config_folder, model):
paddle.jit.save(model, "serving_tmp")
loaded_layer = paddle.jit.load(path=".", model_filename="serving_tmp.pdmodel", params_filename="serving_tmp.pdiparams")
feed_target_names = [x.name for x in loaded_layer._input_spec()]
fetch_target_names = [x.name for x in loaded_layer._output_spec()]
inference_program = loaded_layer.program()
feed_var_dict = {
x: inference_program.global_block().var(x)
for x in feed_target_names
}
fetch_var_dict = {
x: inference_program.global_block().var(x)
for x in fetch_target_names
}
config = model_conf.GeneralModelConfig()
#int64 = 0; float32 = 1; int32 = 2;
for key in feed_var_dict:
feed_var = model_conf.FeedVar()
feed_var.alias_name = key
feed_var.name = feed_var_dict[key].name
feed_var.is_lod_tensor = feed_var_dict[key].lod_level >= 1
if feed_var_dict[key].dtype == core.VarDesc.VarType.INT64:
feed_var.feed_type = 0
if feed_var_dict[key].dtype == core.VarDesc.VarType.FP32:
feed_var.feed_type = 1
if feed_var_dict[key].dtype == core.VarDesc.VarType.INT32:
feed_var.feed_type = 2
if feed_var.is_lod_tensor:
feed_var.shape.extend([-1])
else:
tmp_shape = []
for v in feed_var_dict[key].shape:
if v >= 0:
tmp_shape.append(v)
feed_var.shape.extend(tmp_shape)
config.feed_var.extend([feed_var])
for key in fetch_var_dict:
fetch_var = model_conf.FetchVar()
fetch_var.alias_name = key
fetch_var.name = fetch_var_dict[key].name
fetch_var.is_lod_tensor = 1
if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT64:
fetch_var.fetch_type = 0
if fetch_var_dict[key].dtype == core.VarDesc.VarType.FP32:
fetch_var.fetch_type = 1
if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT32:
fetch_var.fetch_type = 2
if fetch_var.is_lod_tensor:
fetch_var.shape.extend([-1])
else:
tmp_shape = []
for v in fetch_var_dict[key].shape:
if v >= 0:
tmp_shape.append(v)
fetch_var.shape.extend(tmp_shape)
config.fetch_var.extend([fetch_var])
cmd = "mkdir -p {}".format(client_config_folder)
os.system(cmd)
cmd = "mkdir -p {}".format(serving_model_folder)
os.system(cmd)
cmd = "mv {} {}/__model__".format("serving_tmp.pdmodel", serving_model_folder)
os.system(cmd)
cmd = "mv {} {}/__params__".format("serving_tmp.pdiparams", serving_model_folder)
os.system(cmd)
cmd = "rm -rf serving_tmp.pd*"
os.system(cmd)
with open("{}/serving_client_conf.prototxt".format(client_config_folder),
"w") as fout:
fout.write(str(config))
with open("{}/serving_server_conf.prototxt".format(serving_model_folder),
"w") as fout:
fout.write(str(config))
with open("{}/serving_client_conf.stream.prototxt".format(
client_config_folder), "wb") as fout:
fout.write(config.SerializeToString())
with open("{}/serving_server_conf.stream.prototxt".format(
serving_model_folder), "wb") as fout:
fout.write(config.SerializeToString())
def save_model(server_model_folder, def save_model(server_model_folder,
client_config_folder, client_config_folder,
...@@ -44,6 +127,8 @@ def save_model(server_model_folder, ...@@ -44,6 +127,8 @@ def save_model(server_model_folder,
feed_var_names, feed_var_names,
target_vars, target_vars,
executor, executor,
model_filename="__model__",
params_filename="__params__",
main_program=main_program) main_program=main_program)
config = model_conf.GeneralModelConfig() config = model_conf.GeneralModelConfig()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册