提交 2c23bcf1 编写于 作者: W wangjiawei04

Merge branch 'pyserving' of https://github.com/barrierye/Serving into pyserving

...@@ -18,7 +18,6 @@ package baidu.paddle_serving.pyserving; ...@@ -18,7 +18,6 @@ package baidu.paddle_serving.pyserving;
message ChannelData { message ChannelData {
repeated Inst insts = 1; repeated Inst insts = 1;
required int32 id = 2; required int32 id = 2;
required int32 type = 3 [ default = 0 ];
required int32 ecode = 4; required int32 ecode = 4;
optional string error_info = 5; optional string error_info = 5;
} }
......
...@@ -28,8 +28,7 @@ imdb_dataset.load_resource('imdb.vocab') ...@@ -28,8 +28,7 @@ imdb_dataset.load_resource('imdb.vocab')
for i in range(1): for i in range(1):
word_ids, label = imdb_dataset.get_words_and_label(words) word_ids, label = imdb_dataset.get_words_and_label(words)
fetch_map = lp_wrapper( fetch_map = lp_wrapper(feed={"words": word_ids}, fetch=["prediction"])
feed={"words": word_ids}, fetch=["combined_prediction"])
print(fetch_map) print(fetch_map)
#lp.print_stats() #lp.print_stats()
...@@ -33,7 +33,7 @@ class CombineOp(Op): ...@@ -33,7 +33,7 @@ class CombineOp(Op):
data = channeldata.parse() data = channeldata.parse()
logging.info("{}: {}".format(op_name, data["prediction"])) logging.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"] combined_prediction += data["prediction"]
data = {"combined_prediction": combined_prediction / 2} data = {"prediction": combined_prediction / 2}
return data return data
......
...@@ -79,18 +79,22 @@ class ChannelDataEcode(enum.Enum): ...@@ -79,18 +79,22 @@ class ChannelDataEcode(enum.Enum):
TIMEOUT = 1 TIMEOUT = 1
NOT_IMPLEMENTED = 2 NOT_IMPLEMENTED = 2
TYPE_ERROR = 3 TYPE_ERROR = 3
UNKNOW = 4 RPC_PACKAGE_ERROR = 4
UNKNOW = 5
class ChannelDataType(enum.Enum): class ChannelDataType(enum.Enum):
CHANNEL_PBDATA = 0 CHANNEL_PBDATA = 0
CHANNEL_FUTURE = 1 CHANNEL_FUTURE = 1
CHANNEL_NPDATA = 2
class ChannelData(object): class ChannelData(object):
def __init__(self, def __init__(self,
datatype=None,
future=None, future=None,
pbdata=None, pbdata=None,
npdata=None,
data_id=None, data_id=None,
callback_func=None, callback_func=None,
ecode=None, ecode=None,
...@@ -98,10 +102,12 @@ class ChannelData(object): ...@@ -98,10 +102,12 @@ class ChannelData(object):
''' '''
There are several ways to use it: There are several ways to use it:
1. ChannelData(future, pbdata[, callback_func]) 1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
2. ChannelData(future, data_id[, callback_func]) 2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
3. ChannelData(pbdata) 3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
4. ChannelData(ecode, error_info, data_id) 4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
6. ChannelData(ecode, error_info, data_id)
''' '''
if ecode is not None: if ecode is not None:
if data_id is None or error_info is None: if data_id is None or error_info is None:
...@@ -111,37 +117,92 @@ class ChannelData(object): ...@@ -111,37 +117,92 @@ class ChannelData(object):
pbdata.id = data_id pbdata.id = data_id
pbdata.error_info = error_info pbdata.error_info = error_info
else: else:
if pbdata is None: if datatype == ChannelDataType.CHANNEL_FUTURE.value:
if data_id is None: if pbdata is None:
raise ValueError("data_id cannot be None") if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id
elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
ecode, error_info = self._check_npdata(npdata)
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
else:
for name, value in postped_data.items():
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(
value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
ecode, error_info = self._check_npdata(npdata)
pbdata = channel_pb2.ChannelData() pbdata = channel_pb2.ChannelData()
pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id pbdata.id = data_id
elif not isinstance(pbdata, channel_pb2.ChannelData): pbdata.ecode = ecode
raise TypeError( if pbdata.ecode != ChannelDataEcode.OK.value:
"pbdata must be pyserving_channel_pb2.ChannelData type({})". pbdata.error_info = error_info
format(type(pbdata))) logging.error(pbdata.error_info)
else:
raise ValueError("datatype not match")
if not isinstance(pbdata, channel_pb2.ChannelData):
raise TypeError(
"pbdata must be pyserving_channel_pb2.ChannelData type({})".
format(type(pbdata)))
self.future = future self.future = future
self.pbdata = pbdata self.pbdata = pbdata
self.npdata = npdata
self.datatype = datatype
self.callback_func = callback_func self.callback_func = callback_func
def _check_npdata(self, npdata):
ecode = ChannelDataEcode.OK.value
error_info = None
for name, value in npdata.items():
if not isinstance(name, (str, unicode)):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the key of postped_data must " \
"be str, but get {}".format(type(name)))
break
if not isinstance(value, np.ndarray):
pbdata.ecode = ChannelDataEcode.TYPE_ERROR.value
pbdata.error_info = log("the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value)))
break
return ecode, error_info
def parse(self): def parse(self):
# return narray # return narray
feed = {} feed = None
if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value: if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
feed = {}
for inst in self.pbdata.insts: for inst in self.pbdata.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type) feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32") feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value: elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
feed = self.future.result() feed = self.future.result()
if self.callback_func is not None: if self.callback_func is not None:
feed = self.callback_func(feed) feed = self.callback_func(feed)
elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = self.npdata
else: else:
raise TypeError("Error type({}) in pbdata.type.".format( raise TypeError("Error type({}) in pbdata.type.".format(
self.pbdata.type)) self.pbdata.type))
return feed return feed
def __str__(self):
return "type[{}], ecode[{}]".format(
ChannelDataType(self.datatype).name, self.pbdata.ecode)
class Channel(Queue.Queue): class Channel(Queue.Queue):
""" """
...@@ -213,7 +274,7 @@ class Channel(Queue.Queue): ...@@ -213,7 +274,7 @@ class Channel(Queue.Queue):
def push(self, channeldata, op_name=None): def push(self, channeldata, op_name=None):
logging.debug( logging.debug(
self._log("{} try to push data: {}".format(op_name, self._log("{} try to push data: {}".format(op_name,
channeldata.pbdata))) channeldata.__str__())))
if len(self._producers) == 0: if len(self._producers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -287,7 +348,9 @@ class Channel(Queue.Queue): ...@@ -287,7 +348,9 @@ class Channel(Queue.Queue):
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
logging.debug(self._log("{} get data succ!".format(op_name))) logging.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
return resp return resp
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -478,7 +541,10 @@ class Op(object): ...@@ -478,7 +541,10 @@ class Op(object):
# error data in predecessor Op # error data in predecessor Op
if error_pbdata is not None: if error_pbdata is not None:
self._push_to_output_channels(ChannelData(pbdata=error_pbdata)) self._push_to_output_channels(
ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=error_pbdata))
continue continue
# preprecess # preprecess
...@@ -564,6 +630,7 @@ class Op(object): ...@@ -564,6 +630,7 @@ class Op(object):
if self.with_serving(): if self.with_serving():
# use call_future # use call_future
output_data = ChannelData( output_data = ChannelData(
datatype=ChannelDataType.CHANNEL_FUTURE.value,
future=call_future, future=call_future,
data_id=data_id, data_id=data_id,
callback_func=self.postprocess) callback_func=self.postprocess)
...@@ -590,36 +657,10 @@ class Op(object): ...@@ -590,36 +657,10 @@ class Op(object):
data_id=data_id)) data_id=data_id))
continue continue
ecode = ChannelDataEcode.OK.value output_data = ChannelData(
error_info = None ChannelDataType.CHANNEL_NPDATA.value,
pbdata = channel_pb2.ChannelData() npdata=postped_data,
for name, value in postped_data.items(): data_id=data_id)
if not isinstance(name, (str, unicode)):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the key of postped_data must " \
"be str, but get {}".format(type(name)))
break
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value)))
break
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
pbdata.ecode = ecode
pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}-postp_1".format(op_info_prefix)) _profiler.record("{}-postp_1".format(op_info_prefix))
# push data to channel (if run succ) # push data to channel (if run succ)
...@@ -750,35 +791,41 @@ class GeneralPythonService( ...@@ -750,35 +791,41 @@ class GeneralPythonService(
pbdata = channel_pb2.ChannelData() pbdata = channel_pb2.ChannelData()
data_id = self._get_next_id() data_id = self._get_next_id()
pbdata.id = data_id pbdata.id = data_id
for idx, name in enumerate(request.feed_var_names): pbdata.ecode = ChannelDataEcode.OK.value
logging.debug( try:
self._log('name: {}'.format(request.feed_var_names[idx]))) for idx, name in enumerate(request.feed_var_names):
logging.debug(self._log('data: {}'.format(request.feed_insts[idx]))) logging.debug(
inst = channel_pb2.Inst() self._log('name: {}'.format(request.feed_var_names[idx])))
inst.data = request.feed_insts[idx] logging.debug(
inst.shape = request.shape[idx] self._log('data: {}'.format(request.feed_insts[idx])))
inst.name = name inst = channel_pb2.Inst()
inst.type = request.type[idx] inst.data = request.feed_insts[idx]
pbdata.insts.append(inst) inst.shape = request.shape[idx]
pbdata.ecode = ChannelDataEcode.OK.value #TODO: parse request error inst.name = name
return ChannelData(pbdata=pbdata), data_id inst.type = request.type[idx]
pbdata.insts.append(inst)
except Exception as e:
pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
pbdata.error_info = "rpc package error"
return ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=pbdata), data_id
def _pack_data_for_resp(self, channeldata): def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata')) logging.debug(self._log('get channeldata'))
resp = pyservice_pb2.Response() resp = pyservice_pb2.Response()
resp.ecode = channeldata.pbdata.ecode resp.ecode = channeldata.pbdata.ecode
if resp.ecode == ChannelDataEcode.OK.value: if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value: if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts: for inst in channeldata.pbdata.insts:
resp.fetch_insts.append(inst.data) resp.fetch_insts.append(inst.data)
resp.fetch_var_names.append(inst.name) resp.fetch_var_names.append(inst.name)
resp.shape.append(inst.shape) resp.shape.append(inst.shape)
resp.type.append(inst.type) resp.type.append(inst.type)
elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value: elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
feed = channeldata.futures.result() ChannelDataType.CHANNEL_NPDATA.value):
if channeldata.callback_func is not None: feed = channeldata.parse()
feed = channeldata.callback_func(feed) for name, var in feed.items():
for name, var in feed:
resp.fetch_insts.append(var.tobytes()) resp.fetch_insts.append(var.tobytes())
resp.fetch_var_names.append(name) resp.fetch_var_names.append(name)
resp.shape.append( resp.shape.append(
...@@ -788,7 +835,7 @@ class GeneralPythonService( ...@@ -788,7 +835,7 @@ class GeneralPythonService(
else: else:
raise TypeError( raise TypeError(
self._log("Error type({}) in pbdata.type.".format( self._log("Error type({}) in pbdata.type.".format(
self.pbdata.type))) channeldata.datatype)))
else: else:
resp.error_info = channeldata.pbdata.error_info resp.error_info = channeldata.pbdata.error_info
return resp return resp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册