提交 a2a8409b 编写于 作者: B barrierye

add some error code

上级 5de538d5
...@@ -27,6 +27,8 @@ logging.basicConfig( ...@@ -27,6 +27,8 @@ logging.basicConfig(
class CombineOp(Op): class CombineOp(Op):
pass
'''
def preprocess(self, input_data): def preprocess(self, input_data):
combined_prediction = 0 combined_prediction = 0
for op_name, channeldata in input_data.items(): for op_name, channeldata in input_data.items():
...@@ -35,6 +37,7 @@ class CombineOp(Op): ...@@ -35,6 +37,7 @@ class CombineOp(Op):
combined_prediction += data["prediction"] combined_prediction += data["prediction"]
data = {"combined_prediction": combined_prediction / 2} data = {"combined_prediction": combined_prediction / 2}
return data return data
'''
read_op = Op(name="read", inputs=None) read_op = Op(name="read", inputs=None)
...@@ -47,7 +50,7 @@ bow_op = Op(name="bow", ...@@ -47,7 +50,7 @@ bow_op = Op(name="bow",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["prediction"], fetch_names=["prediction"],
concurrency=1, concurrency=1,
timeout=0.01, timeout=0.1,
retry=2) retry=2)
cnn_op = Op(name="cnn", cnn_op = Op(name="cnn",
inputs=[read_op], inputs=[read_op],
......
...@@ -49,6 +49,8 @@ class PyClient(object): ...@@ -49,6 +49,8 @@ class PyClient(object):
"fetch_with_type must be list type with format: [name].") "fetch_with_type must be list type with format: [name].")
req = self._pack_data_for_infer(feed) req = self._pack_data_for_infer(feed)
resp = self._stub.inference(req) resp = self._stub.inference(req)
if resp.ecode != 0:
raise Exception(resp.error_info)
fetch_map = {} fetch_map = {}
for idx, name in enumerate(resp.fetch_var_names): for idx, name in enumerate(resp.fetch_var_names):
if name not in fetch: if name not in fetch:
......
...@@ -77,6 +77,9 @@ _profiler = _TimeProfiler() ...@@ -77,6 +77,9 @@ _profiler = _TimeProfiler()
class ChannelDataEcode(enum.Enum): class ChannelDataEcode(enum.Enum):
OK = 0 OK = 0
TIMEOUT = 1 TIMEOUT = 1
NOT_IMPLEMENTED = 2
TYPE_ERROR = 3
UNKNOW = 4
class ChannelDataType(enum.Enum): class ChannelDataType(enum.Enum):
...@@ -89,8 +92,25 @@ class ChannelData(object): ...@@ -89,8 +92,25 @@ class ChannelData(object):
future=None, future=None,
pbdata=None, pbdata=None,
data_id=None, data_id=None,
callback_func=None): callback_func=None,
self.future = future ecode=None,
error_info=None):
'''
There are several ways to use it:
- ChannelData(future, pbdata[, callback_func])
- ChannelData(future, data_id[, callback_func])
- ChannelData(pbdata)
- ChannelData(ecode, error_info, data_id)
'''
if ecode is not None:
if data_id is None or error_info is None:
raise ValueError("data_id and error_info cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ecode
pbdata.id = data_id
pbdata.error_info = error_info
else:
if pbdata is None: if pbdata is None:
if data_id is None: if data_id is None:
raise ValueError("data_id cannot be None") raise ValueError("data_id cannot be None")
...@@ -98,6 +118,11 @@ class ChannelData(object): ...@@ -98,6 +118,11 @@ class ChannelData(object):
pbdata.type = ChannelDataType.CHANNEL_FUTURE.value pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
pbdata.ecode = ChannelDataEcode.OK.value pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id pbdata.id = data_id
elif not isinstance(pbdata, channel_pb2.ChannelData):
raise TypeError(
"pbdata must be pyserving_channel_pb2.ChannelData type({})".
format(type(pbdata)))
self.future = future
self.pbdata = pbdata self.pbdata = pbdata
self.callback_func = callback_func self.callback_func = callback_func
...@@ -389,10 +414,9 @@ class Op(object): ...@@ -389,10 +414,9 @@ class Op(object):
def preprocess(self, channeldata): def preprocess(self, channeldata):
if isinstance(channeldata, dict): if isinstance(channeldata, dict):
raise Exception( raise NotImplementedError(
self._log(
'this Op has multiple previous inputs. Please override this method' 'this Op has multiple previous inputs. Please override this method'
)) )
feed = channeldata.parse() feed = channeldata.parse()
return feed return feed
...@@ -412,13 +436,6 @@ class Op(object): ...@@ -412,13 +436,6 @@ class Op(object):
def postprocess(self, output_data): def postprocess(self, output_data):
return output_data return output_data
def errorprocess(self, error_info, data_id):
data = channel_pb2.ChannelData()
data.ecode = 1
data.id = data_id
data.error_info = error_info
return data
def stop(self): def stop(self):
self._input.stop() self._input.stop()
for channel in self._outputs: for channel in self._outputs:
...@@ -433,7 +450,7 @@ class Op(object): ...@@ -433,7 +450,7 @@ class Op(object):
data_id = channeldata[key].pbdata.id data_id = channeldata[key].pbdata.id
for _, data in channeldata.items(): for _, data in channeldata.items():
if data.pbdata.ecode != 0: if data.pbdata.ecode != 0:
error_data = data error_data = data.pbdata
break break
else: else:
data_id = channeldata.pbdata.id data_id = channeldata.pbdata.id
...@@ -441,65 +458,87 @@ class Op(object): ...@@ -441,65 +458,87 @@ class Op(object):
error_data = channeldata.pbdata error_data = channeldata.pbdata
return data_id, error_data return data_id, error_data
def _push_to_output_channels(self, data):
for channel in self._outputs:
channel.push(data, self.name)
def start(self, concurrency_idx): def start(self, concurrency_idx):
op_info_prefix = "[{}{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True self._run = True
while self._run: while self._run:
_profiler.record("{}{}-get_0".format(self.name, concurrency_idx)) _profiler.record("{}-get_0".format(op_info_prefix))
input_data = self._input.front(self.name) input_data = self._input.front(self.name)
_profiler.record("{}{}-get_1".format(self.name, concurrency_idx)) _profiler.record("{}-get_1".format(op_info_prefix))
logging.debug(self._log("input_data: {}".format(input_data))) logging.debug(log("input_data: {}".format(input_data)))
data_id, error_data = self._parse_channeldata(input_data) data_id, error_data = self._parse_channeldata(input_data)
output_data = None # predecessor Op error
if error_data is None: if error_data is not None:
_profiler.record("{}{}-prep_0".format(self.name, self._push_to_output_channels(ChannelData(pbdata=error_data))
concurrency_idx)) continue
# preprocess function not implemented
try:
_profiler.record("{}-prep_0".format(op_info_prefix))
data = self.preprocess(input_data) data = self.preprocess(input_data)
_profiler.record("{}{}-prep_1".format(self.name, _profiler.record("{}-prep_1".format(op_info_prefix))
concurrency_idx)) except NotImplementedError as e:
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info,
data_id=data_id))
continue
# midprocess
call_future = None call_future = None
ecode = 0
error_info = None error_info = None
if self.with_serving(): if self.with_serving():
_profiler.record("{}{}-midp_0".format(self.name, _profiler.record("{}-midp_0".format(op_info_prefix))
concurrency_idx)) if self._timeout <= 0:
if self._timeout > 0: try:
call_future = self.midprocess(data)
except Exception as e:
logging.error(self._log(e))
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
else:
for i in range(self._retry): for i in range(self._retry):
try: try:
call_future = func_timeout.func_timeout( call_future = func_timeout.func_timeout(
self._timeout, self._timeout, self.midprocess, args=(data, ))
self.midprocess,
args=(data, ))
except func_timeout.FunctionTimedOut: except func_timeout.FunctionTimedOut:
logging.error("error: timeout") if i + 1 >= self._retry:
error_info = "{}({}): timeout".format( ecode = ChannelDataEcode.TIMEOUT.value
self.name, concurrency_idx) error_info = "{} timeout".format(op_info_prefix)
if i + 1 < self._retry: else:
error_info = None
logging.warn( logging.warn(
self._log("warn: timeout, retry({})". log("warn: timeout, retry({})".format(i +
format(i + 1))) 1)))
except Exception as e: except Exception as e:
logging.error("error: {}".format(e)) ecode = ChannelDataEcode.UNKNOW.value
error_info = "{}({}): {}".format( error_info = log(e)
self.name, concurrency_idx, e) logging.error(error_info)
logging.warn(self._log(e))
# TODO
break break
else: else:
break break
else: if ecode != 0:
call_future = self.midprocess(data) self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
_profiler.record("{}-midp_1".format(op_info_prefix))
_profiler.record("{}{}-midp_1".format(self.name, # postprocess
concurrency_idx)) output_data = None
_profiler.record("{}{}-postp_0".format(self.name, _profiler.record("{}-postp_0".format(op_info_prefix))
concurrency_idx))
if error_info is not None:
error_data = self.errorprocess(error_info, data_id)
output_data = ChannelData(pbdata=error_data)
else:
if self.with_serving(): # use call_future if self.with_serving(): # use call_future
output_data = ChannelData( output_data = ChannelData(
future=call_future, future=call_future,
...@@ -508,34 +547,41 @@ class Op(object): ...@@ -508,34 +547,41 @@ class Op(object):
else: else:
post_data = self.postprocess(data) post_data = self.postprocess(data)
if not isinstance(post_data, dict): if not isinstance(post_data, dict):
raise TypeError( ecode = ChannelDataEcode.TYPE_ERROR.value
self._log( error_info = log("output of postprocess funticon must be " \
'output_data must be dict type, but get {}'. "dict type, but get {}".format(type(post_data)))
format(type(output_data)))) logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
pbdata = channel_pb2.ChannelData() pbdata = channel_pb2.ChannelData()
for name, value in post_data.items(): for name, value in post_data.items():
inst = channel_pb2.Inst() inst = channel_pb2.Inst()
inst.data = value.tobytes() inst.data = value.tobytes()
inst.name = name inst.name = name
inst.shape = np.array( inst.shape = np.array(value.shape, dtype="int32").tobytes()
value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype) inst.type = str(value.dtype)
pbdata.insts.append(inst) pbdata.insts.append(inst)
pbdata.ecode = 0 pbdata.ecode = 0
pbdata.id = data_id pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata) output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}{}-postp_1".format(self.name, _profiler.record("{}-postp_1".format(op_info_prefix))
concurrency_idx))
else:
output_data = ChannelData(pbdata=error_data)
_profiler.record("{}{}-push_0".format(self.name, concurrency_idx)) # push data to channel (if run succ)
for channel in self._outputs: _profiler.record("{}-push_0".format(op_info_prefix))
channel.push(output_data, self.name) self._push_to_output_channels(output_data)
_profiler.record("{}{}-push_1".format(self.name, concurrency_idx)) _profiler.record("{}-push_1".format(op_info_prefix))
def _log(self, info_str): def _log(self, info):
return "[{}] {}".format(self.name, info_str) return "{} {}".format(self.name, info)
def _get_log_func(self, op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
def get_concurrency(self): def get_concurrency(self):
return self._concurrency return self._concurrency
...@@ -682,6 +728,7 @@ class GeneralPythonService( ...@@ -682,6 +728,7 @@ class GeneralPythonService(
if resp_channeldata.pbdata.ecode == 0: if resp_channeldata.pbdata.ecode == 0:
break break
if i + 1 < self._retry:
logging.warn("retry({}): {}".format( logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info)) i + 1, resp_channeldata.pbdata.error_info))
...@@ -874,8 +921,6 @@ class PyServer(object): ...@@ -874,8 +921,6 @@ class PyServer(object):
server.start() server.start()
server.wait_for_termination() server.wait_for_termination()
self._stop_ops() # TODO self._stop_ops() # TODO
for th in self._op_threads:
th.join()
def prepare_serving(self, op): def prepare_serving(self, op):
model_path = op._server_model model_path = op._server_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册