提交 523ad1ec 编写于 作者: B barrierye

add SetTimeout and GetClientConfig rpc function

上级 fdea4da7
...@@ -28,17 +28,17 @@ message FeedInst { repeated Tensor tensor_array = 1; }; ...@@ -28,17 +28,17 @@ message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; }; message FetchInst { repeated Tensor tensor_array = 1; };
message Request { message InferenceRequest {
repeated FeedInst insts = 1; repeated FeedInst insts = 1;
repeated string feed_var_names = 2; repeated string feed_var_names = 2;
repeated string fetch_var_names = 3; repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ]; required bool is_python = 4 [ default = false ];
}; };
message Response { message InferenceResponse {
repeated ModelOutput outputs = 1; repeated ModelOutput outputs = 1;
optional string tag = 2; optional string tag = 2;
optional bool brpc_predict_error = 3; required int32 err_code = 3;
}; };
message ModelOutput { message ModelOutput {
...@@ -46,6 +46,17 @@ message ModelOutput { ...@@ -46,6 +46,17 @@ message ModelOutput {
optional string engine_name = 2; optional string engine_name = 2;
} }
message SetTimeoutRequest { required int32 timeout_ms = 1; }
message SimpleResponse { required int32 err_code = 1; }
message GetClientConfigRequest {}
message GetClientConfigResponse { required string client_config_str = 1; }
service MultiLangGeneralModelService { service MultiLangGeneralModelService {
rpc inference(Request) returns (Response) {} rpc Inference(InferenceRequest) returns (InferenceResponse) {}
rpc SetTimeout(SetTimeoutRequest) returns (SimpleResponse) {}
rpc GetClientConfig(GetClientConfigRequest)
returns (GetClientConfigResponse) {}
}; };
...@@ -384,22 +384,24 @@ class Client(object): ...@@ -384,22 +384,24 @@ class Client(object):
class MultiLangClient(object): class MultiLangClient(object):
def __init__(self): def __init__(self):
self.channel_ = None self.channel_ = None
self.stub_ = None
self.rpc_timeout_s_ = 2 self.rpc_timeout_s_ = 2
def load_client_config(self, path):
if not isinstance(path, str):
raise Exception("GClient only supports multi-model temporarily")
self._parse_model_config(path)
def add_variant(self, tag, cluster, variant_weight): def add_variant(self, tag, cluster, variant_weight):
# TODO # TODO
raise Exception("cannot support ABtest yet") raise Exception("cannot support ABtest yet")
def set_rpc_timeout_ms(self, rpc_timeout): def set_rpc_timeout_ms(self, rpc_timeout):
if rpc_timeout > 2000: if self.stub_ is None:
print("WARN: you must also need to modify Server timeout, " \ raise Exception("set timeout must be set after connect.")
"because the default timeout on Server side is 2000ms.") if not isinstance(rpc_timeout, int):
# for bclient
raise ValueError("rpc_timeout must be int type.")
self.rpc_timeout_s_ = rpc_timeout / 1000.0 self.rpc_timeout_s_ = rpc_timeout / 1000.0
timeout_req = multi_lang_general_model_service_pb2.SetTimeoutRequest()
timeout_req.timeout_ms = rpc_timeout
resp = self.stub_.SetTimeout(timeout_req)
return resp.err_code == 0
def connect(self, endpoints): def connect(self, endpoints):
# https://github.com/tensorflow/serving/issues/1382 # https://github.com/tensorflow/serving/issues/1382
...@@ -411,6 +413,12 @@ class MultiLangClient(object): ...@@ -411,6 +413,12 @@ class MultiLangClient(object):
self.channel_ = grpc.insecure_channel(g_endpoint, options=options) self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub( self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
self.channel_) self.channel_)
# get client model config
get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
)
resp = self.stub_.GetClientConfig(get_client_config_req)
model_config_str = resp.client_config_str
self._parse_model_config(model_config_str)
def _flatten_list(self, nested_list): def _flatten_list(self, nested_list):
for item in nested_list: for item in nested_list:
...@@ -420,11 +428,10 @@ class MultiLangClient(object): ...@@ -420,11 +428,10 @@ class MultiLangClient(object):
else: else:
yield item yield item
def _parse_model_config(self, model_config_path): def _parse_model_config(self, model_config_str):
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open(model_config_path, 'r') model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf = google.protobuf.text_format.Merge( model_conf)
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {} self.feed_types_ = {}
self.feed_shapes_ = {} self.feed_shapes_ = {}
...@@ -445,8 +452,8 @@ class MultiLangClient(object): ...@@ -445,8 +452,8 @@ class MultiLangClient(object):
if var.is_lod_tensor: if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name) self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch, is_python): def _pack_inference_request(self, feed, fetch, is_python):
req = multi_lang_general_model_service_pb2.Request() req = multi_lang_general_model_service_pb2.InferenceRequest()
req.fetch_var_names.extend(fetch) req.fetch_var_names.extend(fetch)
req.is_python = is_python req.is_python = is_python
feed_batch = None feed_batch = None
...@@ -499,8 +506,9 @@ class MultiLangClient(object): ...@@ -499,8 +506,9 @@ class MultiLangClient(object):
req.insts.append(inst) req.insts.append(inst)
return req return req
def _unpack_resp(self, resp, fetch, is_python, need_variant_tag): def _unpack_inference_response(self, resp, fetch, is_python,
if resp.brpc_predict_error: need_variant_tag):
if resp.err_code != 0:
return None return None
tag = resp.tag tag = resp.tag
multi_result_map = {} multi_result_map = {}
...@@ -541,7 +549,8 @@ class MultiLangClient(object): ...@@ -541,7 +549,8 @@ class MultiLangClient(object):
def _done_callback_func(self, fetch, is_python, need_variant_tag): def _done_callback_func(self, fetch, is_python, need_variant_tag):
def unpack_resp(resp): def unpack_resp(resp):
return self._unpack_resp(resp, fetch, is_python, need_variant_tag) return self._unpack_inference_response(resp, fetch, is_python,
need_variant_tag)
return unpack_resp return unpack_resp
...@@ -553,22 +562,18 @@ class MultiLangClient(object): ...@@ -553,22 +562,18 @@ class MultiLangClient(object):
fetch, fetch,
need_variant_tag=False, need_variant_tag=False,
asyn=False, asyn=False,
is_python=True, is_python=True):
timeout_ms=None): req = self._pack_inference_request(feed, fetch, is_python=is_python)
if timeout_ms is None:
timeout = self.rpc_timeout_s_
else:
timeout = timeout_ms / 1000.0
req = self._pack_feed_data(feed, fetch, is_python=is_python)
if not asyn: if not asyn:
resp = self.stub_.inference(req, timeout=timeout) resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
return self._unpack_resp( return self._unpack_inference_response(
resp, resp,
fetch, fetch,
is_python=is_python, is_python=is_python,
need_variant_tag=need_variant_tag) need_variant_tag=need_variant_tag)
else: else:
call_future = self.stub_.inference.future(req, timeout=timeout) call_future = self.stub_.Inference.future(
req, timeout=self.rpc_timeout_s_)
return MultiLangPredictFuture( return MultiLangPredictFuture(
call_future, call_future,
self._done_callback_func( self._done_callback_func(
......
...@@ -440,29 +440,29 @@ class Server(object): ...@@ -440,29 +440,29 @@ class Server(object):
os.system(command) os.system(command)
class MultiLangServerService( class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService): MultiLangGeneralModelServiceServicer):
def __init__(self, def __init__(self, model_config_path, is_multi_model, endpoints):
model_config_path,
is_multi_model,
endpoints,
timeout_ms=None):
self.is_multi_model_ = is_multi_model self.is_multi_model_ = is_multi_model
self.model_config_path_ = model_config_path
self.endpoints_ = endpoints
with open(self.model_config_path_) as f:
self.model_config_str_ = str(f.read())
self._parse_model_config(self.model_config_str_)
self._init_bclient(self.model_config_path_, self.endpoints_)
def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
from paddle_serving_client import Client from paddle_serving_client import Client
self._parse_model_config(model_config_path)
self.bclient_ = Client() self.bclient_ = Client()
if timeout_ms is not None: if timeout_ms is not None:
self.bclient_.set_rpc_timeout_ms(timeout_ms) self.bclient_.set_rpc_timeout_ms(timeout_ms)
self.bclient_.load_client_config( self.bclient_.load_client_config(model_config_path)
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.connect(endpoints) self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path): def _parse_model_config(self, model_config_str):
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path), model_conf = google.protobuf.text_format.Merge(model_config_str,
'r') model_conf)
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {} self.feed_types_ = {}
self.feed_shapes_ = {} self.feed_shapes_ = {}
...@@ -487,7 +487,7 @@ class MultiLangServerService( ...@@ -487,7 +487,7 @@ class MultiLangServerService(
else: else:
yield item yield item
def _unpack_request(self, request): def _unpack_inference_request(self, request):
feed_names = list(request.feed_var_names) feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names) fetch_names = list(request.fetch_var_names)
is_python = request.is_python is_python = request.is_python
...@@ -517,14 +517,14 @@ class MultiLangServerService( ...@@ -517,14 +517,14 @@ class MultiLangServerService(
feed_batch.append(feed_dict) feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python return feed_batch, fetch_names, is_python
def _pack_resp_package(self, ret, fetch_names, is_python): def _pack_inference_response(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.Response() resp = multi_lang_general_model_service_pb2.InferenceResponse()
if ret is None: if ret is None:
resp.brpc_predict_error = True resp.err_code = 1
return resp return resp
results, tag = ret results, tag = ret
resp.tag = tag resp.tag = tag
resp.brpc_predict_error = False resp.err_code = 0
if not self.is_multi_model_: if not self.is_multi_model_:
results = {'general_infer_0': results} results = {'general_infer_0': results}
for model_name, model_result in results.items(): for model_name, model_result in results.items():
...@@ -554,11 +554,26 @@ class MultiLangServerService( ...@@ -554,11 +554,26 @@ class MultiLangServerService(
resp.outputs.append(model_output) resp.outputs.append(model_output)
return resp return resp
def inference(self, request, context): def SetTimeout(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_request(request) # This porcess and Inference process cannot be operate at the same time.
# For performance reasons, do not add thread lock temporarily.
timeout_ms = request.timeout_ms
self._init_bclient(self.model_config_path_, self.endpoints_, timeout_ms)
resp = multi_lang_general_model_service_pb2.SimpleResponse()
resp.err_code = 0
return resp
def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request(
request)
ret = self.bclient_.predict( ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True) feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(ret, fetch_names, is_python) return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str = self.model_config_str_
return resp
class MultiLangServer(object): class MultiLangServer(object):
...@@ -567,12 +582,8 @@ class MultiLangServer(object): ...@@ -567,12 +582,8 @@ class MultiLangServer(object):
self.worker_num_ = 4 self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024 self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000 self.concurrency_ = 100000
self.bclient_timeout_ms_ = 2000
self.is_multi_model_ = False # for model ensemble self.is_multi_model_ = False # for model ensemble
def set_bclient_timeout_ms(self, timeout):
self.bclient_timeout_ms_ = timeout
def set_max_concurrency(self, concurrency): def set_max_concurrency(self, concurrency):
self.concurrency_ = concurrency self.concurrency_ = concurrency
self.bserver_.set_max_concurrency(concurrency) self.bserver_.set_max_concurrency(concurrency)
...@@ -617,15 +628,17 @@ class MultiLangServer(object): ...@@ -617,15 +628,17 @@ class MultiLangServer(object):
def use_mkl(self, flag): def use_mkl(self, flag):
self.bserver_.use_mkl(flag) self.bserver_.use_mkl(flag)
def load_model_config(self, model_config_paths): def load_model_config(self, server_config_paths, client_config_path=None):
self.bserver_.load_model_config(model_config_paths) self.bserver_.load_model_config(server_config_paths)
if isinstance(model_config_paths, dict): if client_config_path is None:
# print("You have specified multiple model paths, please ensure " if isinstance(server_config_paths, dict):
# "that the input and output of multiple models are the same.") self.is_multi_model_ = True
self.model_config_path_ = list(model_config_paths.items())[0][1] client_config_path = '{}/serving_server_conf.prototxt'.format(
self.is_multi_model_ = True list(server_config_paths.items())[0][1])
else: else:
self.model_config_path_ = model_config_paths client_config_path = '{}/serving_server_conf.prototxt'.format(
server_config_paths)
self.bclient_config_path_ = client_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"): def prepare_server(self, workdir=None, port=9292, device="cpu"):
if not self._port_is_available(port): if not self._port_is_available(port):
...@@ -661,12 +674,9 @@ class MultiLangServer(object): ...@@ -661,12 +674,9 @@ class MultiLangServer(object):
options=options, options=options,
maximum_concurrent_rpcs=self.concurrency_) maximum_concurrent_rpcs=self.concurrency_)
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server( multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService( MultiLangServerServiceServicer(
self.model_config_path_, self.bclient_config_path_, self.is_multi_model_,
self.is_multi_model_, ["0.0.0.0:{}".format(self.port_list_[0])]), server)
["0.0.0.0:{}".format(self.port_list_[0])],
timeout_ms=self.bclient_timeout_ms_),
server)
server.add_insecure_port('[::]:{}'.format(self.gport_)) server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start() server.start()
p_bserver.join() p_bserver.join()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册