提交 5312cdb7 编写于 作者: F felixhjh

add process check feed_list and fetch_list feature

上级 5d912c2a
......@@ -157,13 +157,13 @@ def ParamChecker(function):
# if there are invalid arguments, raise the error.
if len(invalid_argument_list) > 0:
raise CustomException(CustomExceptionCode.INPUT_PARAMS_ERROR, "invalid arg list: {}".format(invalid_argument_list))
raise CustomException(CustomExceptionCode.INPUT_PARAMS_ERROR, "invalid arg list: {}".format(invalid_argument_list), True)
# check the result.
result = function(*args, **kwargs)
checker = inspect.signature(function).return_annotation
if not check('return', result, checker, function):
raise CustomException(CustomExceptionCode.INPUT_PARAMS_ERROR, "invalid return type")
raise CustomException(CustomExceptionCode.INPUT_PARAMS_ERROR, "invalid return type", True)
# return the result.
return result
......@@ -211,9 +211,21 @@ class ParamVerify(object):
if len(feed_dict.keys()) != len(feed_list):
return False
for key in feed_list:
if key in feed_dict.keys():
if key not in feed_dict.keys():
return False
return True
@staticmethod
def check_fetch_list(fetch_list, right_fetch_list):
if not isinstance(fetch_list, list):
return False
# read model config, try catch and
if len(fetch_list) != len(right_fetch_list):
return False
for key in fetch_list:
if key not in right_fetch_list:
return False
return True
ErrorCatch = ErrorCatch()
......@@ -34,7 +34,9 @@ elif sys.version_info.major == 3:
else:
raise Exception("Error Python version")
from .error_catch import ErrorCatch, CustomException, CustomExceptionCode
from .error_catch import ErrorCatch, CustomException, CustomExceptionCode, ParamChecker, ParamVerify
check_feed_dict=ParamVerify.check_feed_dict
check_fetch_list=ParamVerify.check_fetch_list
from .proto import pipeline_service_pb2
from .channel import (ThreadChannel, ProcessChannel,ChannelData,
ChannelDataType, ChannelStopError, ChannelTimeoutError)
......@@ -560,7 +562,7 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict, False, None, ""
def process(self, feed_batch, typical_logid=0):
"""
In process stage, send requests to the inference server or predict locally.
......@@ -577,7 +579,19 @@ class Op(object):
call_result = None
err_code = ChannelDataErrcode.OK.value
err_info = ""
@ErrorCatch
@ParamChecker
def feed_fetch_list_check_helper(feed_batch : lambda feed_batch: check_feed_dict(feed_batch[0], self.right_feed_names),
fetch_list : lambda fetch_list: check_fetch_list(fetch_list, self.right_fetch_names),
log_id):
return None
_, resp = feed_fetch_list_check_helper(feed_batch, self._fetch_names, log_id=typical_logid)
if resp.err_no != CustomExceptionCode.OK.value:
err_code = resp.err_no
err_info = resp.err_msg
call_result = None
return call_result, err_code, err_info
if self.client_type == "local_predictor":
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
......@@ -1030,12 +1044,13 @@ class Op(object):
# 2 kinds of errors
if error_code != ChannelDataErrcode.OK.value or midped_batch is None:
error_info = "(log_id={}) {} failed to predict. Please check the input dict and checkout PipelineServingLogs/pipeline.log for more details.".format(
typical_logid, self.name)
error_info = "[{}] failed to predict. {}. Please check the input dict and checkout PipelineServingLogs/pipeline.log for more details.".format(
self.name, error_info)
_LOGGER.error(error_info)
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
error_code=ChannelDataErrcode.CLIENT_ERROR.value,
error_code=error_code,
error_info=error_info,
data_id=data_id,
log_id=logid_dict.get(data_id))
......@@ -1354,7 +1369,6 @@ class Op(object):
_LOGGER.debug("op:{} parse_end:{}".format(op_info_prefix,
time.time()))
# print
front_cost = int(round(_time() * 1000000)) - start
for data_id, parsed_data in parsed_data_dict.items():
_LOGGER.debug(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册