提交 08c4d181 编写于 作者: B barrierye

add ndarray impl

上级 f62370c4
......@@ -410,7 +410,6 @@ class MultiLangClient(object):
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
......@@ -426,10 +425,11 @@ class MultiLangClient(object):
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch):
def _pack_feed_data(self, feed, fetch, is_python):
req = multi_lang_general_model_service_pb2.Request()
req.fetch_var_names.extend(fetch)
req.feed_var_names.extend(feed.keys())
req.is_python = is_python
feed_batch = None
if isinstance(feed, dict):
feed_batch = [feed]
......@@ -444,18 +444,33 @@ class MultiLangClient(object):
tensor = multi_lang_general_model_service_pb2.Tensor()
var = feed_data[name]
v_type = self.feed_types_[name]
if v_type == 0: # int64
if isinstance(var, np.ndarray):
tensor.int64_data.extend(var.reshape(-1).tolist())
else:
tensor.int64_data.extend(self._flatten_list(var))
elif v_type == 1: # float32
if isinstance(var, np.ndarray):
tensor.float_data.extend(var.reshape(-1).tolist())
if is_python:
data = None
if isinstance(var, list):
if v_type == 0: # int64
data = np.array(var, dtype="int64")
elif v_type == 1: # float32
data = np.array(var, dtype="float32")
else:
raise Exception("error type.")
else:
tensor.float_data.extend(self._flatten_list(var))
data = var
if var.dtype == "float64":
data = data.astype("float32")
tensor.data = data.tobytes()
else:
raise Exception("error type.")
if v_type == 0: # int64
if isinstance(var, np.ndarray):
tensor.int64_data.extend(var.reshape(-1).tolist())
else:
tensor.int64_data.extend(self._flatten_list(var))
elif v_type == 1: # float32
if isinstance(var, np.ndarray):
tensor.float_data.extend(var.reshape(-1).tolist())
else:
tensor.float_data.extend(self._flatten_list(var))
else:
raise Exception("error type.")
if isinstance(var, np.ndarray):
tensor.shape.extend(list(var.shape))
else:
......@@ -464,39 +479,60 @@ class MultiLangClient(object):
req.insts.append(inst)
return req
def _unpack_resp(self, resp, fetch, need_variant_tag):
def _unpack_resp(self, resp, fetch, is_python, need_variant_tag):
result_map = {}
inst = resp.outputs[0].insts[0]
tag = resp.tag
for i, name in enumerate(fetch):
var = inst.tensor_array[i]
v_type = self.fetch_types_[name]
if v_type == 0: # int64
result_map[name] = np.array(list(var.int64_data))
elif v_type == 1: # flot32
result_map[name] = np.array(list(var.float_data))
if is_python:
if v_type == 0: # int64
result_map[name] = np.frombuffer(var.data, dtype="int64")
elif v_type == 1: # float32
result_map[name] = np.frombuffer(var.data, dtype="float32")
else:
raise Exception("error type.")
else:
raise Exception("error type.")
if v_type == 0: # int64
result_map[name] = np.array(list(var.int64_data))
elif v_type == 1: # float32
result_map[name] = np.array(list(var.float_data))
else:
raise Exception("error type.")
result_map[name].shape = list(var.shape)
if name in self.lod_tensor_set_:
result_map["{}.lod".format(name)] = np.array(list(var.lod))
return result_map if not need_variant_tag else [result_map, tag]
def _done_callback_func(self, fetch, need_variant_tag):
def _done_callback_func(self, fetch, is_python, need_variant_tag):
def unpack_resp(resp):
return self._unpack_resp(resp, fetch, need_variant_tag)
return self._unpack_resp(resp, fetch, is_python, need_variant_tag)
return unpack_resp
def predict(self, feed, fetch, need_variant_tag=False, asyn=False):
req = self._pack_feed_data(feed, fetch)
def predict(self,
feed,
fetch,
need_variant_tag=False,
asyn=False,
is_python=True):
req = self._pack_feed_data(feed, fetch, is_python=is_python)
if not asyn:
resp = self.stub_.inference(req)
return self._unpack_resp(resp, fetch, need_variant_tag)
return self._unpack_resp(
resp,
fetch,
is_python=is_python,
need_variant_tag=need_variant_tag)
else:
call_future = self.stub_.inference.future(req)
return MultiLangPredictFuture(
call_future, self._done_callback_func(fetch, need_variant_tag))
call_future,
self._done_callback_func(
fetch,
is_python=is_python,
need_variant_tag=need_variant_tag))
class MultiLangPredictFuture(object):
......
......@@ -458,7 +458,6 @@ class MultiLangServerService(
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
......@@ -481,43 +480,50 @@ class MultiLangServerService(
def _unpack_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name]
data = None
if v_type == 0: # int64
data = np.array(
list(feed_inst.tensor_array[idx].int64_data),
dtype="int64")
elif v_type == 1: # float32
data = np.array(
list(feed_inst.tensor_array[idx].float_data),
dtype="float")
if is_python:
if v_type == 0:
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:
data = np.frombuffer(var.data, dtype="float32")
else:
raise Exception("error type.")
else:
raise Exception("error type.")
shape = list(feed_inst.tensor_array[idx].shape)
data.shape = shape
if v_type == 0: # int64
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data
feed_batch.append(feed_dict)
return feed_batch, fetch_names
return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, tag):
def _pack_resp_package(self, result, fetch_names, is_python, tag):
resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
# model_output.fetch_var_names.append(name)
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
if is_python:
tensor.data = result[name].tobytes()
else:
raise Exception("error type.")
if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist())
......@@ -528,10 +534,10 @@ class MultiLangServerService(
return resp
def inference(self, request, context):
feed_dict, fetch_names = self._unpack_request(request)
feed_dict, fetch_names, is_python = self._unpack_request(request)
data, tag = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, tag)
return self._pack_resp_package(data, fetch_names, is_python, tag)
class MultiLangServer(object):
......
......@@ -499,7 +499,6 @@ class MultiLangServerService(
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
......@@ -522,43 +521,50 @@ class MultiLangServerService(
def _unpack_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name]
data = None
if v_type == 0: # int64
data = np.array(
list(feed_inst.tensor_array[idx].int64_data),
dtype="int64")
elif v_type == 1: # float32
data = np.array(
list(feed_inst.tensor_array[idx].float_data),
dtype="float")
if is_python:
if v_type == 0:
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:
data = np.frombuffer(var.data, dtype="float32")
else:
raise Exception("error type.")
else:
raise Exception("error type.")
shape = list(feed_inst.tensor_array[idx].shape)
data.shape = shape
if v_type == 0: # int64
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data
feed_batch.append(feed_dict)
return feed_batch, fetch_names
return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, tag):
def _pack_resp_package(self, result, fetch_names, is_python, tag):
resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
# model_output.fetch_var_names.append(name)
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
if is_python:
tensor.data = result[name].tobytes()
else:
raise Exception("error type.")
if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist())
......@@ -569,10 +575,10 @@ class MultiLangServerService(
return resp
def inference(self, request, context):
feed_dict, fetch_names = self._unpack_request(request)
feed_dict, fetch_names, is_python = self._unpack_request(request)
data, tag = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, tag)
return self._pack_resp_package(data, fetch_names, is_python, tag)
class MultiLangServer(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册