From 6ecf9211fccf3a5b8f8fd991145df847b0103c74 Mon Sep 17 00:00:00 2001 From: barriery Date: Thu, 30 Jul 2020 04:03:32 +0000 Subject: [PATCH] bug fix --- python/pipeline/channel.py | 28 ++++++++---- python/pipeline/operator.py | 69 ++++++++++++++++-------------- python/pipeline/pipeline_client.py | 7 ++- 3 files changed, 62 insertions(+), 42 deletions(-) diff --git a/python/pipeline/channel.py b/python/pipeline/channel.py index 6c7004c0..cae20e00 100644 --- a/python/pipeline/channel.py +++ b/python/pipeline/channel.py @@ -117,6 +117,16 @@ class ChannelData(object): "be dict, but get {}.".format(type(dictdata)) return ecode, error_info + @staticmethod + def check_batch_npdata(batch): + ecode = ChannelDataEcode.OK.value + error_info = None + for npdata in batch: + ecode, error_info = ChannelData.check_npdata(npdata) + if ecode != ChannelDataEcode.OK.value: + break + return ecode, error_info + @staticmethod def check_npdata(npdata): ecode = ChannelDataEcode.OK.value @@ -329,10 +339,11 @@ class ProcessChannel(object): def front(self, op_name=None, timeout=None): endtime = None - if timeout is not None and timeout <= 0: - timeout = None - else: - endtime = _time() + timeout + if timeout is not None: + if timeout <= 0: + timeout = None + else: + endtime = _time() + timeout _LOGGER.debug(self._log("{} try to get data...".format(op_name))) if len(self._consumer_cursors) == 0: @@ -600,10 +611,11 @@ class ThreadChannel(Queue.Queue): def front(self, op_name=None, timeout=None): endtime = None - if timeout is not None and timeout <= 0: - timeout = None - else: - endtime = _time() + timeout + if timeout is not None: + if timeout <= 0: + timeout = None + else: + endtime = _time() + timeout _LOGGER.debug(self._log("{} try to get data".format(op_name))) if len(self._consumer_cursors) == 0: diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 1a6a84be..a933d21b 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -26,7 +26,8 @@ from numpy import * from .proto import pipeline_service_pb2 from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode, - ChannelData, ChannelDataType, ChannelStopError) + ChannelData, ChannelDataType, ChannelStopError, + ChannelTimeoutError) from .util import NameGenerator from .profiler import TimeProfiler @@ -45,7 +46,7 @@ class Op(object): timeout=-1, retry=1, batch_size=1, - auto_batchint_timeout=None): + auto_batching_timeout=None): if name is None: name = _op_name_gen.next() self.name = name # to identify the type of OP, it must be globally unique @@ -65,9 +66,9 @@ class Op(object): self._outputs = [] self._batch_size = batch_size - self._auto_batchint_timeout = auto_batchint_timeout - if self._auto_batchint_timeout is not None and self._auto_batchint_timeout <= 0: - self._auto_batchint_timeout = None + self._auto_batching_timeout = auto_batching_timeout + if self._auto_batching_timeout is not None and self._auto_batching_timeout <= 0: + self._auto_batching_timeout = None self._server_use_profile = False @@ -155,14 +156,13 @@ class Op(object): (_, input_dict), = input_dicts.items() return input_dict - def process(self, feed_dict): - #TODO: check batch - err, err_info = ChannelData.check_npdata(feed_dict) + def process(self, feed_batch): + err, err_info = ChannelData.check_batch_npdata(feed_batch) if err != 0: raise NotImplementedError( "{} Please override preprocess func.".format(err_info)) call_result = self.client.predict( - feed=feed_dict, fetch=self._fetch_names) + feed=feed_batch, fetch=self._fetch_names) _LOGGER.debug(self._log("get call_result")) return call_result @@ -277,11 +277,12 @@ class Op(object): err_channeldata_dict = {} if self.with_serving: data_ids = preped_data_dict.keys() - batch = [preped_data_dict[data_id] for data_id in data_ids] + feed_batch = [preped_data_dict[data_id] for data_id in data_ids] + midped_batch = None ecode = ChannelDataEcode.OK.value if self._timeout <= 0: try: - midped_data = self.process(batch) + midped_batch = self.process(feed_batch) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log_func(e) @@ -290,7 +291,7 @@ class Op(object): for i in range(self._retry): try: midped_batch = func_timeout.func_timeout( - self._timeout, self.process, args=(batch, )) + self._timeout, self.process, args=(feed_batch, )) except func_timeout.FunctionTimedOut as e: if i + 1 >= self._retry: ecode = ChannelDataEcode.TIMEOUT.value @@ -333,7 +334,7 @@ class Op(object): def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func): postped_data_dict = {} err_channeldata_dict = {} - for data_id, midped_data in mided_data_dict.items(): + for data_id, midped_data in midped_data_dict.items(): postped_data, err_channeldata = None, None try: postped_data = self.postprocess( @@ -403,7 +404,7 @@ class Op(object): parsed_data_dict = {} need_profile_dict = {} profile_dict = {} - for channeldata_dict in channeldata_dict_batch: + for channeldata_dict in batch: (data_id, error_channeldata, parsed_data, client_need_profile, profile_set) = \ self._parse_channeldata(channeldata_dict) @@ -419,21 +420,23 @@ class Op(object): return parsed_data_dict, need_profile_dict, profile_dict - def _run(self, concurrency_idx, input_channel, output_channels, + def _run(self, concurrency_idx, input_channel, output_channels, client_type, is_thread_op): def get_log_func(op_info_prefix): def log_func(info_str): return "{} {}".format(op_info_prefix, info_str) - return log_func op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) log = get_log_func(op_info_prefix) + preplog = get_log_func(op_info_prefix + "(prep)") + midplog = get_log_func(op_info_prefix + "(midp)") + postplog = get_log_func(op_info_prefix + "(postp)") tid = threading.current_thread().ident # init op try: - self._initialize(is_thread_op) + self._initialize(is_thread_op, client_type) except Exception as e: _LOGGER.error(log(e)) os._exit(-1) @@ -468,7 +471,7 @@ class Op(object): # preprecess self._profiler_record("prep#{}_0".format(op_info_prefix)) preped_data_dict, err_channeldata_dict \ - = self._run_preprocess(parsed_data_dict, log) + = self._run_preprocess(parsed_data_dict, preplog) self._profiler_record("prep#{}_1".format(op_info_prefix)) try: for data_id, err_channeldata in err_channeldata_dict.items(): @@ -487,7 +490,7 @@ class Op(object): # process self._profiler_record("midp#{}_0".format(op_info_prefix)) midped_data_dict, err_channeldata_dict \ - = self._run_process(preped_data_dict, log) + = self._run_process(preped_data_dict, midplog) self._profiler_record("midp#{}_1".format(op_info_prefix)) try: for data_id, err_channeldata in err_channeldata_dict.items(): @@ -507,7 +510,7 @@ class Op(object): self._profiler_record("postp#{}_0".format(op_info_prefix)) postped_data_dict, err_channeldata_dict \ = self._run_postprocess( - parsed_data_dict, midped_data_dict, log) + parsed_data_dict, midped_data_dict, postplog) self._profiler_record("postp#{}_1".format(op_info_prefix)) try: for data_id, err_channeldata in err_channeldata_dict.items(): @@ -536,7 +539,7 @@ class Op(object): self._finalize(is_thread_op) break - def _initialize(self, is_thread_op): + def _initialize(self, is_thread_op, client_type): if is_thread_op: with self._for_init_op_lock: if not self._succ_init_op: @@ -553,18 +556,18 @@ class Op(object): self.init_op() self._succ_init_op = True self._succ_close_op = False - else: - self.concurrency_idx = concurrency_idx - # init profiler - self._profiler = TimeProfiler() - self._profiler.enable(True) - # init client - self.client = self.init_client( - client_type, self._client_config, - self._server_endpoints, - self._fetch_names) - # user defined - self.init_op() + else: + self.concurrency_idx = concurrency_idx + # init profiler + self._profiler = TimeProfiler() + self._profiler.enable(True) + # init client + self.client = self.init_client( + client_type, self._client_config, + self._server_endpoints, + self._fetch_names) + # user defined + self.init_op() def _finalize(self, is_thread_op): if is_thread_op: diff --git a/python/pipeline/pipeline_client.py b/python/pipeline/pipeline_client.py index 6d96b926..330fa0c6 100644 --- a/python/pipeline/pipeline_client.py +++ b/python/pipeline/pipeline_client.py @@ -18,6 +18,7 @@ import numpy as np from numpy import * import logging import functools +from .channel import ChannelDataEcode from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2_grpc @@ -59,7 +60,11 @@ class PipelineClient(object): def _unpack_response_package(self, resp, fetch): if resp.ecode != 0: - return {"ecode": resp.ecode, "error_info": resp.error_info} + return { + "ecode": resp.ecode, + "ecode_desc": ChannelDataEcode(resp.ecode), + "error_info": resp.error_info, + } fetch_map = {"ecode": resp.ecode} for idx, key in enumerate(resp.key): if key == self._profile_key: -- GitLab