提交 62c07152 编写于 作者: H HexToString

fix rpc bug

上级 ed1ce4b7
...@@ -59,7 +59,7 @@ message SimpleResponse { required int32 err_code = 1; } ...@@ -59,7 +59,7 @@ message SimpleResponse { required int32 err_code = 1; }
message GetClientConfigRequest {} message GetClientConfigRequest {}
message GetClientConfigResponse { repeated string client_config_str_list = 1; } message GetClientConfigResponse { required string client_config_str = 1; }
service MultiLangGeneralModelService { service MultiLangGeneralModelService {
rpc Inference(InferenceRequest) returns (InferenceResponse) {} rpc Inference(InferenceRequest) returns (InferenceResponse) {}
......
...@@ -554,15 +554,8 @@ class MultiLangClient(object): ...@@ -554,15 +554,8 @@ class MultiLangClient(object):
get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest( get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
) )
resp = self.stub_.GetClientConfig(get_client_config_req) resp = self.stub_.GetClientConfig(get_client_config_req)
model_config_path_list = resp.client_config_str_list model_config_str = resp.client_config_str
file_path_list = [] self._parse_model_config(model_config_str)
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
self._parse_model_config(file_path_list)
def _flatten_list(self, nested_list): def _flatten_list(self, nested_list):
for item in nested_list: for item in nested_list:
...@@ -572,23 +565,10 @@ class MultiLangClient(object): ...@@ -572,23 +565,10 @@ class MultiLangClient(object):
else: else:
yield item yield item
def _parse_model_config(self, model_config_path_list): def _parse_model_config(self, model_config_str):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_client_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[0], 'r') model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf = google.protobuf.text_format.Merge( model_conf)
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {} self.feed_types_ = {}
self.feed_shapes_ = {} self.feed_shapes_ = {}
...@@ -598,11 +578,6 @@ class MultiLangClient(object): ...@@ -598,11 +578,6 @@ class MultiLangClient(object):
self.feed_shapes_[var.alias_name] = var.shape self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor: if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name) self.lod_tensor_set_.add(var.alias_name)
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {} self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var): for i, var in enumerate(model_conf.fetch_var):
......
...@@ -198,5 +198,14 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -198,5 +198,14 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
#model_config_path_list is list right now. #model_config_path_list is list right now.
#dict should be added when graphMaker is used. #dict should be added when graphMaker is used.
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse() resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str_list[:] = self.model_config_path_list model_config_str = []
for single_model_config in self.model_config_path_list:
if os.path.isdir(single_model_config):
with open("{}/serving_server_conf.prototxt".format(
single_model_config)) as f:
model_config_str.append(str(f.read()))
elif os.path.isfile(single_model_config):
with open(single_model_config) as f:
model_config_str.append(str(f.read()))
resp.client_config_str = model_config_str[0]
return resp return resp
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册