提交 a2a8409b 编写于 作者: B barrierye

add some error code

上级 5de538d5
......@@ -27,6 +27,8 @@ logging.basicConfig(
class CombineOp(Op):
pass
'''
def preprocess(self, input_data):
combined_prediction = 0
for op_name, channeldata in input_data.items():
......@@ -35,6 +37,7 @@ class CombineOp(Op):
combined_prediction += data["prediction"]
data = {"combined_prediction": combined_prediction / 2}
return data
'''
read_op = Op(name="read", inputs=None)
......@@ -47,7 +50,7 @@ bow_op = Op(name="bow",
server_name="127.0.0.1:9393",
fetch_names=["prediction"],
concurrency=1,
timeout=0.01,
timeout=0.1,
retry=2)
cnn_op = Op(name="cnn",
inputs=[read_op],
......
......@@ -49,6 +49,8 @@ class PyClient(object):
"fetch_with_type must be list type with format: [name].")
req = self._pack_data_for_infer(feed)
resp = self._stub.inference(req)
if resp.ecode != 0:
raise Exception(resp.error_info)
fetch_map = {}
for idx, name in enumerate(resp.fetch_var_names):
if name not in fetch:
......
......@@ -77,6 +77,9 @@ _profiler = _TimeProfiler()
class ChannelDataEcode(enum.Enum):
OK = 0
TIMEOUT = 1
NOT_IMPLEMENTED = 2
TYPE_ERROR = 3
UNKNOW = 4
class ChannelDataType(enum.Enum):
......@@ -89,15 +92,37 @@ class ChannelData(object):
future=None,
pbdata=None,
data_id=None,
callback_func=None):
self.future = future
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
callback_func=None,
ecode=None,
error_info=None):
'''
There are several ways to use it:
- ChannelData(future, pbdata[, callback_func])
- ChannelData(future, data_id[, callback_func])
- ChannelData(pbdata)
- ChannelData(ecode, error_info, data_id)
'''
if ecode is not None:
if data_id is None or error_info is None:
raise ValueError("data_id and error_info cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.ecode = ecode
pbdata.id = data_id
pbdata.error_info = error_info
else:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id
elif not isinstance(pbdata, channel_pb2.ChannelData):
raise TypeError(
"pbdata must be pyserving_channel_pb2.ChannelData type({})".
format(type(pbdata)))
self.future = future
self.pbdata = pbdata
self.callback_func = callback_func
......@@ -389,10 +414,9 @@ class Op(object):
def preprocess(self, channeldata):
if isinstance(channeldata, dict):
raise Exception(
self._log(
'this Op has multiple previous inputs. Please override this method'
))
raise NotImplementedError(
'this Op has multiple previous inputs. Please override this method'
)
feed = channeldata.parse()
return feed
......@@ -412,13 +436,6 @@ class Op(object):
def postprocess(self, output_data):
return output_data
def errorprocess(self, error_info, data_id):
data = channel_pb2.ChannelData()
data.ecode = 1
data.id = data_id
data.error_info = error_info
return data
def stop(self):
self._input.stop()
for channel in self._outputs:
......@@ -433,7 +450,7 @@ class Op(object):
data_id = channeldata[key].pbdata.id
for _, data in channeldata.items():
if data.pbdata.ecode != 0:
error_data = data
error_data = data.pbdata
break
else:
data_id = channeldata.pbdata.id
......@@ -441,101 +458,130 @@ class Op(object):
error_data = channeldata.pbdata
return data_id, error_data
def _push_to_output_channels(self, data):
for channel in self._outputs:
channel.push(data, self.name)
def start(self, concurrency_idx):
op_info_prefix = "[{}{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
_profiler.record("{}{}-get_0".format(self.name, concurrency_idx))
_profiler.record("{}-get_0".format(op_info_prefix))
input_data = self._input.front(self.name)
_profiler.record("{}{}-get_1".format(self.name, concurrency_idx))
logging.debug(self._log("input_data: {}".format(input_data)))
_profiler.record("{}-get_1".format(op_info_prefix))
logging.debug(log("input_data: {}".format(input_data)))
data_id, error_data = self._parse_channeldata(input_data)
output_data = None
if error_data is None:
_profiler.record("{}{}-prep_0".format(self.name,
concurrency_idx))
# predecessor Op error
if error_data is not None:
self._push_to_output_channels(ChannelData(pbdata=error_data))
continue
# preprocess function not implemented
try:
_profiler.record("{}-prep_0".format(op_info_prefix))
data = self.preprocess(input_data)
_profiler.record("{}{}-prep_1".format(self.name,
concurrency_idx))
call_future = None
error_info = None
if self.with_serving():
_profiler.record("{}{}-midp_0".format(self.name,
concurrency_idx))
if self._timeout > 0:
for i in range(self._retry):
try:
call_future = func_timeout.func_timeout(
self._timeout,
self.midprocess,
args=(data, ))
except func_timeout.FunctionTimedOut:
logging.error("error: timeout")
error_info = "{}({}): timeout".format(
self.name, concurrency_idx)
if i + 1 < self._retry:
error_info = None
logging.warn(
self._log("warn: timeout, retry({})".
format(i + 1)))
except Exception as e:
logging.error("error: {}".format(e))
error_info = "{}({}): {}".format(
self.name, concurrency_idx, e)
logging.warn(self._log(e))
# TODO
break
else:
break
else:
call_future = self.midprocess(data)
_profiler.record("{}-prep_1".format(op_info_prefix))
except NotImplementedError as e:
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info,
data_id=data_id))
continue
_profiler.record("{}{}-midp_1".format(self.name,
concurrency_idx))
_profiler.record("{}{}-postp_0".format(self.name,
concurrency_idx))
if error_info is not None:
error_data = self.errorprocess(error_info, data_id)
output_data = ChannelData(pbdata=error_data)
# midprocess
call_future = None
ecode = 0
error_info = None
if self.with_serving():
_profiler.record("{}-midp_0".format(op_info_prefix))
if self._timeout <= 0:
try:
call_future = self.midprocess(data)
except Exception as e:
logging.error(self._log(e))
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
else:
if self.with_serving(): # use call_future
output_data = ChannelData(
future=call_future,
data_id=data_id,
callback_func=self.postprocess)
else:
post_data = self.postprocess(data)
if not isinstance(post_data, dict):
raise TypeError(
self._log(
'output_data must be dict type, but get {}'.
format(type(output_data))))
pbdata = channel_pb2.ChannelData()
for name, value in post_data.items():
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(
value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
pbdata.ecode = 0
pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}{}-postp_1".format(self.name,
concurrency_idx))
else:
output_data = ChannelData(pbdata=error_data)
_profiler.record("{}{}-push_0".format(self.name, concurrency_idx))
for channel in self._outputs:
channel.push(output_data, self.name)
_profiler.record("{}{}-push_1".format(self.name, concurrency_idx))
for i in range(self._retry):
try:
call_future = func_timeout.func_timeout(
self._timeout, self.midprocess, args=(data, ))
except func_timeout.FunctionTimedOut:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
error_info = "{} timeout".format(op_info_prefix)
else:
logging.warn(
log("warn: timeout, retry({})".format(i +
1)))
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
break
else:
break
if ecode != 0:
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
_profiler.record("{}-midp_1".format(op_info_prefix))
def _log(self, info_str):
return "[{}] {}".format(self.name, info_str)
# postprocess
output_data = None
_profiler.record("{}-postp_0".format(op_info_prefix))
if self.with_serving(): # use call_future
output_data = ChannelData(
future=call_future,
data_id=data_id,
callback_func=self.postprocess)
else:
post_data = self.postprocess(data)
if not isinstance(post_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(post_data)))
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
pbdata = channel_pb2.ChannelData()
for name, value in post_data.items():
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
pbdata.ecode = 0
pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}-postp_1".format(op_info_prefix))
# push data to channel (if run succ)
_profiler.record("{}-push_0".format(op_info_prefix))
self._push_to_output_channels(output_data)
_profiler.record("{}-push_1".format(op_info_prefix))
def _log(self, info):
return "{} {}".format(self.name, info)
def _get_log_func(self, op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
def get_concurrency(self):
return self._concurrency
......@@ -682,8 +728,9 @@ class GeneralPythonService(
if resp_channeldata.pbdata.ecode == 0:
break
logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info))
if i + 1 < self._retry:
logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info))
_profiler.record("{}-postpack_0".format(self.name))
resp = self._pack_data_for_resp(resp_channeldata)
......@@ -874,8 +921,6 @@ class PyServer(object):
server.start()
server.wait_for_termination()
self._stop_ops() # TODO
for th in self._op_threads:
th.join()
def prepare_serving(self, op):
model_path = op._server_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册