diff --git a/python/pipeline/channel.py b/python/pipeline/channel.py index ce55b187e66ae02916d04a57732391de01f4ece5..24be6d15e21a5c74d4993ba3fd2eac9eec7b6c81 100644 --- a/python/pipeline/channel.py +++ b/python/pipeline/channel.py @@ -27,7 +27,7 @@ import logging import enum import copy -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() class ChannelDataEcode(enum.Enum): @@ -97,17 +97,37 @@ class ChannelData(object): error_info = "the value of data must " \ "be dict, but get {}.".format(type(dictdata)) return ecode, error_info - + @staticmethod def check_npdata(npdata): ecode = ChannelDataEcode.OK.value error_info = None - for _, value in npdata.items(): - if not isinstance(value, np.ndarray): - ecode = ChannelDataEcode.TYPE_ERROR.value - error_info = "the value of data must " \ - "be np.ndarray, but get {}.".format(type(value)) - break + if isinstance(npdata, list): + # batch data + for sample in npdata: + if not isinstance(sample, dict): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be dict, but get {}.".format(type(sample)) + break + for _, value in sample.items(): + if not isinstance(value, np.ndarray): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be np.ndarray, but get {}.".format(type(value)) + return ecode, error_info + elif isinstance(npdata, dict): + # batch_size = 1 + for _, value in npdata.items(): + if not isinstance(value, np.ndarray): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be np.ndarray, but get {}.".format(type(value)) + break + else: + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be dict, but get {}.".format(type(npdata)) return ecode, error_info def parse(self): diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index d82cac888298f83a1c8412f742adbf7de3932471..9a8e8f39078cd2ede42166814c8defdc12feef78 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -25,7 +25,7 @@ from .proto import pipeline_service_pb2 from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType from .util import NameGenerator -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() _op_name_gen = NameGenerator("Op") @@ -142,7 +142,7 @@ class Op(object): _LOGGER.debug(self._log("get call_result")) return call_result - def postprocess(self, fetch_dict): + def postprocess(self, input_dict, fetch_dict): return fetch_dict def stop(self): @@ -267,10 +267,10 @@ class Op(object): midped_data = preped_data return midped_data, error_channeldata - def _run_postprocess(self, midped_data, data_id, log_func): + def _run_postprocess(self, input_dict, midped_data, data_id, log_func): output_data, error_channeldata = None, None try: - postped_data = self.postprocess(midped_data) + postped_data = self.postprocess(input_dict, midped_data) except Exception as e: error_info = log_func(e) _LOGGER.error(error_info) @@ -359,8 +359,8 @@ class Op(object): # postprocess self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid)) - output_data, error_channeldata = self._run_postprocess(midped_data, - data_id, log) + output_data, error_channeldata = self._run_postprocess( + parsed_data, midped_data, data_id, log) self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid)) if error_channeldata is not None: self._push_to_output_channels(error_channeldata, diff --git a/python/pipeline/pipeline_client.py b/python/pipeline/pipeline_client.py index 4ad05b5a953d5084ffda360c0a1ac561463898a4..891a2b2a2fe759d0f392ad043dcc6f9173bd0e3a 100644 --- a/python/pipeline/pipeline_client.py +++ b/python/pipeline/pipeline_client.py @@ -20,7 +20,7 @@ import functools from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2_grpc -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() class PipelineClient(object): @@ -52,7 +52,7 @@ class PipelineClient(object): return {"ecode": resp.ecode, "error_info": resp.error_info} fetch_map = {"ecode": resp.ecode} for idx, key in enumerate(resp.key): - if key not in fetch: + if fetch is not None and key not in fetch: continue data = resp.value[idx] try: @@ -62,16 +62,16 @@ class PipelineClient(object): fetch_map[key] = data return fetch_map - def predict(self, feed_dict, fetch, asyn=False): + def predict(self, feed_dict, fetch=None, asyn=False): if not isinstance(feed_dict, dict): raise TypeError( "feed must be dict type with format: {name: value}.") - if not isinstance(fetch, list): + if fetch is not None and not isinstance(fetch, list): raise TypeError("fetch must be list type with format: [name].") req = self._pack_request_package(feed_dict) if not asyn: resp = self._stub.inference(req) - return self._unpack_response_package(resp) + return self._unpack_response_package(resp, fetch) else: call_future = self._stub.inference.future(req) return PipelinePredictFuture( diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index 55289eeca42e02bb979d4a21791fdde44e0aff02..dabb1bcdad75fd93104c3c2944bc937996d0184b 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod from .profiler import TimeProfiler from .util import NameGenerator -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() _profiler = TimeProfiler() @@ -384,7 +384,7 @@ class PipelineServer(object): def prepare_server(self, yml_file): with open(yml_file) as f: - yml_config = yaml.load(f.read()) + yml_config = yaml.load(f.read(), Loader=yaml.FullLoader) self._port = yml_config.get('port', 8080) if not self._port_is_available(self._port): raise SystemExit("Prot {} is already used".format(self._port)) diff --git a/python/pipeline/profiler.py b/python/pipeline/profiler.py index 146203f7c184b506bb8fd70dadac1d89166a2de9..bf96bcbd43c0c0397427939b4f2737b68b2be8a6 100644 --- a/python/pipeline/profiler.py +++ b/python/pipeline/profiler.py @@ -24,7 +24,7 @@ else: raise Exception("Error Python version") import time -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() class TimeProfiler(object):