diff --git a/python/paddle_serving_client/convert.py b/python/paddle_serving_client/convert.py index e3cd3a05f8e09155b0c884e3ddf12b57234de3dd..984deec609e884a4222a0be1609d068505d97f62 100644 --- a/python/paddle_serving_client/convert.py +++ b/python/paddle_serving_client/convert.py @@ -23,6 +23,12 @@ from .io import inference_model_to_serving def parse_args(): # pylint: disable=doc-string-missing parser = argparse.ArgumentParser("convert") + parser.add_argument( + "--show_proto", + type=bool, + default=False, + help='If yes, you can preview the proto and then determine your feed var alias name and fetch var alias name.' + ) parser.add_argument( "--dirname", type=str, @@ -53,6 +59,18 @@ def parse_args(): # pylint: disable=doc-string-missing default=None, help='The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None.' ) + parser.add_argument( + "--feed_alias_names", + type=str, + default=None, + help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of feed vars' + ) + parser.add_argument( + "--fetch_alias_names", + type=str, + default=None, + help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of fetch vars' + ) return parser.parse_args() @@ -63,4 +81,7 @@ if __name__ == "__main__": serving_server=args.serving_server, serving_client=args.serving_client, model_filename=args.model_filename, - params_filename=args.params_filename) + params_filename=args.params_filename, + show_proto=args.show_proto, + feed_alias_names=args.feed_alias_names, + fetch_alias_names=args.fetch_alias_names) diff --git a/python/paddle_serving_client/io/__init__.py b/python/paddle_serving_client/io/__init__.py index 35b400bed60da8dab4b49cb660d4e6fcfe0f7f2c..07f443196d5e460d5158112dda33bb9c186394b5 100644 --- a/python/paddle_serving_client/io/__init__.py +++ b/python/paddle_serving_client/io/__init__.py @@ -184,7 +184,10 @@ def save_model(server_model_folder, key_len=128, encrypt_conf=None, model_filename=None, - params_filename=None): + params_filename=None, + show_proto=False, + feed_alias_names=None, + fetch_alias_names=None): executor = Executor(place=CPUPlace()) feed_var_names = [feed_var_dict[x].name for x in feed_var_dict] @@ -194,9 +197,9 @@ def save_model(server_model_folder, target_vars.append(fetch_var_dict[key]) target_var_names.append(key) - if not os.path.exists(server_model_folder): - os.makedirs(server_model_folder) - if not encryption: + if not encryption and not show_proto: + if not os.path.exists(server_model_folder): + os.makedirs(server_model_folder) if not model_filename: model_filename = "model.pdmodel" if not params_filename: @@ -207,7 +210,7 @@ def save_model(server_model_folder, with open(new_model_path, "wb") as new_model_file: new_model_file.write(main_program.desc.serialize_to_string()) - + paddle.static.save_vars( executor=executor, dirname=server_model_folder, @@ -215,7 +218,9 @@ def save_model(server_model_folder, vars=None, predicate=paddle.static.io.is_persistable, filename=params_filename) - else: + elif not show_proto: + if not os.path.exists(server_model_folder): + os.makedirs(server_model_folder) if encrypt_conf == None: aes_cipher = CipherFactory.create_cipher() else: @@ -233,10 +238,19 @@ def save_model(server_model_folder, os.chdir("..") config = model_conf.GeneralModelConfig() - - for key in feed_var_dict: + if feed_alias_names is None: + feed_alias = list(feed_var_dict.keys()) + else: + feed_alias = feed_alias_names.split(',') + if fetch_alias_names is None: + fetch_alias = target_var_names + else: + fetch_alias = fetch_alias_names.split(',') + if len(feed_alias) != len(feed_var_dict.keys()) or len(fetch_alias) != len(target_var_names): + raise ValueError("please check the input --feed_alias_names and --fetch_alias_names, should be same size with feed_vars and fetch_vars") + for i, key in enumerate(feed_var_dict): feed_var = model_conf.FeedVar() - feed_var.alias_name = key + feed_var.alias_name = feed_alias[i] feed_var.name = feed_var_dict[key].name feed_var.feed_type = var_type_conversion(feed_var_dict[key].dtype) @@ -251,9 +265,9 @@ def save_model(server_model_folder, feed_var.shape.extend(tmp_shape) config.feed_var.extend([feed_var]) - for key in target_var_names: + for i, key in enumerate(target_var_names): fetch_var = model_conf.FetchVar() - fetch_var.alias_name = key + fetch_var.alias_name = fetch_alias[i] fetch_var.name = fetch_var_dict[key].name fetch_var.fetch_type = var_type_conversion(fetch_var_dict[key].dtype) @@ -269,6 +283,9 @@ def save_model(server_model_folder, fetch_var.shape.extend(tmp_shape) config.fetch_var.extend([fetch_var]) + if show_proto: + print(str(config)) + return try: save_dirname = os.path.normpath(client_config_folder) os.makedirs(save_dirname) @@ -296,7 +313,10 @@ def inference_model_to_serving(dirname, params_filename=None, encryption=False, key_len=128, - encrypt_conf=None): + encrypt_conf=None, + show_proto=False, + feed_alias_names=None, + fetch_alias_names=None): paddle.enable_static() place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -309,7 +329,7 @@ def inference_model_to_serving(dirname, fetch_dict = {x.name: x for x in fetch_targets} save_model(serving_server, serving_client, feed_dict, fetch_dict, inference_program, encryption, key_len, encrypt_conf, - model_filename, params_filename) + model_filename, params_filename, show_proto, feed_alias_names, fetch_alias_names) feed_names = feed_dict.keys() fetch_names = fetch_dict.keys() return feed_names, fetch_names diff --git a/python/setup.py.server.in b/python/setup.py.server.in index cf579db0ba082606e289eb49f8713b9441053743..dfe3761035c18cad0d74f25f9a17b268003dd201 100644 --- a/python/setup.py.server.in +++ b/python/setup.py.server.in @@ -33,7 +33,7 @@ util.gen_pipeline_code("paddle_serving_server") REQUIRED_PACKAGES = [ 'six >= 1.10.0', 'protobuf >= 3.11.0', 'grpcio <= 1.33.2', 'grpcio-tools <= 1.33.2', - 'flask >= 1.1.1', 'click==7.1.2', 'itsdangerous==1.1.0', 'Jinja2==2.11.3', + 'flask >= 1.1.1,<2.0.0', 'click==7.1.2', 'itsdangerous==1.1.0', 'Jinja2==2.11.3', 'MarkupSafe==1.1.1', 'Werkzeug==1.0.1', 'func_timeout', 'pyyaml' ]