提交 6ecf9211 编写于 作者: B barriery

bug fix

上级 0727cf98
......@@ -117,6 +117,16 @@ class ChannelData(object):
"be dict, but get {}.".format(type(dictdata))
return ecode, error_info
@staticmethod
def check_batch_npdata(batch):
ecode = ChannelDataEcode.OK.value
error_info = None
for npdata in batch:
ecode, error_info = ChannelData.check_npdata(npdata)
if ecode != ChannelDataEcode.OK.value:
break
return ecode, error_info
@staticmethod
def check_npdata(npdata):
ecode = ChannelDataEcode.OK.value
......@@ -329,7 +339,8 @@ class ProcessChannel(object):
def front(self, op_name=None, timeout=None):
endtime = None
if timeout is not None and timeout <= 0:
if timeout is not None:
if timeout <= 0:
timeout = None
else:
endtime = _time() + timeout
......@@ -600,7 +611,8 @@ class ThreadChannel(Queue.Queue):
def front(self, op_name=None, timeout=None):
endtime = None
if timeout is not None and timeout <= 0:
if timeout is not None:
if timeout <= 0:
timeout = None
else:
endtime = _time() + timeout
......
......@@ -26,7 +26,8 @@ from numpy import *
from .proto import pipeline_service_pb2
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
ChannelData, ChannelDataType, ChannelStopError)
ChannelData, ChannelDataType, ChannelStopError,
ChannelTimeoutError)
from .util import NameGenerator
from .profiler import TimeProfiler
......@@ -45,7 +46,7 @@ class Op(object):
timeout=-1,
retry=1,
batch_size=1,
auto_batchint_timeout=None):
auto_batching_timeout=None):
if name is None:
name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique
......@@ -65,9 +66,9 @@ class Op(object):
self._outputs = []
self._batch_size = batch_size
self._auto_batchint_timeout = auto_batchint_timeout
if self._auto_batchint_timeout is not None and self._auto_batchint_timeout <= 0:
self._auto_batchint_timeout = None
self._auto_batching_timeout = auto_batching_timeout
if self._auto_batching_timeout is not None and self._auto_batching_timeout <= 0:
self._auto_batching_timeout = None
self._server_use_profile = False
......@@ -155,14 +156,13 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_dict):
#TODO: check batch
err, err_info = ChannelData.check_npdata(feed_dict)
def process(self, feed_batch):
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
call_result = self.client.predict(
feed=feed_dict, fetch=self._fetch_names)
feed=feed_batch, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result
......@@ -277,11 +277,12 @@ class Op(object):
err_channeldata_dict = {}
if self.with_serving:
data_ids = preped_data_dict.keys()
batch = [preped_data_dict[data_id] for data_id in data_ids]
feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
midped_batch = None
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(batch)
midped_batch = self.process(feed_batch)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
......@@ -290,7 +291,7 @@ class Op(object):
for i in range(self._retry):
try:
midped_batch = func_timeout.func_timeout(
self._timeout, self.process, args=(batch, ))
self._timeout, self.process, args=(feed_batch, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -333,7 +334,7 @@ class Op(object):
def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
postped_data_dict = {}
err_channeldata_dict = {}
for data_id, midped_data in mided_data_dict.items():
for data_id, midped_data in midped_data_dict.items():
postped_data, err_channeldata = None, None
try:
postped_data = self.postprocess(
......@@ -403,7 +404,7 @@ class Op(object):
parsed_data_dict = {}
need_profile_dict = {}
profile_dict = {}
for channeldata_dict in channeldata_dict_batch:
for channeldata_dict in batch:
(data_id, error_channeldata, parsed_data,
client_need_profile, profile_set) = \
self._parse_channeldata(channeldata_dict)
......@@ -424,16 +425,18 @@ class Op(object):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = get_log_func(op_info_prefix)
preplog = get_log_func(op_info_prefix + "(prep)")
midplog = get_log_func(op_info_prefix + "(midp)")
postplog = get_log_func(op_info_prefix + "(postp)")
tid = threading.current_thread().ident
# init op
try:
self._initialize(is_thread_op)
self._initialize(is_thread_op, client_type)
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
......@@ -468,7 +471,7 @@ class Op(object):
# preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix))
preped_data_dict, err_channeldata_dict \
= self._run_preprocess(parsed_data_dict, log)
= self._run_preprocess(parsed_data_dict, preplog)
self._profiler_record("prep#{}_1".format(op_info_prefix))
try:
for data_id, err_channeldata in err_channeldata_dict.items():
......@@ -487,7 +490,7 @@ class Op(object):
# process
self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data_dict, err_channeldata_dict \
= self._run_process(preped_data_dict, log)
= self._run_process(preped_data_dict, midplog)
self._profiler_record("midp#{}_1".format(op_info_prefix))
try:
for data_id, err_channeldata in err_channeldata_dict.items():
......@@ -507,7 +510,7 @@ class Op(object):
self._profiler_record("postp#{}_0".format(op_info_prefix))
postped_data_dict, err_channeldata_dict \
= self._run_postprocess(
parsed_data_dict, midped_data_dict, log)
parsed_data_dict, midped_data_dict, postplog)
self._profiler_record("postp#{}_1".format(op_info_prefix))
try:
for data_id, err_channeldata in err_channeldata_dict.items():
......@@ -536,7 +539,7 @@ class Op(object):
self._finalize(is_thread_op)
break
def _initialize(self, is_thread_op):
def _initialize(self, is_thread_op, client_type):
if is_thread_op:
with self._for_init_op_lock:
if not self._succ_init_op:
......
......@@ -18,6 +18,7 @@ import numpy as np
from numpy import *
import logging
import functools
from .channel import ChannelDataEcode
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
......@@ -59,7 +60,11 @@ class PipelineClient(object):
def _unpack_response_package(self, resp, fetch):
if resp.ecode != 0:
return {"ecode": resp.ecode, "error_info": resp.error_info}
return {
"ecode": resp.ecode,
"ecode_desc": ChannelDataEcode(resp.ecode),
"error_info": resp.error_info,
}
fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key):
if key == self._profile_key:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册