提交 39cc7f17 编写于 作者: B barrierye

add ChannelData.NPDATA type

上级 72010b40
......@@ -18,7 +18,6 @@ package baidu.paddle_serving.pyserving;
message ChannelData {
repeated Inst insts = 1;
required int32 id = 2;
required int32 type = 3 [ default = 0 ];
required int32 ecode = 4;
optional string error_info = 5;
}
......
......@@ -28,8 +28,7 @@ imdb_dataset.load_resource('imdb.vocab')
for i in range(1):
word_ids, label = imdb_dataset.get_words_and_label(words)
fetch_map = lp_wrapper(
feed={"words": word_ids}, fetch=["combined_prediction"])
fetch_map = lp_wrapper(feed={"words": word_ids}, fetch=["prediction"])
print(fetch_map)
#lp.print_stats()
......@@ -33,7 +33,7 @@ class CombineOp(Op):
data = channeldata.parse()
logging.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"]
data = {"combined_prediction": combined_prediction / 2}
data = {"prediction": combined_prediction / 2}
return data
......
......@@ -79,18 +79,22 @@ class ChannelDataEcode(enum.Enum):
TIMEOUT = 1
NOT_IMPLEMENTED = 2
TYPE_ERROR = 3
UNKNOW = 4
RPC_PACKAGE_ERROR = 4
UNKNOW = 5
class ChannelDataType(enum.Enum):
CHANNEL_PBDATA = 0
CHANNEL_FUTURE = 1
CHANNEL_NPDATA = 2
class ChannelData(object):
def __init__(self,
datatype=None,
future=None,
pbdata=None,
npdata=None,
data_id=None,
callback_func=None,
ecode=None,
......@@ -98,10 +102,12 @@ class ChannelData(object):
'''
There are several ways to use it:
1. ChannelData(future, pbdata[, callback_func])
2. ChannelData(future, data_id[, callback_func])
3. ChannelData(pbdata)
4. ChannelData(ecode, error_info, data_id)
1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
6. ChannelData(ecode, error_info, data_id)
'''
if ecode is not None:
if data_id is None or error_info is None:
......@@ -111,37 +117,92 @@ class ChannelData(object):
pbdata.id = data_id
pbdata.error_info = error_info
else:
if datatype == ChannelDataType.CHANNEL_FUTURE.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id
elif not isinstance(pbdata, channel_pb2.ChannelData):
elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
ecode, error_info = self._check_npdata(npdata)
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
else:
for name, value in postped_data.items():
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(
value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
ecode, error_info = self._check_npdata(npdata)
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
else:
raise ValueError("datatype not match")
if not isinstance(pbdata, channel_pb2.ChannelData):
raise TypeError(
"pbdata must be pyserving_channel_pb2.ChannelData type({})".
format(type(pbdata)))
self.future = future
self.pbdata = pbdata
self.npdata = npdata
self.datatype = datatype
self.callback_func = callback_func
def _check_npdata(self, npdata):
ecode = ChannelDataEcode.OK.value
error_info = None
for name, value in npdata.items():
if not isinstance(name, (str, unicode)):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the key of postped_data must " \
"be str, but get {}".format(type(name)))
break
if not isinstance(value, np.ndarray):
pbdata.ecode = ChannelDataEcode.TYPE_ERROR.value
pbdata.error_info = log("the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value)))
break
return ecode, error_info
def parse(self):
# return narray
feed = None
if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
feed = {}
if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
for inst in self.pbdata.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
feed = self.future.result()
if self.callback_func is not None:
feed = self.callback_func(feed)
elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = self.npdata
else:
raise TypeError("Error type({}) in pbdata.type.".format(
self.pbdata.type))
return feed
def __str__(self):
return "type[{}], ecode[{}]".format(
ChannelDataType(self.datatype).name, self.pbdata.ecode)
class Channel(Queue.Queue):
"""
......@@ -213,7 +274,7 @@ class Channel(Queue.Queue):
def push(self, channeldata, op_name=None):
logging.debug(
self._log("{} try to push data: {}".format(op_name,
channeldata.pbdata)))
channeldata.__str__())))
if len(self._producers) == 0:
raise Exception(
self._log(
......@@ -287,7 +348,9 @@ class Channel(Queue.Queue):
break
except Queue.Empty:
self._cv.wait()
logging.debug(self._log("{} get data succ!".format(op_name)))
logging.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
return resp
elif op_name is None:
raise Exception(
......@@ -478,7 +541,10 @@ class Op(object):
# error data in predecessor Op
if error_pbdata is not None:
self._push_to_output_channels(ChannelData(pbdata=error_pbdata))
self._push_to_output_channels(
ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=error_pbdata))
continue
# preprecess
......@@ -564,6 +630,7 @@ class Op(object):
if self.with_serving():
# use call_future
output_data = ChannelData(
datatype=ChannelDataType.CHANNEL_FUTURE.value,
future=call_future,
data_id=data_id,
callback_func=self.postprocess)
......@@ -590,36 +657,10 @@ class Op(object):
data_id=data_id))
continue
ecode = ChannelDataEcode.OK.value
error_info = None
pbdata = channel_pb2.ChannelData()
for name, value in postped_data.items():
if not isinstance(name, (str, unicode)):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the key of postped_data must " \
"be str, but get {}".format(type(name)))
break
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value)))
break
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
pbdata.ecode = ecode
pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata)
output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id)
_profiler.record("{}-postp_1".format(op_info_prefix))
# push data to channel (if run succ)
......@@ -750,34 +791,40 @@ class GeneralPythonService(
pbdata = channel_pb2.ChannelData()
data_id = self._get_next_id()
pbdata.id = data_id
pbdata.ecode = ChannelDataEcode.OK.value
try:
for idx, name in enumerate(request.feed_var_names):
logging.debug(
self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
logging.debug(
self._log('data: {}'.format(request.feed_insts[idx])))
inst = channel_pb2.Inst()
inst.data = request.feed_insts[idx]
inst.shape = request.shape[idx]
inst.name = name
inst.type = request.type[idx]
pbdata.insts.append(inst)
pbdata.ecode = ChannelDataEcode.OK.value #TODO: parse request error
return ChannelData(pbdata=pbdata), data_id
except Exception as e:
pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
pbdata.error_info = "rpc package error"
return ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=pbdata), data_id
def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata'))
resp = pyservice_pb2.Response()
resp.ecode = channeldata.pbdata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts:
resp.fetch_insts.append(inst.data)
resp.fetch_var_names.append(inst.name)
resp.shape.append(inst.shape)
resp.type.append(inst.type)
elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
feed = channeldata.future.result()
if channeldata.callback_func is not None:
feed = channeldata.callback_func(feed)
elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
ChannelDataType.CHANNEL_NPDATA.value):
feed = channeldata.parse()
for name, var in feed.items():
resp.fetch_insts.append(var.tobytes())
resp.fetch_var_names.append(name)
......@@ -788,7 +835,7 @@ class GeneralPythonService(
else:
raise TypeError(
self._log("Error type({}) in pbdata.type.".format(
self.pbdata.type)))
channeldata.datatype)))
else:
resp.error_info = channeldata.pbdata.error_info
return resp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册