提交 45d47d4a 编写于 作者: B barrierye

add timeout

上级 96472347
......@@ -26,4 +26,6 @@ message Request {
message Response {
repeated bytes fetch_insts = 1;
repeated string fetch_var_names = 2;
required int32 is_error = 3;
optional string error_info = 4;
}
......@@ -28,6 +28,7 @@ import python_service_channel_pb2
import logging
import random
import time
import func_timeout
class _TimeProfiler(object):
......@@ -265,7 +266,8 @@ class Op(object):
client_config=None,
server_name=None,
fetch_names=None,
concurrency=1):
concurrency=1,
timeout=-1):
self._run = False
# TODO: globally unique check
self._name = name # to identify the type of OP, it must be globally unique
......@@ -282,6 +284,7 @@ class Op(object):
self._server_model = server_model
self._server_port = server_port
self._device = device
self._timeout = timeout
def set_client(self, client_config, server_name, fetch_names):
self._client = Client()
......@@ -344,6 +347,12 @@ class Op(object):
'Please override this method to convert data to the format in channel.'
))
def errorprocess(self, error_info):
data = python_service_channel_pb2.ChannelData()
data.is_error = 1
data.error_info = error_info
return data
def stop(self):
self._run = False
......@@ -354,35 +363,69 @@ class Op(object):
input_data = self._input.front(self._name)
_profiler.record("{}{}-get_1".format(self._name, concurrency_idx))
data_id = None
output_data = None
error_data = None
logging.debug(self._log("input_data: {}".format(input_data)))
if isinstance(input_data, dict):
key = input_data.keys()[0]
data_id = input_data[key].id
for _, data in input_data.items():
if data.is_error != 0:
error_data = data
break
else:
data_id = input_data.id
if input_data.is_error != 0:
error_data = input_data
_profiler.record("{}{}-prep_0".format(self._name, concurrency_idx))
data = self.preprocess(input_data)
_profiler.record("{}{}-prep_1".format(self._name, concurrency_idx))
if self.with_serving():
_profiler.record("{}{}-midp_0".format(self._name,
if error_data is None:
_profiler.record("{}{}-prep_0".format(self._name,
concurrency_idx))
data = self.midprocess(data)
_profiler.record("{}{}-midp_1".format(self._name,
data = self.preprocess(input_data)
_profiler.record("{}{}-prep_1".format(self._name,
concurrency_idx))
_profiler.record("{}{}-postp_0".format(self._name, concurrency_idx))
output_data = self.postprocess(data)
_profiler.record("{}{}-postp_1".format(self._name, concurrency_idx))
if not isinstance(output_data,
python_service_channel_pb2.ChannelData):
raise TypeError(
self._log(
'output_data must be ChannelData type, but get {}'.
format(type(output_data))))
output_data.id = data_id
error_info = None
if self.with_serving():
_profiler.record("{}{}-midp_0".format(self._name,
concurrency_idx))
if self._time > 0:
try:
data = func_timeout.func_timeout(
self._time, self.midprocess, args=(data, ))
except func_timeout.FunctionTimedOut:
logging.error("error: timeout")
error_info = "{}({}): timeout".format(
self._name, concurrency_idx)
except Exception as e:
logging.error("error: {}".format(e))
error_info = "{}({}): {}".format(self._name,
concurrency_idx, e)
else:
data = self.midprocess(data)
_profiler.record("{}{}-midp_1".format(self._name,
concurrency_idx))
_profiler.record("{}{}-postp_0".format(self._name,
concurrency_idx))
if error_info is not None:
output_data = self.errorprocess(error_info)
else:
output_data = self.postprocess(data)
if not isinstance(output_data,
python_service_channel_pb2.ChannelData):
raise TypeError(
self._log(
'output_data must be ChannelData type, but get {}'.
format(type(output_data))))
output_data.is_error = 0
_profiler.record("{}{}-postp_1".format(self._name,
concurrency_idx))
output_data.id = data_id
else:
output_data = error_data
_profiler.record("{}{}-push_0".format(self._name, concurrency_idx))
for channel in self._outputs:
......@@ -398,7 +441,7 @@ class Op(object):
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel):
def __init__(self, in_channel, out_channel, retry=2):
super(GeneralPythonService, self).__init__()
self._name = "#G"
self.set_in_channel(in_channel)
......@@ -412,6 +455,7 @@ class GeneralPythonService(
self._cv = threading.Condition()
self._globel_resp_dict = {}
self._id_counter = 0
self._retry = retry
self._recive_func = threading.Thread(
target=GeneralPythonService._recive_out_channel_func, args=(self, ))
self._recive_func.start()
......@@ -480,11 +524,15 @@ class GeneralPythonService(
resp = general_python_service_pb2.Response()
logging.debug(self._log('gen resp'))
logging.debug(data)
for inst in data.insts:
logging.debug(self._log('append data'))
resp.fetch_insts.append(inst.data)
logging.debug(self._log('append name'))
resp.fetch_var_names.append(inst.name)
resp.is_error = data.is_error
if resp.is_error == 0:
for inst in data.insts:
logging.debug(self._log('append data'))
resp.fetch_insts.append(inst.data)
logging.debug(self._log('append name'))
resp.fetch_var_names.append(inst.name)
else:
resp.error_info = data.error_info
return resp
def inference(self, request, context):
......@@ -492,16 +540,21 @@ class GeneralPythonService(
data, data_id = self._pack_data_for_infer(request)
_profiler.record("{}-prepack_1".format(self._name))
logging.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self._name))
self._in_channel.push(data, self._name)
_profiler.record("{}-push_1".format(self._name))
for i in range(self._retry):
logging.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self._name))
self._in_channel.push(data, self._name)
_profiler.record("{}-push_1".format(self._name))
logging.debug(self._log('wait for infer'))
resp_data = None
_profiler.record("{}-fetch_0".format(self._name))
resp_data = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self._name))
logging.debug(self._log('wait for infer'))
resp_data = None
_profiler.record("{}-fetch_0".format(self._name))
resp_data = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self._name))
if resp_data.is_error == 0:
break
logging.warn("retry({}): {}".format(i + 1, resp_data.error_info))
_profiler.record("{}-postpack_0".format(self._name))
resp = self._pack_data_for_resp(resp_data)
......@@ -511,7 +564,7 @@ class GeneralPythonService(
class PyServer(object):
def __init__(self, profile=False):
def __init__(self, retry=2, profile=False):
self._channels = []
self._ops = []
self._op_threads = []
......@@ -519,6 +572,7 @@ class PyServer(object):
self._worker_num = None
self._in_channel = None
self._out_channel = None
self._retry = retry
_profiler.enable(profile)
def add_channel(self, channel):
......@@ -571,7 +625,8 @@ class PyServer(object):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
GeneralPythonService(self._in_channel, self._out_channel), server)
GeneralPythonService(self._in_channel, self._out_channel,
self._retry), server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
try:
......
......@@ -19,6 +19,8 @@ message ChannelData {
required int32 id = 2;
optional string type = 3
[ default = "CD" ]; // CD(channel data), CF(channel futures)
required int32 is_error = 4;
optional string error_info = 5;
}
message Inst {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册