提交 45d47d4a 编写于 作者: B barrierye

add timeout

上级 96472347
...@@ -26,4 +26,6 @@ message Request { ...@@ -26,4 +26,6 @@ message Request {
message Response { message Response {
repeated bytes fetch_insts = 1; repeated bytes fetch_insts = 1;
repeated string fetch_var_names = 2; repeated string fetch_var_names = 2;
required int32 is_error = 3;
optional string error_info = 4;
} }
...@@ -28,6 +28,7 @@ import python_service_channel_pb2 ...@@ -28,6 +28,7 @@ import python_service_channel_pb2
import logging import logging
import random import random
import time import time
import func_timeout
class _TimeProfiler(object): class _TimeProfiler(object):
...@@ -265,7 +266,8 @@ class Op(object): ...@@ -265,7 +266,8 @@ class Op(object):
client_config=None, client_config=None,
server_name=None, server_name=None,
fetch_names=None, fetch_names=None,
concurrency=1): concurrency=1,
timeout=-1):
self._run = False self._run = False
# TODO: globally unique check # TODO: globally unique check
self._name = name # to identify the type of OP, it must be globally unique self._name = name # to identify the type of OP, it must be globally unique
...@@ -282,6 +284,7 @@ class Op(object): ...@@ -282,6 +284,7 @@ class Op(object):
self._server_model = server_model self._server_model = server_model
self._server_port = server_port self._server_port = server_port
self._device = device self._device = device
self._timeout = timeout
def set_client(self, client_config, server_name, fetch_names): def set_client(self, client_config, server_name, fetch_names):
self._client = Client() self._client = Client()
...@@ -344,6 +347,12 @@ class Op(object): ...@@ -344,6 +347,12 @@ class Op(object):
'Please override this method to convert data to the format in channel.' 'Please override this method to convert data to the format in channel.'
)) ))
def errorprocess(self, error_info):
data = python_service_channel_pb2.ChannelData()
data.is_error = 1
data.error_info = error_info
return data
def stop(self): def stop(self):
self._run = False self._run = False
...@@ -354,27 +363,55 @@ class Op(object): ...@@ -354,27 +363,55 @@ class Op(object):
input_data = self._input.front(self._name) input_data = self._input.front(self._name)
_profiler.record("{}{}-get_1".format(self._name, concurrency_idx)) _profiler.record("{}{}-get_1".format(self._name, concurrency_idx))
data_id = None data_id = None
output_data = None
error_data = None
logging.debug(self._log("input_data: {}".format(input_data))) logging.debug(self._log("input_data: {}".format(input_data)))
if isinstance(input_data, dict): if isinstance(input_data, dict):
key = input_data.keys()[0] key = input_data.keys()[0]
data_id = input_data[key].id data_id = input_data[key].id
for _, data in input_data.items():
if data.is_error != 0:
error_data = data
break
else: else:
data_id = input_data.id data_id = input_data.id
if input_data.is_error != 0:
error_data = input_data
_profiler.record("{}{}-prep_0".format(self._name, concurrency_idx)) if error_data is None:
_profiler.record("{}{}-prep_0".format(self._name,
concurrency_idx))
data = self.preprocess(input_data) data = self.preprocess(input_data)
_profiler.record("{}{}-prep_1".format(self._name, concurrency_idx)) _profiler.record("{}{}-prep_1".format(self._name,
concurrency_idx))
error_info = None
if self.with_serving(): if self.with_serving():
_profiler.record("{}{}-midp_0".format(self._name, _profiler.record("{}{}-midp_0".format(self._name,
concurrency_idx)) concurrency_idx))
if self._time > 0:
try:
data = func_timeout.func_timeout(
self._time, self.midprocess, args=(data, ))
except func_timeout.FunctionTimedOut:
logging.error("error: timeout")
error_info = "{}({}): timeout".format(
self._name, concurrency_idx)
except Exception as e:
logging.error("error: {}".format(e))
error_info = "{}({}): {}".format(self._name,
concurrency_idx, e)
else:
data = self.midprocess(data) data = self.midprocess(data)
_profiler.record("{}{}-midp_1".format(self._name, _profiler.record("{}{}-midp_1".format(self._name,
concurrency_idx)) concurrency_idx))
_profiler.record("{}{}-postp_0".format(self._name, concurrency_idx)) _profiler.record("{}{}-postp_0".format(self._name,
concurrency_idx))
if error_info is not None:
output_data = self.errorprocess(error_info)
else:
output_data = self.postprocess(data) output_data = self.postprocess(data)
_profiler.record("{}{}-postp_1".format(self._name, concurrency_idx))
if not isinstance(output_data, if not isinstance(output_data,
python_service_channel_pb2.ChannelData): python_service_channel_pb2.ChannelData):
...@@ -382,7 +419,13 @@ class Op(object): ...@@ -382,7 +419,13 @@ class Op(object):
self._log( self._log(
'output_data must be ChannelData type, but get {}'. 'output_data must be ChannelData type, but get {}'.
format(type(output_data)))) format(type(output_data))))
output_data.is_error = 0
_profiler.record("{}{}-postp_1".format(self._name,
concurrency_idx))
output_data.id = data_id output_data.id = data_id
else:
output_data = error_data
_profiler.record("{}{}-push_0".format(self._name, concurrency_idx)) _profiler.record("{}{}-push_0".format(self._name, concurrency_idx))
for channel in self._outputs: for channel in self._outputs:
...@@ -398,7 +441,7 @@ class Op(object): ...@@ -398,7 +441,7 @@ class Op(object):
class GeneralPythonService( class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService): general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel): def __init__(self, in_channel, out_channel, retry=2):
super(GeneralPythonService, self).__init__() super(GeneralPythonService, self).__init__()
self._name = "#G" self._name = "#G"
self.set_in_channel(in_channel) self.set_in_channel(in_channel)
...@@ -412,6 +455,7 @@ class GeneralPythonService( ...@@ -412,6 +455,7 @@ class GeneralPythonService(
self._cv = threading.Condition() self._cv = threading.Condition()
self._globel_resp_dict = {} self._globel_resp_dict = {}
self._id_counter = 0 self._id_counter = 0
self._retry = retry
self._recive_func = threading.Thread( self._recive_func = threading.Thread(
target=GeneralPythonService._recive_out_channel_func, args=(self, )) target=GeneralPythonService._recive_out_channel_func, args=(self, ))
self._recive_func.start() self._recive_func.start()
...@@ -480,11 +524,15 @@ class GeneralPythonService( ...@@ -480,11 +524,15 @@ class GeneralPythonService(
resp = general_python_service_pb2.Response() resp = general_python_service_pb2.Response()
logging.debug(self._log('gen resp')) logging.debug(self._log('gen resp'))
logging.debug(data) logging.debug(data)
resp.is_error = data.is_error
if resp.is_error == 0:
for inst in data.insts: for inst in data.insts:
logging.debug(self._log('append data')) logging.debug(self._log('append data'))
resp.fetch_insts.append(inst.data) resp.fetch_insts.append(inst.data)
logging.debug(self._log('append name')) logging.debug(self._log('append name'))
resp.fetch_var_names.append(inst.name) resp.fetch_var_names.append(inst.name)
else:
resp.error_info = data.error_info
return resp return resp
def inference(self, request, context): def inference(self, request, context):
...@@ -492,6 +540,7 @@ class GeneralPythonService( ...@@ -492,6 +540,7 @@ class GeneralPythonService(
data, data_id = self._pack_data_for_infer(request) data, data_id = self._pack_data_for_infer(request)
_profiler.record("{}-prepack_1".format(self._name)) _profiler.record("{}-prepack_1".format(self._name))
for i in range(self._retry):
logging.debug(self._log('push data')) logging.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self._name)) _profiler.record("{}-push_0".format(self._name))
self._in_channel.push(data, self._name) self._in_channel.push(data, self._name)
...@@ -503,6 +552,10 @@ class GeneralPythonService( ...@@ -503,6 +552,10 @@ class GeneralPythonService(
resp_data = self._get_data_in_globel_resp_dict(data_id) resp_data = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self._name)) _profiler.record("{}-fetch_1".format(self._name))
if resp_data.is_error == 0:
break
logging.warn("retry({}): {}".format(i + 1, resp_data.error_info))
_profiler.record("{}-postpack_0".format(self._name)) _profiler.record("{}-postpack_0".format(self._name))
resp = self._pack_data_for_resp(resp_data) resp = self._pack_data_for_resp(resp_data)
_profiler.record("{}-postpack_1".format(self._name)) _profiler.record("{}-postpack_1".format(self._name))
...@@ -511,7 +564,7 @@ class GeneralPythonService( ...@@ -511,7 +564,7 @@ class GeneralPythonService(
class PyServer(object): class PyServer(object):
def __init__(self, profile=False): def __init__(self, retry=2, profile=False):
self._channels = [] self._channels = []
self._ops = [] self._ops = []
self._op_threads = [] self._op_threads = []
...@@ -519,6 +572,7 @@ class PyServer(object): ...@@ -519,6 +572,7 @@ class PyServer(object):
self._worker_num = None self._worker_num = None
self._in_channel = None self._in_channel = None
self._out_channel = None self._out_channel = None
self._retry = retry
_profiler.enable(profile) _profiler.enable(profile)
def add_channel(self, channel): def add_channel(self, channel):
...@@ -571,7 +625,8 @@ class PyServer(object): ...@@ -571,7 +625,8 @@ class PyServer(object):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server( general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
GeneralPythonService(self._in_channel, self._out_channel), server) GeneralPythonService(self._in_channel, self._out_channel,
self._retry), server)
server.add_insecure_port('[::]:{}'.format(self._port)) server.add_insecure_port('[::]:{}'.format(self._port))
server.start() server.start()
try: try:
......
...@@ -19,6 +19,8 @@ message ChannelData { ...@@ -19,6 +19,8 @@ message ChannelData {
required int32 id = 2; required int32 id = 2;
optional string type = 3 optional string type = 3
[ default = "CD" ]; // CD(channel data), CF(channel futures) [ default = "CD" ]; // CD(channel data), CF(channel futures)
required int32 is_error = 4;
optional string error_info = 5;
} }
message Inst { message Inst {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册