提交 35e8b338 编写于 作者: B barrierye

add int32 support

上级 523ad1ec
......@@ -20,7 +20,6 @@ import time
import threading
client = MultiLangClient()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"])
import paddle
......
......@@ -478,26 +478,47 @@ class MultiLangClient(object):
data = np.array(var, dtype="int64")
elif v_type == 1: # float32
data = np.array(var, dtype="float32")
elif v_type == 2: # int32
data = np.array(var, dtype="int32")
else:
raise Exception("error type.")
else:
raise Exception("error tensor value type.")
elif isinstance(var, np.ndarray):
data = var
if var.dtype == "float64":
if v_type == 0 and data.dtype != 'int64':
data = data.astype("int64")
elif v_type == 1 and data.dtype != 'float32':
data = data.astype("float32")
elif v_type == 2 and data.dtype != 'int32':
data = data.astype("int32")
else:
raise Exception("error tensor value type.")
else:
raise Exception("var must be list or ndarray.")
tensor.data = data.tobytes()
else:
if v_type == 0: # int64
if isinstance(var, np.ndarray):
tensor.int64_data.extend(var.reshape(-1).tolist())
if isinstance(var, np.ndarray):
if v_type == 0: # int64
tensor.int64_data.extend(
var.reshape(-1).astype("int64").tolist())
elif v_type == 1:
tensor.float_data.extend(
var.reshape(-1).astype('float32').tolist())
elif v_type == 2:
tensor.int32_data.extend(
var.reshape(-1).astype('int32').tolist())
else:
raise Exception("error tensor value type.")
elif isinstance(var, list):
if v_type == 0:
tensor.int64_data.extend(self._flatten_list(var))
elif v_type == 1: # float32
if isinstance(var, np.ndarray):
tensor.float_data.extend(var.reshape(-1).tolist())
else:
elif v_type == 1:
tensor.float_data.extend(self._flatten_list(var))
elif v_type == 2:
tensor.int32_data.extend(self._flatten_list(var))
else:
raise Exception("error tensor value type.")
else:
raise Exception("error type.")
raise Exception("var must be list or ndarray.")
if isinstance(var, np.ndarray):
tensor.shape.extend(list(var.shape))
else:
......
......@@ -499,10 +499,12 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
v_type = self.feed_types_[name]
data = None
if is_python:
if v_type == 0:
if v_type == 0: # int64
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:
elif v_type == 1: # float32
data = np.frombuffer(var.data, dtype="float32")
elif v_type == 2: # int32
data = np.frombuffer(var.data, dtype="int32")
else:
raise Exception("error type.")
else:
......@@ -510,6 +512,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2: # int32
data = np.array(list(var.int32_data), dtype="int32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
......@@ -542,6 +546,9 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
elif v_type == 1: # float32
tensor.float_data.extend(model_result[name].reshape(-1)
.tolist())
elif v_type == 2: # int32
tensor.int32_data.extend(model_result[name].reshape(-1)
.tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape))
......@@ -619,9 +626,6 @@ class MultiLangServer(object):
def set_ir_optimize(self, flag=False):
self.bserver_.set_ir_optimize(flag)
def set_gpuid(self, gpuid=0):
self.bserver_.set_gpuid(gpuid)
def set_op_sequence(self, op_seq):
self.bserver_.set_op_sequence(op_seq)
......
......@@ -489,29 +489,29 @@ class Server(object):
os.system(command)
class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
def __init__(self,
model_config_path,
is_multi_model,
endpoints,
timeout_ms=None):
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path, is_multi_model, endpoints):
self.is_multi_model_ = is_multi_model
self.model_config_path_ = model_config_path
self.endpoints_ = endpoints
with open(self.model_config_path_) as f:
self.model_config_str_ = str(f.read())
self._parse_model_config(self.model_config_str_)
self._init_bclient(self.model_config_path_, self.endpoints_)
def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
self.bclient_ = Client()
if timeout_ms is not None:
self.bclient_.set_rpc_timeout_ms(timeout_ms)
self.bclient_.load_client_config(
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.load_client_config(model_config_path)
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path):
def _parse_model_config(self, model_config_str):
model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path),
'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
......@@ -536,7 +536,7 @@ class MultiLangServerService(
else:
yield item
def _unpack_request(self, request):
def _unpack_inference_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
......@@ -552,6 +552,8 @@ class MultiLangServerService(
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:
data = np.frombuffer(var.data, dtype="float32")
elif v_type == 2:
data = np.frombuffer(var.data, dtype="int32")
else:
raise Exception("error type.")
else:
......@@ -559,6 +561,8 @@ class MultiLangServerService(
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2:
data = np.array(list(var.int32_data), dtype="int32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
......@@ -566,14 +570,14 @@ class MultiLangServerService(
feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python
def _pack_resp_package(self, ret, fetch_names, is_python):
def _pack_inference_response(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.Response()
if ret is None:
resp.brpc_predict_error = True
resp.err_code = 1
return resp
results, tag = ret
resp.tag = tag
resp.brpc_predict_error = False
resp.err_code = 0
if not self.is_multi_model_:
results = {'general_infer_0': results}
for model_name, model_result in results.items():
......@@ -591,6 +595,9 @@ class MultiLangServerService(
elif v_type == 1: # float32
tensor.float_data.extend(model_result[name].reshape(-1)
.tolist())
elif v_type == 2: # int32
tensor.int32_data.extend(model_result[name].reshape(-1)
.tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape))
......@@ -603,11 +610,26 @@ class MultiLangServerService(
resp.outputs.append(model_output)
return resp
def inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_request(request)
def SetTimeout(self, request, context):
# This porcess and Inference process cannot be operate at the same time.
# For performance reasons, do not add thread lock temporarily.
timeout_ms = request.timeout_ms
self._init_bclient(self.model_config_path_, self.endpoints_, timeout_ms)
resp = multi_lang_general_model_service_pb2.SimpleResponse()
resp.err_code = 0
return resp
def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request(
request)
ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(ret, fetch_names, is_python)
return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str = self.model_config_str_
return resp
class MultiLangServer(object):
......@@ -616,12 +638,8 @@ class MultiLangServer(object):
self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000
self.bclient_timeout_ms_ = 2000
self.is_multi_model_ = False # for model ensemble
def set_bclient_timeout_ms(self, timeout):
self.bclient_timeout_ms_ = timeout
def set_max_concurrency(self, concurrency):
self.concurrency_ = concurrency
self.bserver_.set_max_concurrency(concurrency)
......@@ -660,15 +678,17 @@ class MultiLangServer(object):
def set_gpuid(self, gpuid=0):
self.bserver_.set_gpuid(gpuid)
def load_model_config(self, model_config_paths):
self.bserver_.load_model_config(model_config_paths)
if isinstance(model_config_paths, dict):
# print("You have specified multiple model paths, please ensure "
# "that the input and output of multiple models are the same.")
self.model_config_path_ = list(model_config_paths.items())[0][1]
self.is_multi_model_ = True
else:
self.model_config_path_ = model_config_paths
def load_model_config(self, server_config_paths, client_config_path=None):
self.bserver_.load_model_config(server_config_paths)
if client_config_path is None:
if isinstance(server_config_paths, dict):
self.is_multi_model_ = True
client_config_path = '{}/serving_server_conf.prototxt'.format(
list(server_config_paths.items())[0][1])
else:
client_config_path = '{}/serving_server_conf.prototxt'.format(
server_config_paths)
self.bclient_config_path_ = client_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"):
if not self._port_is_available(port):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册