diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index dd5a7aaa28b5147111531310edd21a1dd37664fd..1a6a84be773ba5340834c6fc2ff739070dfc75ae 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -156,6 +156,7 @@ class Op(object): return input_dict def process(self, feed_dict): + #TODO: check batch err, err_info = ChannelData.check_npdata(feed_dict) if err != 0: raise NotImplementedError( @@ -235,42 +236,52 @@ class Op(object): def init_op(self): pass - def _run_preprocess(self, parsed_data, data_id, log_func): - preped_data, error_channeldata = None, None - try: - preped_data = self.preprocess(parsed_data) - except NotImplementedError as e: - # preprocess function not implemented - error_info = log_func(e) - _LOGGER.error(error_info) - error_channeldata = ChannelData( - ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, - error_info=error_info, - data_id=data_id) - except TypeError as e: - # Error type in channeldata.datatype - error_info = log_func(e) - _LOGGER.error(error_info) - error_channeldata = ChannelData( - ecode=ChannelDataEcode.TYPE_ERROR.value, - error_info=error_info, - data_id=data_id) - except Exception as e: - error_info = log_func(e) - _LOGGER.error(error_info) - error_channeldata = ChannelData( - ecode=ChannelDataEcode.UNKNOW.value, - error_info=error_info, - data_id=data_id) - return preped_data, error_channeldata - - def _run_process(self, preped_data, data_id, log_func): - midped_data, error_channeldata = None, None + def _run_preprocess(self, parsed_data_dict, log_func): + preped_data_dict = {} + err_channeldata_dict = {} + for data_id, parsed_data in parsed_data_dict.items(): + preped_data, error_channeldata = None, None + try: + preped_data = self.preprocess(parsed_data) + except NotImplementedError as e: + # preprocess function not implemented + error_info = log_func(e) + _LOGGER.error(error_info) + error_channeldata = ChannelData( + ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, + error_info=error_info, + data_id=data_id) + except TypeError as e: + # Error type in channeldata.datatype + error_info = log_func(e) + _LOGGER.error(error_info) + error_channeldata = ChannelData( + ecode=ChannelDataEcode.TYPE_ERROR.value, + error_info=error_info, + data_id=data_id) + except Exception as e: + error_info = log_func(e) + _LOGGER.error(error_info) + error_channeldata = ChannelData( + ecode=ChannelDataEcode.UNKNOW.value, + error_info=error_info, + data_id=data_id) + if error_channeldata is not None: + err_channeldata_dict[data_id] = error_channeldata + else: + preped_data_dict[data_id] = preped_data + return preped_data_dict, err_channeldata_dict + + def _run_process(self, preped_data_dict, log_func): + midped_data_dict = {} + err_channeldata_dict = {} if self.with_serving: + data_ids = preped_data_dict.keys() + batch = [preped_data_dict[data_id] for data_id in data_ids] ecode = ChannelDataEcode.OK.value if self._timeout <= 0: try: - midped_data = self.process(preped_data) + midped_data = self.process(batch) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log_func(e) @@ -278,8 +289,8 @@ class Op(object): else: for i in range(self._retry): try: - midped_data = func_timeout.func_timeout( - self._timeout, self.process, args=(preped_data, )) + midped_batch = func_timeout.func_timeout( + self._timeout, self.process, args=(batch, )) except func_timeout.FunctionTimedOut as e: if i + 1 >= self._retry: ecode = ChannelDataEcode.TIMEOUT.value @@ -296,54 +307,73 @@ class Op(object): else: break if ecode != ChannelDataEcode.OK.value: - error_channeldata = ChannelData( - ecode=ecode, error_info=error_info, data_id=data_id) - elif midped_data is None: + for data_id in data_ids: + err_channeldata_dict[data_id] = ChannelData( + ecode=ecode, + error_info=error_info, + data_id=data_id) + elif midped_batch is None: # op client return None - error_channeldata = ChannelData( - ecode=ChannelDataEcode.CLIENT_ERROR.value, - error_info=log_func( - "predict failed. pls check the server side."), - data_id=data_id) - else: - midped_data = preped_data - return midped_data, error_channeldata - - def _run_postprocess(self, input_dict, midped_data, data_id, log_func): - output_data, error_channeldata = None, None - try: - postped_data = self.postprocess(input_dict, midped_data) - except Exception as e: - error_info = log_func(e) - _LOGGER.error(error_info) - error_channeldata = ChannelData( - ecode=ChannelDataEcode.UNKNOW.value, - error_info=error_info, - data_id=data_id) - return output_data, error_channeldata - - if not isinstance(postped_data, dict): - error_info = log_func("output of postprocess funticon must be " \ - "dict type, but get {}".format(type(postped_data))) - _LOGGER.error(error_info) - error_channeldata = ChannelData( - ecode=ChannelDataEcode.UNKNOW.value, - error_info=error_info, - data_id=data_id) - return output_data, error_channeldata - - err, _ = ChannelData.check_npdata(postped_data) - if err == 0: - output_data = ChannelData( - ChannelDataType.CHANNEL_NPDATA.value, - npdata=postped_data, - data_id=data_id) + for data_id in data_ids: + err_channeldata_dict[data_id] = ChannelData( + ecode=ChannelDataEcode.CLIENT_ERROR.value, + error_info=log_func( + "predict failed. pls check the server side."), + data_id=data_id) + else: + # transform np format to dict format + for idx, data_id in enumerate(data_ids): + midped_data_dict[data_id] = { + k: v[idx] for k, v in midped_batch.items() + } else: - output_data = ChannelData( - ChannelDataType.DICT.value, - dictdata=postped_data, - data_id=data_id) - return output_data, error_channeldata + midped_data_dict = preped_data_dict + return midped_data_dict, err_channeldata_dict + + def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func): + postped_data_dict = {} + err_channeldata_dict = {} + for data_id, midped_data in mided_data_dict.items(): + postped_data, err_channeldata = None, None + try: + postped_data = self.postprocess( + parsed_data_dict[data_id], midped_data) + except Exception as e: + error_info = log_func(e) + _LOGGER.error(error_info) + err_channeldata = ChannelData( + ecode=ChannelDataEcode.UNKNOW.value, + error_info=error_info, + data_id=data_id) + if err_channeldata is not None: + err_channeldata_dict[data_id] = err_channeldata + continue + else: + if not isinstance(postped_data, dict): + error_info = log_func("output of postprocess funticon must be " \ + "dict type, but get {}".format(type(postped_data))) + _LOGGER.error(error_info) + err_channeldata = ChannelData( + ecode=ChannelDataEcode.UNKNOW.value, + error_info=error_info, + data_id=data_id) + err_channeldata_dict[data_id] = err_channeldata + continue + + output_data = None + err, _ = ChannelData.check_npdata(postped_data) + if err == 0: + output_data = ChannelData( + ChannelDataType.CHANNEL_NPDATA.value, + npdata=postped_data, + data_id=data_id) + else: + output_data = ChannelData( + ChannelDataType.DICT.value, + dictdata=postped_data, + data_id=data_id) + postped_data_dict[data_id] = output_data + return postped_data_dict, err_channeldata_dict def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout): while True: @@ -369,8 +399,28 @@ class Op(object): break yield batch - def _run(self, concurrency_idx, input_channel, output_channels, client_type, - is_thread_op): + def _parse_channeldata_batch(self, batch, output_channels): + parsed_data_dict = {} + need_profile_dict = {} + profile_dict = {} + for channeldata_dict in channeldata_dict_batch: + (data_id, error_channeldata, parsed_data, + client_need_profile, profile_set) = \ + self._parse_channeldata(channeldata_dict) + if error_channeldata is None: + parsed_data_dict[data_id] = parsed_data + need_profile_dict[data_id] = client_need_profile + profile_dict[data_id] = profile_set + else: + # error data in predecessor Op + # (error_channeldata with profile info) + self._push_to_output_channels( + error_channeldata, output_channels) + + return parsed_data_dict, need_profile_dict, profile_dict + + def _run(self, concurrency_idx, input_channel, output_channels, + client_type, is_thread_op): def get_log_func(op_info_prefix): def log_func(info_str): return "{} {}".format(op_info_prefix, info_str) @@ -395,7 +445,6 @@ class Op(object): timeout=self._auto_batching_timeout) while True: - channeldata_dict_batch = None try: channeldata_dict_batch = next(batch_generator) except ChannelStopError: @@ -405,95 +454,86 @@ class Op(object): # parse channeldata batch try: - # parse channeldata batch + parsed_data_dict, need_profile_dict, profile_dict \ + = self._parse_channeldata_batch( + channeldata_dict_batch, output_channels) except ChannelStopError: _LOGGER.debug(log("stop.")) + self._finalize(is_thread_op) break - nor_dataid_list = [] - err_dataid_list = [] - nor_datas = {} - err_datas = {} - for channeldata_dict in channeldata_dict_batch: - (data_id, error_channeldata, parsed_data, - client_need_profile, profile_set) = \ - self._parse_channeldata(channeldata_dict) - if error_channeldata is None: - nor_dataid_list.append(data_id) - nor_datas[data_id] = { - "pd": parsed_data, - "np": client_need_profile, - "ps": profile_set, - } - else: - # error data in predecessor Op - try: - # error_channeldata with profile info - self._push_to_output_channels(error_channeldata, - output_channels) - except ChannelStopError: - _LOGGER.debug(log("stop.")) - break + if len(parsed_data_dict) == 0: + # data in the whole batch is all error data + continue # preprecess self._profiler_record("prep#{}_0".format(op_info_prefix)) - preped_data, error_channeldata = self._run_preprocess(parsed_data, - data_id, log) + preped_data_dict, err_channeldata_dict \ + = self._run_preprocess(parsed_data_dict, log) self._profiler_record("prep#{}_1".format(op_info_prefix)) - if error_channeldata is not None: - try: + try: + for data_id, err_channeldata in err_channeldata_dict.items(): self._push_to_output_channels( - error_channeldata, + err_channeldata, output_channels, - client_need_profile=client_need_profile, - profile_set=profile_set) - except ChannelStopError: - _LOGGER.debug(log("stop.")) - break + client_need_profile=need_profile_dict[data_id], + profile_set=profile_dict[data_id]) + except ChannelStopError: + _LOGGER.debug(log("stop.")) + self._finalize(is_thread_op) + break + if len(parsed_data_dict) == 0: continue # process self._profiler_record("midp#{}_0".format(op_info_prefix)) - midped_data, error_channeldata = self._run_process(preped_data, - data_id, log) + midped_data_dict, err_channeldata_dict \ + = self._run_process(preped_data_dict, log) self._profiler_record("midp#{}_1".format(op_info_prefix)) - if error_channeldata is not None: - try: + try: + for data_id, err_channeldata in err_channeldata_dict.items(): self._push_to_output_channels( - error_channeldata, - output_channels, - client_need_profile=client_need_profile, - profile_set=profile_set) - except ChannelStopError: - _LOGGER.debug(log("stop.")) - break + err_channeldata, + output_channels, + client_need_profile=need_profile_dict[data_id], + profile_set=profile_dict[data_id]) + except ChannelStopError: + _LOGGER.debug(log("stop.")) + self._finalize(is_thread_op) + break + if len(midped_data_dict) == 0: continue # postprocess self._profiler_record("postp#{}_0".format(op_info_prefix)) - output_data, error_channeldata = self._run_postprocess( - parsed_data, midped_data, data_id, log) + postped_data_dict, err_channeldata_dict \ + = self._run_postprocess( + parsed_data_dict, midped_data_dict, log) self._profiler_record("postp#{}_1".format(op_info_prefix)) - if error_channeldata is not None: - try: + try: + for data_id, err_channeldata in err_channeldata_dict.items(): self._push_to_output_channels( - error_channeldata, - output_channels, - client_need_profile=client_need_profile, - profile_set=profile_set) - except ChannelStopError: - _LOGGER.debug(log("stop.")) - break + error_channeldata, + output_channels, + client_need_profile=need_profile_dict[data_id], + profile_set=profile_dict[data_id]) + except ChannelStopError: + _LOGGER.debug(log("stop.")) + self._finalize(is_thread_op) + break + if len(postped_data_dict) == 0: continue # push data to channel (if run succ) try: - self._push_to_output_channels( - output_data, - output_channels, - client_need_profile=client_need_profile, - profile_set=profile_set) + for data_id, postped_data in postped_data_dict.items(): + self._push_to_output_channels( + postped_data, + output_channels, + client_need_profile=need_profile_dict[data_id], + profile_set=profile_dict[data_id]) except ChannelStopError: _LOGGER.debug(log("stop.")) + self._finalize(is_thread_op) break def _initialize(self, is_thread_op):