diff --git a/python/examples/imdb/test_py_server.py b/python/examples/imdb/test_py_server.py index 28a8796d74f57ee0b97781c0b8857cd24df92dcf..e2e8de97987dfcda13ecf9de36653b647e10acb5 100644 --- a/python/examples/imdb/test_py_server.py +++ b/python/examples/imdb/test_py_server.py @@ -27,6 +27,8 @@ logging.basicConfig( class CombineOp(Op): + pass + ''' def preprocess(self, input_data): combined_prediction = 0 for op_name, channeldata in input_data.items(): @@ -35,6 +37,7 @@ class CombineOp(Op): combined_prediction += data["prediction"] data = {"combined_prediction": combined_prediction / 2} return data + ''' read_op = Op(name="read", inputs=None) @@ -47,7 +50,7 @@ bow_op = Op(name="bow", server_name="127.0.0.1:9393", fetch_names=["prediction"], concurrency=1, - timeout=0.01, + timeout=0.1, retry=2) cnn_op = Op(name="cnn", inputs=[read_op], diff --git a/python/paddle_serving_client/pyclient.py b/python/paddle_serving_client/pyclient.py index 38c96705902ea23462b2fa415102b4a1abf00721..90682e2a43a86830fa78c56908dfe6e8b014717d 100644 --- a/python/paddle_serving_client/pyclient.py +++ b/python/paddle_serving_client/pyclient.py @@ -49,6 +49,8 @@ class PyClient(object): "fetch_with_type must be list type with format: [name].") req = self._pack_data_for_infer(feed) resp = self._stub.inference(req) + if resp.ecode != 0: + raise Exception(resp.error_info) fetch_map = {} for idx, name in enumerate(resp.fetch_var_names): if name not in fetch: diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index bf781d37f69bf54aac2e7688bf1b379fcd180e85..9f77b8506c82ed1b32728eacf884347f19dfe608 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -77,6 +77,9 @@ _profiler = _TimeProfiler() class ChannelDataEcode(enum.Enum): OK = 0 TIMEOUT = 1 + NOT_IMPLEMENTED = 2 + TYPE_ERROR = 3 + UNKNOW = 4 class ChannelDataType(enum.Enum): @@ -89,15 +92,37 @@ class ChannelData(object): future=None, pbdata=None, data_id=None, - callback_func=None): - self.future = future - if pbdata is None: - if data_id is None: - raise ValueError("data_id cannot be None") + callback_func=None, + ecode=None, + error_info=None): + ''' + There are several ways to use it: + + - ChannelData(future, pbdata[, callback_func]) + - ChannelData(future, data_id[, callback_func]) + - ChannelData(pbdata) + - ChannelData(ecode, error_info, data_id) + ''' + if ecode is not None: + if data_id is None or error_info is None: + raise ValueError("data_id and error_info cannot be None") pbdata = channel_pb2.ChannelData() - pbdata.type = ChannelDataType.CHANNEL_FUTURE.value - pbdata.ecode = ChannelDataEcode.OK.value + pbdata.ecode = ecode pbdata.id = data_id + pbdata.error_info = error_info + else: + if pbdata is None: + if data_id is None: + raise ValueError("data_id cannot be None") + pbdata = channel_pb2.ChannelData() + pbdata.type = ChannelDataType.CHANNEL_FUTURE.value + pbdata.ecode = ChannelDataEcode.OK.value + pbdata.id = data_id + elif not isinstance(pbdata, channel_pb2.ChannelData): + raise TypeError( + "pbdata must be pyserving_channel_pb2.ChannelData type({})". + format(type(pbdata))) + self.future = future self.pbdata = pbdata self.callback_func = callback_func @@ -389,10 +414,9 @@ class Op(object): def preprocess(self, channeldata): if isinstance(channeldata, dict): - raise Exception( - self._log( - 'this Op has multiple previous inputs. Please override this method' - )) + raise NotImplementedError( + 'this Op has multiple previous inputs. Please override this method' + ) feed = channeldata.parse() return feed @@ -412,13 +436,6 @@ class Op(object): def postprocess(self, output_data): return output_data - def errorprocess(self, error_info, data_id): - data = channel_pb2.ChannelData() - data.ecode = 1 - data.id = data_id - data.error_info = error_info - return data - def stop(self): self._input.stop() for channel in self._outputs: @@ -433,7 +450,7 @@ class Op(object): data_id = channeldata[key].pbdata.id for _, data in channeldata.items(): if data.pbdata.ecode != 0: - error_data = data + error_data = data.pbdata break else: data_id = channeldata.pbdata.id @@ -441,101 +458,130 @@ class Op(object): error_data = channeldata.pbdata return data_id, error_data + def _push_to_output_channels(self, data): + for channel in self._outputs: + channel.push(data, self.name) + def start(self, concurrency_idx): + op_info_prefix = "[{}{}]".format(self.name, concurrency_idx) + log = self._get_log_func(op_info_prefix) self._run = True while self._run: - _profiler.record("{}{}-get_0".format(self.name, concurrency_idx)) + _profiler.record("{}-get_0".format(op_info_prefix)) input_data = self._input.front(self.name) - _profiler.record("{}{}-get_1".format(self.name, concurrency_idx)) - logging.debug(self._log("input_data: {}".format(input_data))) + _profiler.record("{}-get_1".format(op_info_prefix)) + logging.debug(log("input_data: {}".format(input_data))) data_id, error_data = self._parse_channeldata(input_data) - output_data = None - if error_data is None: - _profiler.record("{}{}-prep_0".format(self.name, - concurrency_idx)) + # predecessor Op error + if error_data is not None: + self._push_to_output_channels(ChannelData(pbdata=error_data)) + continue + + # preprocess function not implemented + try: + _profiler.record("{}-prep_0".format(op_info_prefix)) data = self.preprocess(input_data) - _profiler.record("{}{}-prep_1".format(self.name, - concurrency_idx)) - - call_future = None - error_info = None - if self.with_serving(): - _profiler.record("{}{}-midp_0".format(self.name, - concurrency_idx)) - if self._timeout > 0: - for i in range(self._retry): - try: - call_future = func_timeout.func_timeout( - self._timeout, - self.midprocess, - args=(data, )) - except func_timeout.FunctionTimedOut: - logging.error("error: timeout") - error_info = "{}({}): timeout".format( - self.name, concurrency_idx) - if i + 1 < self._retry: - error_info = None - logging.warn( - self._log("warn: timeout, retry({})". - format(i + 1))) - except Exception as e: - logging.error("error: {}".format(e)) - error_info = "{}({}): {}".format( - self.name, concurrency_idx, e) - logging.warn(self._log(e)) - # TODO - break - else: - break - else: - call_future = self.midprocess(data) + _profiler.record("{}-prep_1".format(op_info_prefix)) + except NotImplementedError as e: + error_info = log(e) + logging.error(error_info) + self._push_to_output_channels( + ChannelData( + ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, + error_info=error_info, + data_id=data_id)) + continue - _profiler.record("{}{}-midp_1".format(self.name, - concurrency_idx)) - _profiler.record("{}{}-postp_0".format(self.name, - concurrency_idx)) - if error_info is not None: - error_data = self.errorprocess(error_info, data_id) - output_data = ChannelData(pbdata=error_data) + # midprocess + call_future = None + ecode = 0 + error_info = None + if self.with_serving(): + _profiler.record("{}-midp_0".format(op_info_prefix)) + if self._timeout <= 0: + try: + call_future = self.midprocess(data) + except Exception as e: + logging.error(self._log(e)) + ecode = ChannelDataEcode.UNKNOW.value + error_info = log(e) + logging.error(error_info) else: - if self.with_serving(): # use call_future - output_data = ChannelData( - future=call_future, - data_id=data_id, - callback_func=self.postprocess) - else: - post_data = self.postprocess(data) - if not isinstance(post_data, dict): - raise TypeError( - self._log( - 'output_data must be dict type, but get {}'. - format(type(output_data)))) - pbdata = channel_pb2.ChannelData() - for name, value in post_data.items(): - inst = channel_pb2.Inst() - inst.data = value.tobytes() - inst.name = name - inst.shape = np.array( - value.shape, dtype="int32").tobytes() - inst.type = str(value.dtype) - pbdata.insts.append(inst) - pbdata.ecode = 0 - pbdata.id = data_id - output_data = ChannelData(pbdata=pbdata) - _profiler.record("{}{}-postp_1".format(self.name, - concurrency_idx)) - else: - output_data = ChannelData(pbdata=error_data) - - _profiler.record("{}{}-push_0".format(self.name, concurrency_idx)) - for channel in self._outputs: - channel.push(output_data, self.name) - _profiler.record("{}{}-push_1".format(self.name, concurrency_idx)) + for i in range(self._retry): + try: + call_future = func_timeout.func_timeout( + self._timeout, self.midprocess, args=(data, )) + except func_timeout.FunctionTimedOut: + if i + 1 >= self._retry: + ecode = ChannelDataEcode.TIMEOUT.value + error_info = "{} timeout".format(op_info_prefix) + else: + logging.warn( + log("warn: timeout, retry({})".format(i + + 1))) + except Exception as e: + ecode = ChannelDataEcode.UNKNOW.value + error_info = log(e) + logging.error(error_info) + break + else: + break + if ecode != 0: + self._push_to_output_channels( + ChannelData( + ecode=ecode, error_info=error_info, + data_id=data_id)) + continue + _profiler.record("{}-midp_1".format(op_info_prefix)) - def _log(self, info_str): - return "[{}] {}".format(self.name, info_str) + # postprocess + output_data = None + _profiler.record("{}-postp_0".format(op_info_prefix)) + if self.with_serving(): # use call_future + output_data = ChannelData( + future=call_future, + data_id=data_id, + callback_func=self.postprocess) + else: + post_data = self.postprocess(data) + if not isinstance(post_data, dict): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = log("output of postprocess funticon must be " \ + "dict type, but get {}".format(type(post_data))) + logging.error(error_info) + self._push_to_output_channels( + ChannelData( + ecode=ecode, error_info=error_info, + data_id=data_id)) + continue + pbdata = channel_pb2.ChannelData() + for name, value in post_data.items(): + inst = channel_pb2.Inst() + inst.data = value.tobytes() + inst.name = name + inst.shape = np.array(value.shape, dtype="int32").tobytes() + inst.type = str(value.dtype) + pbdata.insts.append(inst) + pbdata.ecode = 0 + pbdata.id = data_id + output_data = ChannelData(pbdata=pbdata) + _profiler.record("{}-postp_1".format(op_info_prefix)) + + # push data to channel (if run succ) + _profiler.record("{}-push_0".format(op_info_prefix)) + self._push_to_output_channels(output_data) + _profiler.record("{}-push_1".format(op_info_prefix)) + + def _log(self, info): + return "{} {}".format(self.name, info) + + def _get_log_func(self, op_info_prefix): + def log_func(info_str): + return "{} {}".format(op_info_prefix, info_str) + + return log_func def get_concurrency(self): return self._concurrency @@ -682,8 +728,9 @@ class GeneralPythonService( if resp_channeldata.pbdata.ecode == 0: break - logging.warn("retry({}): {}".format( - i + 1, resp_channeldata.pbdata.error_info)) + if i + 1 < self._retry: + logging.warn("retry({}): {}".format( + i + 1, resp_channeldata.pbdata.error_info)) _profiler.record("{}-postpack_0".format(self.name)) resp = self._pack_data_for_resp(resp_channeldata) @@ -874,8 +921,6 @@ class PyServer(object): server.start() server.wait_for_termination() self._stop_ops() # TODO - for th in self._op_threads: - th.join() def prepare_serving(self, op): model_path = op._server_model