提交 fb06526c 编写于 作者: B barrierye

update gpu part

上级 0888751c
......@@ -577,8 +577,13 @@ class MultiLangServer(object):
self.bserver_.set_num_threads(threads)
def set_max_body_size(self, body_size):
self.body_size_ = body_size
self.bserver_.set_max_body_size(body_size)
if body_size >= self.body_size_:
self.body_size_ = body_size
else:
print(
"max_body_size is less than default value, will use default value in service."
)
def set_port(self, port):
self.gport_ = port
......@@ -610,8 +615,8 @@ class MultiLangServer(object):
def load_model_config(self, model_config_paths):
self.bserver_.load_model_config(model_config_paths)
if isinstance(model_config_paths, dict):
print("You have specified multiple model paths, please ensure "
"that the input and output of multiple models are the same.")
# print("You have specified multiple model paths, please ensure "
# "that the input and output of multiple models are the same.")
self.model_config_path_ = list(model_config_paths.items())[0][1]
self.is_multi_model_ = True
else:
......
......@@ -491,10 +491,17 @@ class Server(object):
class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
def __init__(self, model_config_path, endpoints):
def __init__(self,
model_config_path,
is_multi_model,
endpoints,
timeout_ms=None):
self.is_multi_model_ = is_multi_model
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
self.bclient_ = Client()
if timeout_ms is not None:
self.bclient_.set_rpc_timeout_ms(timeout_ms)
self.bclient_.load_client_config(
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.connect(endpoints)
......@@ -559,29 +566,35 @@ class MultiLangServerService(
feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, is_python, tag):
def _pack_resp_package(self, results, fetch_names, is_python, tag):
if not self.is_multi_model_:
results = {'general_infer_0': results}
resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if is_python:
tensor.data = result[name].tobytes()
else:
if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
for model_name, model_result in results.items():
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if is_python:
tensor.data = model_result[name].tobytes()
else:
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist())
inst.tensor_array.append(tensor)
model_output.insts.append(inst)
resp.outputs.append(model_output)
if v_type == 0: # int64
tensor.int64_data.extend(model_result[name].reshape(-1)
.tolist())
elif v_type == 1: # float32
tensor.float_data.extend(model_result[name].reshape(-1)
.tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape))
if name in self.lod_tensor_set_:
tensor.lod.extend(model_result["{}.lod".format(name)]
.tolist())
inst.tensor_array.append(tensor)
model_output.insts.append(inst)
model_output.engine_name = model_name
resp.outputs.append(model_output)
resp.tag = tag
return resp
......@@ -593,19 +606,33 @@ class MultiLangServerService(
class MultiLangServer(object):
def __init__(self, worker_num=2):
def __init__(self):
self.bserver_ = Server()
self.worker_num_ = worker_num
self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000
self.bclient_timeout_ms_ = 2000
self.is_multi_model_ = False # for model ensemble
def set_bclient_timeout_ms(self, timeout):
self.bclient_timeout_ms_ = timeout
def set_max_concurrency(self, concurrency):
self.concurrency_ = concurrency
self.bserver_.set_max_concurrency(concurrency)
def set_num_threads(self, threads):
self.worker_num_ = threads
self.bserver_.set_num_threads(threads)
def set_max_body_size(self, body_size):
# TODO: grpc body
self.bserver_.set_max_body_size(body_size)
if body_size >= self.body_size_:
self.body_size_ = body_size
else:
print(
"max_body_size is less than default value, will use default value in service."
)
def set_port(self, port):
self.gport_ = port
......@@ -628,15 +655,15 @@ class MultiLangServer(object):
def set_gpuid(self, gpuid=0):
self.bserver_.set_gpuid(gpuid)
def use_mkl(self, flag):
self.bserver_.use_mkl(flag)
def load_model_config(self, model_config_path):
if not isinstance(model_config_path, str):
raise Exception(
"MultiLangServer only supports multi-model temporarily")
self.bserver_.load_model_config(model_config_path)
self.model_config_path_ = model_config_path
self.bserver_.load_model_config(model_config_paths)
if isinstance(model_config_paths, dict):
# print("You have specified multiple model paths, please ensure "
# "that the input and output of multiple models are the same.")
self.model_config_path_ = list(model_config_paths.items())[0][1]
self.is_multi_model_ = True
else:
self.model_config_path_ = model_config_paths
def prepare_server(self, workdir=None, port=9292, device="cpu"):
if not self._port_is_available(port):
......@@ -665,11 +692,18 @@ class MultiLangServer(object):
p_bserver = Process(
target=self._launch_brpc_service, args=(self.bserver_, ))
p_bserver.start()
options = [('grpc.max_send_message_length', self.body_size_),
('grpc.max_receive_message_length', self.body_size_)]
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_))
futures.ThreadPoolExecutor(max_workers=self.worker_num_),
options=options,
maximum_concurrent_rpcs=self.concurrency_)
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]),
MultiLangServerService(
self.model_config_path_,
self.is_multi_model_,
["0.0.0.0:{}".format(self.port_list_[0])],
timeout_ms=self.bclient_timeout_ms_),
server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册