diff --git a/python/paddle_serving_server/general_python_service.proto b/python/paddle_serving_server/general_python_service.proto index 5613558663f3f61b3ad8e9f1aed35cda6afa98d3..7f3af66df8d011b9a0a4fbcd9fb14a704f0c4bb2 100644 --- a/python/paddle_serving_server/general_python_service.proto +++ b/python/paddle_serving_server/general_python_service.proto @@ -26,4 +26,6 @@ message Request { message Response { repeated bytes fetch_insts = 1; repeated string fetch_var_names = 2; + required int32 is_error = 3; + optional string error_info = 4; } diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index dad372f3cceba6df6bf51e8c68d186c8e642e2ff..29ff56c422f59c9795e47ebe14402ff06634bf5b 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -28,6 +28,7 @@ import python_service_channel_pb2 import logging import random import time +import func_timeout class _TimeProfiler(object): @@ -265,7 +266,8 @@ class Op(object): client_config=None, server_name=None, fetch_names=None, - concurrency=1): + concurrency=1, + timeout=-1): self._run = False # TODO: globally unique check self._name = name # to identify the type of OP, it must be globally unique @@ -282,6 +284,7 @@ class Op(object): self._server_model = server_model self._server_port = server_port self._device = device + self._timeout = timeout def set_client(self, client_config, server_name, fetch_names): self._client = Client() @@ -344,6 +347,12 @@ class Op(object): 'Please override this method to convert data to the format in channel.' )) + def errorprocess(self, error_info): + data = python_service_channel_pb2.ChannelData() + data.is_error = 1 + data.error_info = error_info + return data + def stop(self): self._run = False @@ -354,35 +363,69 @@ class Op(object): input_data = self._input.front(self._name) _profiler.record("{}{}-get_1".format(self._name, concurrency_idx)) data_id = None + output_data = None + error_data = None logging.debug(self._log("input_data: {}".format(input_data))) if isinstance(input_data, dict): key = input_data.keys()[0] data_id = input_data[key].id + for _, data in input_data.items(): + if data.is_error != 0: + error_data = data + break else: data_id = input_data.id + if input_data.is_error != 0: + error_data = input_data - _profiler.record("{}{}-prep_0".format(self._name, concurrency_idx)) - data = self.preprocess(input_data) - _profiler.record("{}{}-prep_1".format(self._name, concurrency_idx)) - - if self.with_serving(): - _profiler.record("{}{}-midp_0".format(self._name, + if error_data is None: + _profiler.record("{}{}-prep_0".format(self._name, concurrency_idx)) - data = self.midprocess(data) - _profiler.record("{}{}-midp_1".format(self._name, + data = self.preprocess(input_data) + _profiler.record("{}{}-prep_1".format(self._name, concurrency_idx)) - _profiler.record("{}{}-postp_0".format(self._name, concurrency_idx)) - output_data = self.postprocess(data) - _profiler.record("{}{}-postp_1".format(self._name, concurrency_idx)) - - if not isinstance(output_data, - python_service_channel_pb2.ChannelData): - raise TypeError( - self._log( - 'output_data must be ChannelData type, but get {}'. - format(type(output_data)))) - output_data.id = data_id + error_info = None + if self.with_serving(): + _profiler.record("{}{}-midp_0".format(self._name, + concurrency_idx)) + if self._time > 0: + try: + data = func_timeout.func_timeout( + self._time, self.midprocess, args=(data, )) + except func_timeout.FunctionTimedOut: + logging.error("error: timeout") + error_info = "{}({}): timeout".format( + self._name, concurrency_idx) + except Exception as e: + logging.error("error: {}".format(e)) + error_info = "{}({}): {}".format(self._name, + concurrency_idx, e) + else: + data = self.midprocess(data) + _profiler.record("{}{}-midp_1".format(self._name, + concurrency_idx)) + + _profiler.record("{}{}-postp_0".format(self._name, + concurrency_idx)) + if error_info is not None: + output_data = self.errorprocess(error_info) + else: + output_data = self.postprocess(data) + + if not isinstance(output_data, + python_service_channel_pb2.ChannelData): + raise TypeError( + self._log( + 'output_data must be ChannelData type, but get {}'. + format(type(output_data)))) + output_data.is_error = 0 + _profiler.record("{}{}-postp_1".format(self._name, + concurrency_idx)) + + output_data.id = data_id + else: + output_data = error_data _profiler.record("{}{}-push_0".format(self._name, concurrency_idx)) for channel in self._outputs: @@ -398,7 +441,7 @@ class Op(object): class GeneralPythonService( general_python_service_pb2_grpc.GeneralPythonService): - def __init__(self, in_channel, out_channel): + def __init__(self, in_channel, out_channel, retry=2): super(GeneralPythonService, self).__init__() self._name = "#G" self.set_in_channel(in_channel) @@ -412,6 +455,7 @@ class GeneralPythonService( self._cv = threading.Condition() self._globel_resp_dict = {} self._id_counter = 0 + self._retry = retry self._recive_func = threading.Thread( target=GeneralPythonService._recive_out_channel_func, args=(self, )) self._recive_func.start() @@ -480,11 +524,15 @@ class GeneralPythonService( resp = general_python_service_pb2.Response() logging.debug(self._log('gen resp')) logging.debug(data) - for inst in data.insts: - logging.debug(self._log('append data')) - resp.fetch_insts.append(inst.data) - logging.debug(self._log('append name')) - resp.fetch_var_names.append(inst.name) + resp.is_error = data.is_error + if resp.is_error == 0: + for inst in data.insts: + logging.debug(self._log('append data')) + resp.fetch_insts.append(inst.data) + logging.debug(self._log('append name')) + resp.fetch_var_names.append(inst.name) + else: + resp.error_info = data.error_info return resp def inference(self, request, context): @@ -492,16 +540,21 @@ class GeneralPythonService( data, data_id = self._pack_data_for_infer(request) _profiler.record("{}-prepack_1".format(self._name)) - logging.debug(self._log('push data')) - _profiler.record("{}-push_0".format(self._name)) - self._in_channel.push(data, self._name) - _profiler.record("{}-push_1".format(self._name)) + for i in range(self._retry): + logging.debug(self._log('push data')) + _profiler.record("{}-push_0".format(self._name)) + self._in_channel.push(data, self._name) + _profiler.record("{}-push_1".format(self._name)) + + logging.debug(self._log('wait for infer')) + resp_data = None + _profiler.record("{}-fetch_0".format(self._name)) + resp_data = self._get_data_in_globel_resp_dict(data_id) + _profiler.record("{}-fetch_1".format(self._name)) - logging.debug(self._log('wait for infer')) - resp_data = None - _profiler.record("{}-fetch_0".format(self._name)) - resp_data = self._get_data_in_globel_resp_dict(data_id) - _profiler.record("{}-fetch_1".format(self._name)) + if resp_data.is_error == 0: + break + logging.warn("retry({}): {}".format(i + 1, resp_data.error_info)) _profiler.record("{}-postpack_0".format(self._name)) resp = self._pack_data_for_resp(resp_data) @@ -511,7 +564,7 @@ class GeneralPythonService( class PyServer(object): - def __init__(self, profile=False): + def __init__(self, retry=2, profile=False): self._channels = [] self._ops = [] self._op_threads = [] @@ -519,6 +572,7 @@ class PyServer(object): self._worker_num = None self._in_channel = None self._out_channel = None + self._retry = retry _profiler.enable(profile) def add_channel(self, channel): @@ -571,7 +625,8 @@ class PyServer(object): server = grpc.server( futures.ThreadPoolExecutor(max_workers=self._worker_num)) general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server( - GeneralPythonService(self._in_channel, self._out_channel), server) + GeneralPythonService(self._in_channel, self._out_channel, + self._retry), server) server.add_insecure_port('[::]:{}'.format(self._port)) server.start() try: diff --git a/python/paddle_serving_server/python_service_channel.proto b/python/paddle_serving_server/python_service_channel.proto index e61a17b9c528d99f5711247ecff130410df06e76..76a0d99c5cfb9f34e7478e66a89c416f135b73c1 100644 --- a/python/paddle_serving_server/python_service_channel.proto +++ b/python/paddle_serving_server/python_service_channel.proto @@ -19,6 +19,8 @@ message ChannelData { required int32 id = 2; optional string type = 3 [ default = "CD" ]; // CD(channel data), CF(channel futures) + required int32 is_error = 4; + optional string error_info = 5; } message Inst {