提交 a2510413 编写于 作者: B barrierye

add SetTimeout and GetClientConfig rpc function

上级 0b09eed6
......@@ -28,17 +28,17 @@ message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; };
message Request {
message InferenceRequest {
repeated FeedInst insts = 1;
repeated string feed_var_names = 2;
repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ];
};
message Response {
message InferenceResponse {
repeated ModelOutput outputs = 1;
optional string tag = 2;
optional bool brpc_predict_error = 3;
required int32 err_code = 3;
};
message ModelOutput {
......@@ -46,6 +46,17 @@ message ModelOutput {
optional string engine_name = 2;
}
message SetTimeoutRequest { required int32 timeout_ms = 1; }
message SimpleResponse { required int32 err_code = 1; }
message GetClientConfigRequest {}
message GetClientConfigResponse { required string client_config_str = 1; }
service MultiLangGeneralModelService {
rpc inference(Request) returns (Response) {}
rpc Inference(InferenceRequest) returns (InferenceResponse) {}
rpc SetTimeout(SetTimeoutRequest) returns (SimpleResponse) {}
rpc GetClientConfig(GetClientConfigRequest)
returns (GetClientConfigResponse) {}
};
......@@ -384,22 +384,24 @@ class Client(object):
class MultiLangClient(object):
def __init__(self):
self.channel_ = None
self.stub_ = None
self.rpc_timeout_s_ = 2
def load_client_config(self, path):
if not isinstance(path, str):
raise Exception("GClient only supports multi-model temporarily")
self._parse_model_config(path)
def add_variant(self, tag, cluster, variant_weight):
# TODO
raise Exception("cannot support ABtest yet")
def set_rpc_timeout_ms(self, rpc_timeout):
if rpc_timeout > 2000:
print("WARN: you must also need to modify Server timeout, " \
"because the default timeout on Server side is 2000ms.")
if self.stub_ is None:
raise Exception("set timeout must be set after connect.")
if not isinstance(rpc_timeout, int):
# for bclient
raise ValueError("rpc_timeout must be int type.")
self.rpc_timeout_s_ = rpc_timeout / 1000.0
timeout_req = multi_lang_general_model_service_pb2.SetTimeoutRequest()
timeout_req.timeout_ms = rpc_timeout
resp = self.stub_.SetTimeout(timeout_req)
return resp.err_code == 0
def connect(self, endpoints):
# https://github.com/tensorflow/serving/issues/1382
......@@ -411,6 +413,12 @@ class MultiLangClient(object):
self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
self.channel_)
# get client model config
get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
)
resp = self.stub_.GetClientConfig(get_client_config_req)
model_config_str = resp.client_config_str
self._parse_model_config(model_config_str)
def _flatten_list(self, nested_list):
for item in nested_list:
......@@ -420,11 +428,10 @@ class MultiLangClient(object):
else:
yield item
def _parse_model_config(self, model_config_path):
def _parse_model_config(self, model_config_str):
model_conf = m_config.GeneralModelConfig()
f = open(model_config_path, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
......@@ -445,8 +452,8 @@ class MultiLangClient(object):
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch, is_python):
req = multi_lang_general_model_service_pb2.Request()
def _pack_inference_request(self, feed, fetch, is_python):
req = multi_lang_general_model_service_pb2.InferenceRequest()
req.fetch_var_names.extend(fetch)
req.is_python = is_python
feed_batch = None
......@@ -499,8 +506,9 @@ class MultiLangClient(object):
req.insts.append(inst)
return req
def _unpack_resp(self, resp, fetch, is_python, need_variant_tag):
if resp.brpc_predict_error:
def _unpack_inference_response(self, resp, fetch, is_python,
need_variant_tag):
if resp.err_code != 0:
return None
tag = resp.tag
multi_result_map = {}
......@@ -541,7 +549,8 @@ class MultiLangClient(object):
def _done_callback_func(self, fetch, is_python, need_variant_tag):
def unpack_resp(resp):
return self._unpack_resp(resp, fetch, is_python, need_variant_tag)
return self._unpack_inference_response(resp, fetch, is_python,
need_variant_tag)
return unpack_resp
......@@ -553,22 +562,18 @@ class MultiLangClient(object):
fetch,
need_variant_tag=False,
asyn=False,
is_python=True,
timeout_ms=None):
if timeout_ms is None:
timeout = self.rpc_timeout_s_
else:
timeout = timeout_ms / 1000.0
req = self._pack_feed_data(feed, fetch, is_python=is_python)
is_python=True):
req = self._pack_inference_request(feed, fetch, is_python=is_python)
if not asyn:
resp = self.stub_.inference(req, timeout=timeout)
return self._unpack_resp(
resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
return self._unpack_inference_response(
resp,
fetch,
is_python=is_python,
need_variant_tag=need_variant_tag)
else:
call_future = self.stub_.inference.future(req, timeout=timeout)
call_future = self.stub_.Inference.future(
req, timeout=self.rpc_timeout_s_)
return MultiLangPredictFuture(
call_future,
self._done_callback_func(
......
......@@ -440,29 +440,29 @@ class Server(object):
os.system(command)
class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
def __init__(self,
model_config_path,
is_multi_model,
endpoints,
timeout_ms=None):
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path, is_multi_model, endpoints):
self.is_multi_model_ = is_multi_model
self.model_config_path_ = model_config_path
self.endpoints_ = endpoints
with open(self.model_config_path_) as f:
self.model_config_str_ = str(f.read())
self._parse_model_config(self.model_config_str_)
self._init_bclient(self.model_config_path_, self.endpoints_)
def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
self.bclient_ = Client()
if timeout_ms is not None:
self.bclient_.set_rpc_timeout_ms(timeout_ms)
self.bclient_.load_client_config(
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.load_client_config(model_config_path)
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path):
def _parse_model_config(self, model_config_str):
model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path),
'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
......@@ -487,7 +487,7 @@ class MultiLangServerService(
else:
yield item
def _unpack_request(self, request):
def _unpack_inference_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
......@@ -517,14 +517,14 @@ class MultiLangServerService(
feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python
def _pack_resp_package(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.Response()
def _pack_inference_response(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.InferenceResponse()
if ret is None:
resp.brpc_predict_error = True
resp.err_code = 1
return resp
results, tag = ret
resp.tag = tag
resp.brpc_predict_error = False
resp.err_code = 0
if not self.is_multi_model_:
results = {'general_infer_0': results}
for model_name, model_result in results.items():
......@@ -554,11 +554,26 @@ class MultiLangServerService(
resp.outputs.append(model_output)
return resp
def inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_request(request)
def SetTimeout(self, request, context):
# This porcess and Inference process cannot be operate at the same time.
# For performance reasons, do not add thread lock temporarily.
timeout_ms = request.timeout_ms
self._init_bclient(self.model_config_path_, self.endpoints_, timeout_ms)
resp = multi_lang_general_model_service_pb2.SimpleResponse()
resp.err_code = 0
return resp
def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request(
request)
ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(ret, fetch_names, is_python)
return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str = self.model_config_str_
return resp
class MultiLangServer(object):
......@@ -567,12 +582,8 @@ class MultiLangServer(object):
self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000
self.bclient_timeout_ms_ = 2000
self.is_multi_model_ = False # for model ensemble
def set_bclient_timeout_ms(self, timeout):
self.bclient_timeout_ms_ = timeout
def set_max_concurrency(self, concurrency):
self.concurrency_ = concurrency
self.bserver_.set_max_concurrency(concurrency)
......@@ -617,15 +628,17 @@ class MultiLangServer(object):
def use_mkl(self, flag):
self.bserver_.use_mkl(flag)
def load_model_config(self, model_config_paths):
self.bserver_.load_model_config(model_config_paths)
if isinstance(model_config_paths, dict):
# print("You have specified multiple model paths, please ensure "
# "that the input and output of multiple models are the same.")
self.model_config_path_ = list(model_config_paths.items())[0][1]
self.is_multi_model_ = True
else:
self.model_config_path_ = model_config_paths
def load_model_config(self, server_config_paths, client_config_path=None):
self.bserver_.load_model_config(server_config_paths)
if client_config_path is None:
if isinstance(server_config_paths, dict):
self.is_multi_model_ = True
client_config_path = '{}/serving_server_conf.prototxt'.format(
list(server_config_paths.items())[0][1])
else:
client_config_path = '{}/serving_server_conf.prototxt'.format(
server_config_paths)
self.bclient_config_path_ = client_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"):
if not self._port_is_available(port):
......@@ -661,12 +674,9 @@ class MultiLangServer(object):
options=options,
maximum_concurrent_rpcs=self.concurrency_)
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService(
self.model_config_path_,
self.is_multi_model_,
["0.0.0.0:{}".format(self.port_list_[0])],
timeout_ms=self.bclient_timeout_ms_),
server)
MultiLangServerServiceServicer(
self.bclient_config_path_, self.is_multi_model_,
["0.0.0.0:{}".format(self.port_list_[0])]), server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
p_bserver.join()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册