diff --git a/core/configure/proto/multi_lang_general_model_service.proto b/core/configure/proto/multi_lang_general_model_service.proto index 2a4764a041d7f817aba1d516427a241498d4c2e0..18fbcf760647e1694e738c0832fe45f4f7d9934f 100755 --- a/core/configure/proto/multi_lang_general_model_service.proto +++ b/core/configure/proto/multi_lang_general_model_service.proto @@ -59,7 +59,7 @@ message SimpleResponse { required int32 err_code = 1; } message GetClientConfigRequest {} -message GetClientConfigResponse { repeated string client_config_str_list = 1; } +message GetClientConfigResponse { required string client_config_str = 1; } service MultiLangGeneralModelService { rpc Inference(InferenceRequest) returns (InferenceResponse) {} diff --git a/python/paddle_serving_client/client.py b/python/paddle_serving_client/client.py index 48ad112ab015242b85753489f84422c4187f6ec1..8b1fc38032133230f450f83b9139d5f347b2ae1b 100755 --- a/python/paddle_serving_client/client.py +++ b/python/paddle_serving_client/client.py @@ -554,15 +554,8 @@ class MultiLangClient(object): get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest( ) resp = self.stub_.GetClientConfig(get_client_config_req) - model_config_path_list = resp.client_config_str_list - file_path_list = [] - for single_model_config in model_config_path_list: - if os.path.isdir(single_model_config): - file_path_list.append("{}/serving_server_conf.prototxt".format( - single_model_config)) - elif os.path.isfile(single_model_config): - file_path_list.append(single_model_config) - self._parse_model_config(file_path_list) + model_config_str = resp.client_config_str + self._parse_model_config(model_config_str) def _flatten_list(self, nested_list): for item in nested_list: @@ -572,23 +565,10 @@ class MultiLangClient(object): else: yield item - def _parse_model_config(self, model_config_path_list): - if isinstance(model_config_path_list, str): - model_config_path_list = [model_config_path_list] - elif isinstance(model_config_path_list, list): - pass - - file_path_list = [] - for single_model_config in model_config_path_list: - if os.path.isdir(single_model_config): - file_path_list.append("{}/serving_client_conf.prototxt".format( - single_model_config)) - elif os.path.isfile(single_model_config): - file_path_list.append(single_model_config) + def _parse_model_config(self, model_config_str): model_conf = m_config.GeneralModelConfig() - f = open(file_path_list[0], 'r') - model_conf = google.protobuf.text_format.Merge( - str(f.read()), model_conf) + model_conf = google.protobuf.text_format.Merge(model_config_str, + model_conf) self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_types_ = {} self.feed_shapes_ = {} @@ -598,11 +578,6 @@ class MultiLangClient(object): self.feed_shapes_[var.alias_name] = var.shape if var.is_lod_tensor: self.lod_tensor_set_.add(var.alias_name) - if len(file_path_list) > 1: - model_conf = m_config.GeneralModelConfig() - f = open(file_path_list[-1], 'r') - model_conf = google.protobuf.text_format.Merge( - str(f.read()), model_conf) self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_types_ = {} for i, var in enumerate(model_conf.fetch_var): diff --git a/python/paddle_serving_server/rpc_service.py b/python/paddle_serving_server/rpc_service.py index d9d302831fd2e3148547e24772005efb38cb8f32..f2503a5d86b032499543f5f4fc78b8b824218a44 100755 --- a/python/paddle_serving_server/rpc_service.py +++ b/python/paddle_serving_server/rpc_service.py @@ -198,5 +198,14 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. #model_config_path_list is list right now. #dict should be added when graphMaker is used. resp = multi_lang_general_model_service_pb2.GetClientConfigResponse() - resp.client_config_str_list[:] = self.model_config_path_list + model_config_str = [] + for single_model_config in self.model_config_path_list: + if os.path.isdir(single_model_config): + with open("{}/serving_server_conf.prototxt".format( + single_model_config)) as f: + model_config_str.append(str(f.read())) + elif os.path.isfile(single_model_config): + with open(single_model_config) as f: + model_config_str.append(str(f.read())) + resp.client_config_str = model_config_str[0] return resp