提交 fd51120b 编写于 作者: B barrierye

add ndarray impl

上级 743892b6
...@@ -410,7 +410,6 @@ class MultiLangClient(object): ...@@ -410,7 +410,6 @@ class MultiLangClient(object):
self.feed_shapes_ = {} self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {} self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set() self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var): for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type self.feed_types_[var.alias_name] = var.feed_type
...@@ -426,10 +425,11 @@ class MultiLangClient(object): ...@@ -426,10 +425,11 @@ class MultiLangClient(object):
if var.is_lod_tensor: if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name) self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch): def _pack_feed_data(self, feed, fetch, is_python):
req = multi_lang_general_model_service_pb2.Request() req = multi_lang_general_model_service_pb2.Request()
req.fetch_var_names.extend(fetch) req.fetch_var_names.extend(fetch)
req.feed_var_names.extend(feed.keys()) req.feed_var_names.extend(feed.keys())
req.is_python = is_python
feed_batch = None feed_batch = None
if isinstance(feed, dict): if isinstance(feed, dict):
feed_batch = [feed] feed_batch = [feed]
...@@ -444,18 +444,33 @@ class MultiLangClient(object): ...@@ -444,18 +444,33 @@ class MultiLangClient(object):
tensor = multi_lang_general_model_service_pb2.Tensor() tensor = multi_lang_general_model_service_pb2.Tensor()
var = feed_data[name] var = feed_data[name]
v_type = self.feed_types_[name] v_type = self.feed_types_[name]
if v_type == 0: # int64 if is_python:
if isinstance(var, np.ndarray): data = None
tensor.int64_data.extend(var.reshape(-1).tolist()) if isinstance(var, list):
else: if v_type == 0: # int64
tensor.int64_data.extend(self._flatten_list(var)) data = np.array(var, dtype="int64")
elif v_type == 1: # float32 elif v_type == 1: # float32
if isinstance(var, np.ndarray): data = np.array(var, dtype="float32")
tensor.float_data.extend(var.reshape(-1).tolist()) else:
raise Exception("error type.")
else: else:
tensor.float_data.extend(self._flatten_list(var)) data = var
if var.dtype == "float64":
data = data.astype("float32")
tensor.data = data.tobytes()
else: else:
raise Exception("error type.") if v_type == 0: # int64
if isinstance(var, np.ndarray):
tensor.int64_data.extend(var.reshape(-1).tolist())
else:
tensor.int64_data.extend(self._flatten_list(var))
elif v_type == 1: # float32
if isinstance(var, np.ndarray):
tensor.float_data.extend(var.reshape(-1).tolist())
else:
tensor.float_data.extend(self._flatten_list(var))
else:
raise Exception("error type.")
if isinstance(var, np.ndarray): if isinstance(var, np.ndarray):
tensor.shape.extend(list(var.shape)) tensor.shape.extend(list(var.shape))
else: else:
...@@ -464,39 +479,60 @@ class MultiLangClient(object): ...@@ -464,39 +479,60 @@ class MultiLangClient(object):
req.insts.append(inst) req.insts.append(inst)
return req return req
def _unpack_resp(self, resp, fetch, need_variant_tag): def _unpack_resp(self, resp, fetch, is_python, need_variant_tag):
result_map = {} result_map = {}
inst = resp.outputs[0].insts[0] inst = resp.outputs[0].insts[0]
tag = resp.tag tag = resp.tag
for i, name in enumerate(fetch): for i, name in enumerate(fetch):
var = inst.tensor_array[i] var = inst.tensor_array[i]
v_type = self.fetch_types_[name] v_type = self.fetch_types_[name]
if v_type == 0: # int64 if is_python:
result_map[name] = np.array(list(var.int64_data)) if v_type == 0: # int64
elif v_type == 1: # flot32 result_map[name] = np.frombuffer(var.data, dtype="int64")
result_map[name] = np.array(list(var.float_data)) elif v_type == 1: # float32
result_map[name] = np.frombuffer(var.data, dtype="float32")
else:
raise Exception("error type.")
else: else:
raise Exception("error type.") if v_type == 0: # int64
result_map[name] = np.array(list(var.int64_data))
elif v_type == 1: # float32
result_map[name] = np.array(list(var.float_data))
else:
raise Exception("error type.")
result_map[name].shape = list(var.shape) result_map[name].shape = list(var.shape)
if name in self.lod_tensor_set_: if name in self.lod_tensor_set_:
result_map["{}.lod".format(name)] = np.array(list(var.lod)) result_map["{}.lod".format(name)] = np.array(list(var.lod))
return result_map if not need_variant_tag else [result_map, tag] return result_map if not need_variant_tag else [result_map, tag]
def _done_callback_func(self, fetch, need_variant_tag): def _done_callback_func(self, fetch, is_python, need_variant_tag):
def unpack_resp(resp): def unpack_resp(resp):
return self._unpack_resp(resp, fetch, need_variant_tag) return self._unpack_resp(resp, fetch, is_python, need_variant_tag)
return unpack_resp return unpack_resp
def predict(self, feed, fetch, need_variant_tag=False, asyn=False): def predict(self,
req = self._pack_feed_data(feed, fetch) feed,
fetch,
need_variant_tag=False,
asyn=False,
is_python=True):
req = self._pack_feed_data(feed, fetch, is_python=is_python)
if not asyn: if not asyn:
resp = self.stub_.inference(req) resp = self.stub_.inference(req)
return self._unpack_resp(resp, fetch, need_variant_tag) return self._unpack_resp(
resp,
fetch,
is_python=is_python,
need_variant_tag=need_variant_tag)
else: else:
call_future = self.stub_.inference.future(req) call_future = self.stub_.inference.future(req)
return MultiLangPredictFuture( return MultiLangPredictFuture(
call_future, self._done_callback_func(fetch, need_variant_tag)) call_future,
self._done_callback_func(
fetch,
is_python=is_python,
need_variant_tag=need_variant_tag))
class MultiLangPredictFuture(object): class MultiLangPredictFuture(object):
......
...@@ -458,7 +458,6 @@ class MultiLangServerService( ...@@ -458,7 +458,6 @@ class MultiLangServerService(
self.feed_shapes_ = {} self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {} self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set() self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var): for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type self.feed_types_[var.alias_name] = var.feed_type
...@@ -481,43 +480,50 @@ class MultiLangServerService( ...@@ -481,43 +480,50 @@ class MultiLangServerService(
def _unpack_request(self, request): def _unpack_request(self, request):
feed_names = list(request.feed_var_names) feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names) fetch_names = list(request.fetch_var_names)
is_python = request.is_python
feed_batch = [] feed_batch = []
for feed_inst in request.insts: for feed_inst in request.insts:
feed_dict = {} feed_dict = {}
for idx, name in enumerate(feed_names): for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name] v_type = self.feed_types_[name]
data = None data = None
if v_type == 0: # int64 if is_python:
data = np.array( if v_type == 0:
list(feed_inst.tensor_array[idx].int64_data), data = np.frombuffer(var.data, dtype="int64")
dtype="int64") elif v_type == 1:
elif v_type == 1: # float32 data = np.frombuffer(var.data, dtype="float32")
data = np.array( else:
list(feed_inst.tensor_array[idx].float_data), raise Exception("error type.")
dtype="float")
else: else:
raise Exception("error type.") if v_type == 0: # int64
shape = list(feed_inst.tensor_array[idx].shape) data = np.array(list(var.int64_data), dtype="int64")
data.shape = shape elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data feed_dict[name] = data
feed_batch.append(feed_dict) feed_batch.append(feed_dict)
return feed_batch, fetch_names return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, tag): def _pack_resp_package(self, result, fetch_names, is_python, tag):
resp = multi_lang_general_model_service_pb2.Response() resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily # Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput() model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst() inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names): for idx, name in enumerate(fetch_names):
# model_output.fetch_var_names.append(name)
tensor = multi_lang_general_model_service_pb2.Tensor() tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name] v_type = self.fetch_types_[name]
if v_type == 0: # int64 if is_python:
tensor.int64_data.extend(result[name].reshape(-1).tolist()) tensor.data = result[name].tobytes()
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else: else:
raise Exception("error type.") if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape)) tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_: if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist()) tensor.lod.extend(result["{}.lod".format(name)].tolist())
...@@ -528,10 +534,10 @@ class MultiLangServerService( ...@@ -528,10 +534,10 @@ class MultiLangServerService(
return resp return resp
def inference(self, request, context): def inference(self, request, context):
feed_dict, fetch_names = self._unpack_request(request) feed_dict, fetch_names, is_python = self._unpack_request(request)
data, tag = self.bclient_.predict( data, tag = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True) feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, tag) return self._pack_resp_package(data, fetch_names, is_python, tag)
class MultiLangServer(object): class MultiLangServer(object):
......
...@@ -499,7 +499,6 @@ class MultiLangServerService( ...@@ -499,7 +499,6 @@ class MultiLangServerService(
self.feed_shapes_ = {} self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {} self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set() self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var): for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type self.feed_types_[var.alias_name] = var.feed_type
...@@ -522,43 +521,50 @@ class MultiLangServerService( ...@@ -522,43 +521,50 @@ class MultiLangServerService(
def _unpack_request(self, request): def _unpack_request(self, request):
feed_names = list(request.feed_var_names) feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names) fetch_names = list(request.fetch_var_names)
is_python = request.is_python
feed_batch = [] feed_batch = []
for feed_inst in request.insts: for feed_inst in request.insts:
feed_dict = {} feed_dict = {}
for idx, name in enumerate(feed_names): for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name] v_type = self.feed_types_[name]
data = None data = None
if v_type == 0: # int64 if is_python:
data = np.array( if v_type == 0:
list(feed_inst.tensor_array[idx].int64_data), data = np.frombuffer(var.data, dtype="int64")
dtype="int64") elif v_type == 1:
elif v_type == 1: # float32 data = np.frombuffer(var.data, dtype="float32")
data = np.array( else:
list(feed_inst.tensor_array[idx].float_data), raise Exception("error type.")
dtype="float")
else: else:
raise Exception("error type.") if v_type == 0: # int64
shape = list(feed_inst.tensor_array[idx].shape) data = np.array(list(var.int64_data), dtype="int64")
data.shape = shape elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data feed_dict[name] = data
feed_batch.append(feed_dict) feed_batch.append(feed_dict)
return feed_batch, fetch_names return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, tag): def _pack_resp_package(self, result, fetch_names, is_python, tag):
resp = multi_lang_general_model_service_pb2.Response() resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily # Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput() model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst() inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names): for idx, name in enumerate(fetch_names):
# model_output.fetch_var_names.append(name)
tensor = multi_lang_general_model_service_pb2.Tensor() tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name] v_type = self.fetch_types_[name]
if v_type == 0: # int64 if is_python:
tensor.int64_data.extend(result[name].reshape(-1).tolist()) tensor.data = result[name].tobytes()
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else: else:
raise Exception("error type.") if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape)) tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_: if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist()) tensor.lod.extend(result["{}.lod".format(name)].tolist())
...@@ -569,10 +575,10 @@ class MultiLangServerService( ...@@ -569,10 +575,10 @@ class MultiLangServerService(
return resp return resp
def inference(self, request, context): def inference(self, request, context):
feed_dict, fetch_names = self._unpack_request(request) feed_dict, fetch_names, is_python = self._unpack_request(request)
data, tag = self.bclient_.predict( data, tag = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True) feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, tag) return self._pack_resp_package(data, fetch_names, is_python, tag)
class MultiLangServer(object): class MultiLangServer(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册