提交 427c4b11 编写于 作者: J Jiawei Wang

convert and flask version

上级 6477fe08
......@@ -23,6 +23,12 @@ from .io import inference_model_to_serving
def parse_args(): # pylint: disable=doc-string-missing
parser = argparse.ArgumentParser("convert")
parser.add_argument(
"--show_proto",
type=bool,
default=False,
help='If yes, you can preview the proto and then determine your feed var alias name and fetch var alias name.'
)
parser.add_argument(
"--dirname",
type=str,
......@@ -53,6 +59,18 @@ def parse_args(): # pylint: disable=doc-string-missing
default=None,
help='The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None.'
)
parser.add_argument(
"--feed_alias_names",
type=str,
default=None,
help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of feed vars'
)
parser.add_argument(
"--fetch_alias_names",
type=str,
default=None,
help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of fetch vars'
)
return parser.parse_args()
......@@ -63,4 +81,7 @@ if __name__ == "__main__":
serving_server=args.serving_server,
serving_client=args.serving_client,
model_filename=args.model_filename,
params_filename=args.params_filename)
params_filename=args.params_filename,
show_proto=args.show_proto,
feed_alias_names=args.feed_alias_names,
fetch_alias_names=args.fetch_alias_names)
......@@ -184,7 +184,10 @@ def save_model(server_model_folder,
key_len=128,
encrypt_conf=None,
model_filename=None,
params_filename=None):
params_filename=None,
show_proto=False,
feed_alias_names=None,
fetch_alias_names=None):
executor = Executor(place=CPUPlace())
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
......@@ -194,9 +197,9 @@ def save_model(server_model_folder,
target_vars.append(fetch_var_dict[key])
target_var_names.append(key)
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
if not encryption:
if not encryption and not show_proto:
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
if not model_filename:
model_filename = "model.pdmodel"
if not params_filename:
......@@ -207,7 +210,7 @@ def save_model(server_model_folder,
with open(new_model_path, "wb") as new_model_file:
new_model_file.write(main_program.desc.serialize_to_string())
paddle.static.save_vars(
executor=executor,
dirname=server_model_folder,
......@@ -215,7 +218,9 @@ def save_model(server_model_folder,
vars=None,
predicate=paddle.static.io.is_persistable,
filename=params_filename)
else:
elif not show_proto:
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
if encrypt_conf == None:
aes_cipher = CipherFactory.create_cipher()
else:
......@@ -233,10 +238,19 @@ def save_model(server_model_folder,
os.chdir("..")
config = model_conf.GeneralModelConfig()
for key in feed_var_dict:
if feed_alias_names is None:
feed_alias = list(feed_var_dict.keys())
else:
feed_alias = feed_alias_names.split(',')
if fetch_alias_names is None:
fetch_alias = target_var_names
else:
fetch_alias = fetch_alias_names.split(',')
if len(feed_alias) != len(feed_var_dict.keys()) or len(fetch_alias) != len(target_var_names):
raise ValueError("please check the input --feed_alias_names and --fetch_alias_names, should be same size with feed_vars and fetch_vars")
for i, key in enumerate(feed_var_dict):
feed_var = model_conf.FeedVar()
feed_var.alias_name = key
feed_var.alias_name = feed_alias[i]
feed_var.name = feed_var_dict[key].name
feed_var.feed_type = var_type_conversion(feed_var_dict[key].dtype)
......@@ -251,9 +265,9 @@ def save_model(server_model_folder,
feed_var.shape.extend(tmp_shape)
config.feed_var.extend([feed_var])
for key in target_var_names:
for i, key in enumerate(target_var_names):
fetch_var = model_conf.FetchVar()
fetch_var.alias_name = key
fetch_var.alias_name = fetch_alias[i]
fetch_var.name = fetch_var_dict[key].name
fetch_var.fetch_type = var_type_conversion(fetch_var_dict[key].dtype)
......@@ -269,6 +283,9 @@ def save_model(server_model_folder,
fetch_var.shape.extend(tmp_shape)
config.fetch_var.extend([fetch_var])
if show_proto:
print(str(config))
return
try:
save_dirname = os.path.normpath(client_config_folder)
os.makedirs(save_dirname)
......@@ -296,7 +313,10 @@ def inference_model_to_serving(dirname,
params_filename=None,
encryption=False,
key_len=128,
encrypt_conf=None):
encrypt_conf=None,
show_proto=False,
feed_alias_names=None,
fetch_alias_names=None):
paddle.enable_static()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -309,7 +329,7 @@ def inference_model_to_serving(dirname,
fetch_dict = {x.name: x for x in fetch_targets}
save_model(serving_server, serving_client, feed_dict, fetch_dict,
inference_program, encryption, key_len, encrypt_conf,
model_filename, params_filename)
model_filename, params_filename, show_proto, feed_alias_names, fetch_alias_names)
feed_names = feed_dict.keys()
fetch_names = fetch_dict.keys()
return feed_names, fetch_names
......@@ -33,7 +33,7 @@ util.gen_pipeline_code("paddle_serving_server")
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.11.0', 'grpcio <= 1.33.2', 'grpcio-tools <= 1.33.2',
'flask >= 1.1.1', 'click==7.1.2', 'itsdangerous==1.1.0', 'Jinja2==2.11.3',
'flask >= 1.1.1,<2.0.0', 'click==7.1.2', 'itsdangerous==1.1.0', 'Jinja2==2.11.3',
'MarkupSafe==1.1.1', 'Werkzeug==1.0.1', 'func_timeout', 'pyyaml'
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册