提交 6ecf9211 编写于 作者: B barriery

bug fix

上级 0727cf98
...@@ -117,6 +117,16 @@ class ChannelData(object): ...@@ -117,6 +117,16 @@ class ChannelData(object):
"be dict, but get {}.".format(type(dictdata)) "be dict, but get {}.".format(type(dictdata))
return ecode, error_info return ecode, error_info
@staticmethod
def check_batch_npdata(batch):
ecode = ChannelDataEcode.OK.value
error_info = None
for npdata in batch:
ecode, error_info = ChannelData.check_npdata(npdata)
if ecode != ChannelDataEcode.OK.value:
break
return ecode, error_info
@staticmethod @staticmethod
def check_npdata(npdata): def check_npdata(npdata):
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
...@@ -329,10 +339,11 @@ class ProcessChannel(object): ...@@ -329,10 +339,11 @@ class ProcessChannel(object):
def front(self, op_name=None, timeout=None): def front(self, op_name=None, timeout=None):
endtime = None endtime = None
if timeout is not None and timeout <= 0: if timeout is not None:
timeout = None if timeout <= 0:
else: timeout = None
endtime = _time() + timeout else:
endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data...".format(op_name))) _LOGGER.debug(self._log("{} try to get data...".format(op_name)))
if len(self._consumer_cursors) == 0: if len(self._consumer_cursors) == 0:
...@@ -600,10 +611,11 @@ class ThreadChannel(Queue.Queue): ...@@ -600,10 +611,11 @@ class ThreadChannel(Queue.Queue):
def front(self, op_name=None, timeout=None): def front(self, op_name=None, timeout=None):
endtime = None endtime = None
if timeout is not None and timeout <= 0: if timeout is not None:
timeout = None if timeout <= 0:
else: timeout = None
endtime = _time() + timeout else:
endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data".format(op_name))) _LOGGER.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumer_cursors) == 0: if len(self._consumer_cursors) == 0:
......
...@@ -26,7 +26,8 @@ from numpy import * ...@@ -26,7 +26,8 @@ from numpy import *
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode, from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
ChannelData, ChannelDataType, ChannelStopError) ChannelData, ChannelDataType, ChannelStopError,
ChannelTimeoutError)
from .util import NameGenerator from .util import NameGenerator
from .profiler import TimeProfiler from .profiler import TimeProfiler
...@@ -45,7 +46,7 @@ class Op(object): ...@@ -45,7 +46,7 @@ class Op(object):
timeout=-1, timeout=-1,
retry=1, retry=1,
batch_size=1, batch_size=1,
auto_batchint_timeout=None): auto_batching_timeout=None):
if name is None: if name is None:
name = _op_name_gen.next() name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique self.name = name # to identify the type of OP, it must be globally unique
...@@ -65,9 +66,9 @@ class Op(object): ...@@ -65,9 +66,9 @@ class Op(object):
self._outputs = [] self._outputs = []
self._batch_size = batch_size self._batch_size = batch_size
self._auto_batchint_timeout = auto_batchint_timeout self._auto_batching_timeout = auto_batching_timeout
if self._auto_batchint_timeout is not None and self._auto_batchint_timeout <= 0: if self._auto_batching_timeout is not None and self._auto_batching_timeout <= 0:
self._auto_batchint_timeout = None self._auto_batching_timeout = None
self._server_use_profile = False self._server_use_profile = False
...@@ -155,14 +156,13 @@ class Op(object): ...@@ -155,14 +156,13 @@ class Op(object):
(_, input_dict), = input_dicts.items() (_, input_dict), = input_dicts.items()
return input_dict return input_dict
def process(self, feed_dict): def process(self, feed_batch):
#TODO: check batch err, err_info = ChannelData.check_batch_npdata(feed_batch)
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0: if err != 0:
raise NotImplementedError( raise NotImplementedError(
"{} Please override preprocess func.".format(err_info)) "{} Please override preprocess func.".format(err_info))
call_result = self.client.predict( call_result = self.client.predict(
feed=feed_dict, fetch=self._fetch_names) feed=feed_batch, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result")) _LOGGER.debug(self._log("get call_result"))
return call_result return call_result
...@@ -277,11 +277,12 @@ class Op(object): ...@@ -277,11 +277,12 @@ class Op(object):
err_channeldata_dict = {} err_channeldata_dict = {}
if self.with_serving: if self.with_serving:
data_ids = preped_data_dict.keys() data_ids = preped_data_dict.keys()
batch = [preped_data_dict[data_id] for data_id in data_ids] feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
midped_batch = None
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
if self._timeout <= 0: if self._timeout <= 0:
try: try:
midped_data = self.process(batch) midped_batch = self.process(feed_batch)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e) error_info = log_func(e)
...@@ -290,7 +291,7 @@ class Op(object): ...@@ -290,7 +291,7 @@ class Op(object):
for i in range(self._retry): for i in range(self._retry):
try: try:
midped_batch = func_timeout.func_timeout( midped_batch = func_timeout.func_timeout(
self._timeout, self.process, args=(batch, )) self._timeout, self.process, args=(feed_batch, ))
except func_timeout.FunctionTimedOut as e: except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
...@@ -333,7 +334,7 @@ class Op(object): ...@@ -333,7 +334,7 @@ class Op(object):
def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func): def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
postped_data_dict = {} postped_data_dict = {}
err_channeldata_dict = {} err_channeldata_dict = {}
for data_id, midped_data in mided_data_dict.items(): for data_id, midped_data in midped_data_dict.items():
postped_data, err_channeldata = None, None postped_data, err_channeldata = None, None
try: try:
postped_data = self.postprocess( postped_data = self.postprocess(
...@@ -403,7 +404,7 @@ class Op(object): ...@@ -403,7 +404,7 @@ class Op(object):
parsed_data_dict = {} parsed_data_dict = {}
need_profile_dict = {} need_profile_dict = {}
profile_dict = {} profile_dict = {}
for channeldata_dict in channeldata_dict_batch: for channeldata_dict in batch:
(data_id, error_channeldata, parsed_data, (data_id, error_channeldata, parsed_data,
client_need_profile, profile_set) = \ client_need_profile, profile_set) = \
self._parse_channeldata(channeldata_dict) self._parse_channeldata(channeldata_dict)
...@@ -419,21 +420,23 @@ class Op(object): ...@@ -419,21 +420,23 @@ class Op(object):
return parsed_data_dict, need_profile_dict, profile_dict return parsed_data_dict, need_profile_dict, profile_dict
def _run(self, concurrency_idx, input_channel, output_channels, def _run(self, concurrency_idx, input_channel, output_channels,
client_type, is_thread_op): client_type, is_thread_op):
def get_log_func(op_info_prefix): def get_log_func(op_info_prefix):
def log_func(info_str): def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str) return "{} {}".format(op_info_prefix, info_str)
return log_func return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = get_log_func(op_info_prefix) log = get_log_func(op_info_prefix)
preplog = get_log_func(op_info_prefix + "(prep)")
midplog = get_log_func(op_info_prefix + "(midp)")
postplog = get_log_func(op_info_prefix + "(postp)")
tid = threading.current_thread().ident tid = threading.current_thread().ident
# init op # init op
try: try:
self._initialize(is_thread_op) self._initialize(is_thread_op, client_type)
except Exception as e: except Exception as e:
_LOGGER.error(log(e)) _LOGGER.error(log(e))
os._exit(-1) os._exit(-1)
...@@ -468,7 +471,7 @@ class Op(object): ...@@ -468,7 +471,7 @@ class Op(object):
# preprecess # preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix)) self._profiler_record("prep#{}_0".format(op_info_prefix))
preped_data_dict, err_channeldata_dict \ preped_data_dict, err_channeldata_dict \
= self._run_preprocess(parsed_data_dict, log) = self._run_preprocess(parsed_data_dict, preplog)
self._profiler_record("prep#{}_1".format(op_info_prefix)) self._profiler_record("prep#{}_1".format(op_info_prefix))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
...@@ -487,7 +490,7 @@ class Op(object): ...@@ -487,7 +490,7 @@ class Op(object):
# process # process
self._profiler_record("midp#{}_0".format(op_info_prefix)) self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data_dict, err_channeldata_dict \ midped_data_dict, err_channeldata_dict \
= self._run_process(preped_data_dict, log) = self._run_process(preped_data_dict, midplog)
self._profiler_record("midp#{}_1".format(op_info_prefix)) self._profiler_record("midp#{}_1".format(op_info_prefix))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
...@@ -507,7 +510,7 @@ class Op(object): ...@@ -507,7 +510,7 @@ class Op(object):
self._profiler_record("postp#{}_0".format(op_info_prefix)) self._profiler_record("postp#{}_0".format(op_info_prefix))
postped_data_dict, err_channeldata_dict \ postped_data_dict, err_channeldata_dict \
= self._run_postprocess( = self._run_postprocess(
parsed_data_dict, midped_data_dict, log) parsed_data_dict, midped_data_dict, postplog)
self._profiler_record("postp#{}_1".format(op_info_prefix)) self._profiler_record("postp#{}_1".format(op_info_prefix))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
...@@ -536,7 +539,7 @@ class Op(object): ...@@ -536,7 +539,7 @@ class Op(object):
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
def _initialize(self, is_thread_op): def _initialize(self, is_thread_op, client_type):
if is_thread_op: if is_thread_op:
with self._for_init_op_lock: with self._for_init_op_lock:
if not self._succ_init_op: if not self._succ_init_op:
...@@ -553,18 +556,18 @@ class Op(object): ...@@ -553,18 +556,18 @@ class Op(object):
self.init_op() self.init_op()
self._succ_init_op = True self._succ_init_op = True
self._succ_close_op = False self._succ_close_op = False
else: else:
self.concurrency_idx = concurrency_idx self.concurrency_idx = concurrency_idx
# init profiler # init profiler
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(True) self._profiler.enable(True)
# init client # init client
self.client = self.init_client( self.client = self.init_client(
client_type, self._client_config, client_type, self._client_config,
self._server_endpoints, self._server_endpoints,
self._fetch_names) self._fetch_names)
# user defined # user defined
self.init_op() self.init_op()
def _finalize(self, is_thread_op): def _finalize(self, is_thread_op):
if is_thread_op: if is_thread_op:
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
from numpy import * from numpy import *
import logging import logging
import functools import functools
from .channel import ChannelDataEcode
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
...@@ -59,7 +60,11 @@ class PipelineClient(object): ...@@ -59,7 +60,11 @@ class PipelineClient(object):
def _unpack_response_package(self, resp, fetch): def _unpack_response_package(self, resp, fetch):
if resp.ecode != 0: if resp.ecode != 0:
return {"ecode": resp.ecode, "error_info": resp.error_info} return {
"ecode": resp.ecode,
"ecode_desc": ChannelDataEcode(resp.ecode),
"error_info": resp.error_info,
}
fetch_map = {"ecode": resp.ecode} fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key): for idx, key in enumerate(resp.key):
if key == self._profile_key: if key == self._profile_key:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册