未验证 提交 5b9c85dc 编写于 作者: T TeslaZhao 提交者: GitHub

Merge branch 'PaddlePaddle:develop' into develop

...@@ -23,6 +23,12 @@ from .io import inference_model_to_serving ...@@ -23,6 +23,12 @@ from .io import inference_model_to_serving
def parse_args(): # pylint: disable=doc-string-missing def parse_args(): # pylint: disable=doc-string-missing
parser = argparse.ArgumentParser("convert") parser = argparse.ArgumentParser("convert")
parser.add_argument(
"--show_proto",
type=bool,
default=False,
help='If yes, you can preview the proto and then determine your feed var alias name and fetch var alias name.'
)
parser.add_argument( parser.add_argument(
"--dirname", "--dirname",
type=str, type=str,
...@@ -53,6 +59,18 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -53,6 +59,18 @@ def parse_args(): # pylint: disable=doc-string-missing
default=None, default=None,
help='The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None.' help='The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None.'
) )
parser.add_argument(
"--feed_alias_names",
type=str,
default=None,
help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of feed vars'
)
parser.add_argument(
"--fetch_alias_names",
type=str,
default=None,
help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of fetch vars'
)
return parser.parse_args() return parser.parse_args()
...@@ -63,4 +81,7 @@ if __name__ == "__main__": ...@@ -63,4 +81,7 @@ if __name__ == "__main__":
serving_server=args.serving_server, serving_server=args.serving_server,
serving_client=args.serving_client, serving_client=args.serving_client,
model_filename=args.model_filename, model_filename=args.model_filename,
params_filename=args.params_filename) params_filename=args.params_filename,
show_proto=args.show_proto,
feed_alias_names=args.feed_alias_names,
fetch_alias_names=args.fetch_alias_names)
...@@ -184,7 +184,10 @@ def save_model(server_model_folder, ...@@ -184,7 +184,10 @@ def save_model(server_model_folder,
key_len=128, key_len=128,
encrypt_conf=None, encrypt_conf=None,
model_filename=None, model_filename=None,
params_filename=None): params_filename=None,
show_proto=False,
feed_alias_names=None,
fetch_alias_names=None):
executor = Executor(place=CPUPlace()) executor = Executor(place=CPUPlace())
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict] feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
...@@ -194,9 +197,9 @@ def save_model(server_model_folder, ...@@ -194,9 +197,9 @@ def save_model(server_model_folder,
target_vars.append(fetch_var_dict[key]) target_vars.append(fetch_var_dict[key])
target_var_names.append(key) target_var_names.append(key)
if not os.path.exists(server_model_folder): if not encryption and not show_proto:
os.makedirs(server_model_folder) if not os.path.exists(server_model_folder):
if not encryption: os.makedirs(server_model_folder)
if not model_filename: if not model_filename:
model_filename = "model.pdmodel" model_filename = "model.pdmodel"
if not params_filename: if not params_filename:
...@@ -207,7 +210,7 @@ def save_model(server_model_folder, ...@@ -207,7 +210,7 @@ def save_model(server_model_folder,
with open(new_model_path, "wb") as new_model_file: with open(new_model_path, "wb") as new_model_file:
new_model_file.write(main_program.desc.serialize_to_string()) new_model_file.write(main_program.desc.serialize_to_string())
paddle.static.save_vars( paddle.static.save_vars(
executor=executor, executor=executor,
dirname=server_model_folder, dirname=server_model_folder,
...@@ -215,7 +218,9 @@ def save_model(server_model_folder, ...@@ -215,7 +218,9 @@ def save_model(server_model_folder,
vars=None, vars=None,
predicate=paddle.static.io.is_persistable, predicate=paddle.static.io.is_persistable,
filename=params_filename) filename=params_filename)
else: elif not show_proto:
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
if encrypt_conf == None: if encrypt_conf == None:
aes_cipher = CipherFactory.create_cipher() aes_cipher = CipherFactory.create_cipher()
else: else:
...@@ -233,10 +238,19 @@ def save_model(server_model_folder, ...@@ -233,10 +238,19 @@ def save_model(server_model_folder,
os.chdir("..") os.chdir("..")
config = model_conf.GeneralModelConfig() config = model_conf.GeneralModelConfig()
if feed_alias_names is None:
for key in feed_var_dict: feed_alias = list(feed_var_dict.keys())
else:
feed_alias = feed_alias_names.split(',')
if fetch_alias_names is None:
fetch_alias = target_var_names
else:
fetch_alias = fetch_alias_names.split(',')
if len(feed_alias) != len(feed_var_dict.keys()) or len(fetch_alias) != len(target_var_names):
raise ValueError("please check the input --feed_alias_names and --fetch_alias_names, should be same size with feed_vars and fetch_vars")
for i, key in enumerate(feed_var_dict):
feed_var = model_conf.FeedVar() feed_var = model_conf.FeedVar()
feed_var.alias_name = key feed_var.alias_name = feed_alias[i]
feed_var.name = feed_var_dict[key].name feed_var.name = feed_var_dict[key].name
feed_var.feed_type = var_type_conversion(feed_var_dict[key].dtype) feed_var.feed_type = var_type_conversion(feed_var_dict[key].dtype)
...@@ -251,9 +265,9 @@ def save_model(server_model_folder, ...@@ -251,9 +265,9 @@ def save_model(server_model_folder,
feed_var.shape.extend(tmp_shape) feed_var.shape.extend(tmp_shape)
config.feed_var.extend([feed_var]) config.feed_var.extend([feed_var])
for key in target_var_names: for i, key in enumerate(target_var_names):
fetch_var = model_conf.FetchVar() fetch_var = model_conf.FetchVar()
fetch_var.alias_name = key fetch_var.alias_name = fetch_alias[i]
fetch_var.name = fetch_var_dict[key].name fetch_var.name = fetch_var_dict[key].name
fetch_var.fetch_type = var_type_conversion(fetch_var_dict[key].dtype) fetch_var.fetch_type = var_type_conversion(fetch_var_dict[key].dtype)
...@@ -269,6 +283,9 @@ def save_model(server_model_folder, ...@@ -269,6 +283,9 @@ def save_model(server_model_folder,
fetch_var.shape.extend(tmp_shape) fetch_var.shape.extend(tmp_shape)
config.fetch_var.extend([fetch_var]) config.fetch_var.extend([fetch_var])
if show_proto:
print(str(config))
return
try: try:
save_dirname = os.path.normpath(client_config_folder) save_dirname = os.path.normpath(client_config_folder)
os.makedirs(save_dirname) os.makedirs(save_dirname)
...@@ -296,7 +313,10 @@ def inference_model_to_serving(dirname, ...@@ -296,7 +313,10 @@ def inference_model_to_serving(dirname,
params_filename=None, params_filename=None,
encryption=False, encryption=False,
key_len=128, key_len=128,
encrypt_conf=None): encrypt_conf=None,
show_proto=False,
feed_alias_names=None,
fetch_alias_names=None):
paddle.enable_static() paddle.enable_static()
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -309,7 +329,7 @@ def inference_model_to_serving(dirname, ...@@ -309,7 +329,7 @@ def inference_model_to_serving(dirname,
fetch_dict = {x.name: x for x in fetch_targets} fetch_dict = {x.name: x for x in fetch_targets}
save_model(serving_server, serving_client, feed_dict, fetch_dict, save_model(serving_server, serving_client, feed_dict, fetch_dict,
inference_program, encryption, key_len, encrypt_conf, inference_program, encryption, key_len, encrypt_conf,
model_filename, params_filename) model_filename, params_filename, show_proto, feed_alias_names, fetch_alias_names)
feed_names = feed_dict.keys() feed_names = feed_dict.keys()
fetch_names = fetch_dict.keys() fetch_names = fetch_dict.keys()
return feed_names, fetch_names return feed_names, fetch_names
...@@ -33,7 +33,7 @@ util.gen_pipeline_code("paddle_serving_server") ...@@ -33,7 +33,7 @@ util.gen_pipeline_code("paddle_serving_server")
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.11.0', 'grpcio <= 1.33.2', 'grpcio-tools <= 1.33.2', 'six >= 1.10.0', 'protobuf >= 3.11.0', 'grpcio <= 1.33.2', 'grpcio-tools <= 1.33.2',
'flask >= 1.1.1', 'click==7.1.2', 'itsdangerous==1.1.0', 'Jinja2==2.11.3', 'flask >= 1.1.1,<2.0.0', 'click==7.1.2', 'itsdangerous==1.1.0', 'Jinja2==2.11.3',
'MarkupSafe==1.1.1', 'Werkzeug==1.0.1', 'func_timeout', 'pyyaml' 'MarkupSafe==1.1.1', 'Werkzeug==1.0.1', 'func_timeout', 'pyyaml'
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册