From 39cc7f17024468c48f325287b44fb7830a3d20b3 Mon Sep 17 00:00:00 2001 From: barrierye Date: Fri, 12 Jun 2020 19:13:02 +0800 Subject: [PATCH] add ChannelData.NPDATA type --- core/configure/proto/pyserving_channel.proto | 1 - python/examples/imdb/test_py_client.py | 3 +- python/examples/imdb/test_py_server.py | 2 +- python/paddle_serving_server/pyserver.py | 183 ++++++++++++------- 4 files changed, 117 insertions(+), 72 deletions(-) diff --git a/core/configure/proto/pyserving_channel.proto b/core/configure/proto/pyserving_channel.proto index 060f4d72..063de92d 100644 --- a/core/configure/proto/pyserving_channel.proto +++ b/core/configure/proto/pyserving_channel.proto @@ -18,7 +18,6 @@ package baidu.paddle_serving.pyserving; message ChannelData { repeated Inst insts = 1; required int32 id = 2; - required int32 type = 3 [ default = 0 ]; required int32 ecode = 4; optional string error_info = 5; } diff --git a/python/examples/imdb/test_py_client.py b/python/examples/imdb/test_py_client.py index 3f811e16..ee964aa1 100644 --- a/python/examples/imdb/test_py_client.py +++ b/python/examples/imdb/test_py_client.py @@ -28,8 +28,7 @@ imdb_dataset.load_resource('imdb.vocab') for i in range(1): word_ids, label = imdb_dataset.get_words_and_label(words) - fetch_map = lp_wrapper( - feed={"words": word_ids}, fetch=["combined_prediction"]) + fetch_map = lp_wrapper(feed={"words": word_ids}, fetch=["prediction"]) print(fetch_map) #lp.print_stats() diff --git a/python/examples/imdb/test_py_server.py b/python/examples/imdb/test_py_server.py index d8879564..ed93a93b 100644 --- a/python/examples/imdb/test_py_server.py +++ b/python/examples/imdb/test_py_server.py @@ -33,7 +33,7 @@ class CombineOp(Op): data = channeldata.parse() logging.info("{}: {}".format(op_name, data["prediction"])) combined_prediction += data["prediction"] - data = {"combined_prediction": combined_prediction / 2} + data = {"prediction": combined_prediction / 2} return data diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index 70149bb2..8dda9a68 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -79,18 +79,22 @@ class ChannelDataEcode(enum.Enum): TIMEOUT = 1 NOT_IMPLEMENTED = 2 TYPE_ERROR = 3 - UNKNOW = 4 + RPC_PACKAGE_ERROR = 4 + UNKNOW = 5 class ChannelDataType(enum.Enum): CHANNEL_PBDATA = 0 CHANNEL_FUTURE = 1 + CHANNEL_NPDATA = 2 class ChannelData(object): def __init__(self, + datatype=None, future=None, pbdata=None, + npdata=None, data_id=None, callback_func=None, ecode=None, @@ -98,10 +102,12 @@ class ChannelData(object): ''' There are several ways to use it: - 1. ChannelData(future, pbdata[, callback_func]) - 2. ChannelData(future, data_id[, callback_func]) - 3. ChannelData(pbdata) - 4. ChannelData(ecode, error_info, data_id) + 1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func]) + 2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func]) + 3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata) + 4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id) + 5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id) + 6. ChannelData(ecode, error_info, data_id) ''' if ecode is not None: if data_id is None or error_info is None: @@ -111,37 +117,92 @@ class ChannelData(object): pbdata.id = data_id pbdata.error_info = error_info else: - if pbdata is None: - if data_id is None: - raise ValueError("data_id cannot be None") + if datatype == ChannelDataType.CHANNEL_FUTURE.value: + if pbdata is None: + if data_id is None: + raise ValueError("data_id cannot be None") + pbdata = channel_pb2.ChannelData() + pbdata.ecode = ChannelDataEcode.OK.value + pbdata.id = data_id + elif datatype == ChannelDataType.CHANNEL_PBDATA.value: + if pbdata is None: + if data_id is None: + raise ValueError("data_id cannot be None") + pbdata = channel_pb2.ChannelData() + pbdata.id = data_id + ecode, error_info = self._check_npdata(npdata) + pbdata.ecode = ecode + if pbdata.ecode != ChannelDataEcode.OK.value: + pbdata.error_info = error_info + logging.error(pbdata.error_info) + else: + for name, value in postped_data.items(): + inst = channel_pb2.Inst() + inst.data = value.tobytes() + inst.name = name + inst.shape = np.array( + value.shape, dtype="int32").tobytes() + inst.type = str(value.dtype) + pbdata.insts.append(inst) + elif datatype == ChannelDataType.CHANNEL_NPDATA.value: + ecode, error_info = self._check_npdata(npdata) pbdata = channel_pb2.ChannelData() - pbdata.type = ChannelDataType.CHANNEL_FUTURE.value - pbdata.ecode = ChannelDataEcode.OK.value pbdata.id = data_id - elif not isinstance(pbdata, channel_pb2.ChannelData): - raise TypeError( - "pbdata must be pyserving_channel_pb2.ChannelData type({})". - format(type(pbdata))) + pbdata.ecode = ecode + if pbdata.ecode != ChannelDataEcode.OK.value: + pbdata.error_info = error_info + logging.error(pbdata.error_info) + else: + raise ValueError("datatype not match") + if not isinstance(pbdata, channel_pb2.ChannelData): + raise TypeError( + "pbdata must be pyserving_channel_pb2.ChannelData type({})". + format(type(pbdata))) self.future = future self.pbdata = pbdata + self.npdata = npdata + self.datatype = datatype self.callback_func = callback_func + def _check_npdata(self, npdata): + ecode = ChannelDataEcode.OK.value + error_info = None + for name, value in npdata.items(): + if not isinstance(name, (str, unicode)): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = log("the key of postped_data must " \ + "be str, but get {}".format(type(name))) + break + if not isinstance(value, np.ndarray): + pbdata.ecode = ChannelDataEcode.TYPE_ERROR.value + pbdata.error_info = log("the value of postped_data must " \ + "be np.ndarray, but get {}".format(type(value))) + break + return ecode, error_info + def parse(self): # return narray - feed = {} - if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value: + feed = None + if self.datatype == ChannelDataType.CHANNEL_PBDATA.value: + feed = {} for inst in self.pbdata.insts: feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type) feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32") - elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value: + elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value: feed = self.future.result() if self.callback_func is not None: feed = self.callback_func(feed) + elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value: + feed = self.npdata else: raise TypeError("Error type({}) in pbdata.type.".format( self.pbdata.type)) return feed + def __str__(self): + return "type[{}], ecode[{}]".format( + ChannelDataType(self.datatype).name, self.pbdata.ecode) + class Channel(Queue.Queue): """ @@ -213,7 +274,7 @@ class Channel(Queue.Queue): def push(self, channeldata, op_name=None): logging.debug( self._log("{} try to push data: {}".format(op_name, - channeldata.pbdata))) + channeldata.__str__()))) if len(self._producers) == 0: raise Exception( self._log( @@ -287,7 +348,9 @@ class Channel(Queue.Queue): break except Queue.Empty: self._cv.wait() - logging.debug(self._log("{} get data succ!".format(op_name))) + logging.debug( + self._log("{} get data succ: {}".format(op_name, resp.__str__( + )))) return resp elif op_name is None: raise Exception( @@ -478,7 +541,10 @@ class Op(object): # error data in predecessor Op if error_pbdata is not None: - self._push_to_output_channels(ChannelData(pbdata=error_pbdata)) + self._push_to_output_channels( + ChannelData( + datatype=ChannelDataType.CHANNEL_PBDATA.value, + pbdata=error_pbdata)) continue # preprecess @@ -564,6 +630,7 @@ class Op(object): if self.with_serving(): # use call_future output_data = ChannelData( + datatype=ChannelDataType.CHANNEL_FUTURE.value, future=call_future, data_id=data_id, callback_func=self.postprocess) @@ -590,36 +657,10 @@ class Op(object): data_id=data_id)) continue - ecode = ChannelDataEcode.OK.value - error_info = None - pbdata = channel_pb2.ChannelData() - for name, value in postped_data.items(): - if not isinstance(name, (str, unicode)): - ecode = ChannelDataEcode.TYPE_ERROR.value - error_info = log("the key of postped_data must " \ - "be str, but get {}".format(type(name))) - break - if not isinstance(value, np.ndarray): - ecode = ChannelDataEcode.TYPE_ERROR.value - error_info = log("the value of postped_data must " \ - "be np.ndarray, but get {}".format(type(value))) - break - inst = channel_pb2.Inst() - inst.data = value.tobytes() - inst.name = name - inst.shape = np.array(value.shape, dtype="int32").tobytes() - inst.type = str(value.dtype) - pbdata.insts.append(inst) - if ecode != ChannelDataEcode.OK.value: - logging.error(error_info) - self._push_to_output_channels( - ChannelData( - ecode=ecode, error_info=error_info, - data_id=data_id)) - continue - pbdata.ecode = ecode - pbdata.id = data_id - output_data = ChannelData(pbdata=pbdata) + output_data = ChannelData( + ChannelDataType.CHANNEL_NPDATA.value, + npdata=postped_data, + data_id=data_id) _profiler.record("{}-postp_1".format(op_info_prefix)) # push data to channel (if run succ) @@ -750,34 +791,40 @@ class GeneralPythonService( pbdata = channel_pb2.ChannelData() data_id = self._get_next_id() pbdata.id = data_id - for idx, name in enumerate(request.feed_var_names): - logging.debug( - self._log('name: {}'.format(request.feed_var_names[idx]))) - logging.debug(self._log('data: {}'.format(request.feed_insts[idx]))) - inst = channel_pb2.Inst() - inst.data = request.feed_insts[idx] - inst.shape = request.shape[idx] - inst.name = name - inst.type = request.type[idx] - pbdata.insts.append(inst) - pbdata.ecode = ChannelDataEcode.OK.value #TODO: parse request error - return ChannelData(pbdata=pbdata), data_id + pbdata.ecode = ChannelDataEcode.OK.value + try: + for idx, name in enumerate(request.feed_var_names): + logging.debug( + self._log('name: {}'.format(request.feed_var_names[idx]))) + logging.debug( + self._log('data: {}'.format(request.feed_insts[idx]))) + inst = channel_pb2.Inst() + inst.data = request.feed_insts[idx] + inst.shape = request.shape[idx] + inst.name = name + inst.type = request.type[idx] + pbdata.insts.append(inst) + except Exception as e: + pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value + pbdata.error_info = "rpc package error" + return ChannelData( + datatype=ChannelDataType.CHANNEL_PBDATA.value, + pbdata=pbdata), data_id def _pack_data_for_resp(self, channeldata): logging.debug(self._log('get channeldata')) resp = pyservice_pb2.Response() resp.ecode = channeldata.pbdata.ecode if resp.ecode == ChannelDataEcode.OK.value: - if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value: + if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value: for inst in channeldata.pbdata.insts: resp.fetch_insts.append(inst.data) resp.fetch_var_names.append(inst.name) resp.shape.append(inst.shape) resp.type.append(inst.type) - elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value: - feed = channeldata.future.result() - if channeldata.callback_func is not None: - feed = channeldata.callback_func(feed) + elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value, + ChannelDataType.CHANNEL_NPDATA.value): + feed = channeldata.parse() for name, var in feed.items(): resp.fetch_insts.append(var.tobytes()) resp.fetch_var_names.append(name) @@ -788,7 +835,7 @@ class GeneralPythonService( else: raise TypeError( self._log("Error type({}) in pbdata.type.".format( - self.pbdata.type))) + channeldata.datatype))) else: resp.error_info = channeldata.pbdata.error_info return resp -- GitLab