提交 4049eb8a 编写于 作者: B barriery

fix codestyle

上级 de58d900
...@@ -258,7 +258,8 @@ class ProcessChannel(object): ...@@ -258,7 +258,8 @@ class ProcessChannel(object):
def push(self, channeldata, op_name=None): def push(self, channeldata, op_name=None):
_LOGGER.debug( _LOGGER.debug(
self._log("{} try to push data[{}]".format(op_name, channeldata.id))) self._log("{} try to push data[{}]".format(op_name,
channeldata.id)))
if len(self._producers) == 0: if len(self._producers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -275,8 +276,9 @@ class ProcessChannel(object): ...@@ -275,8 +276,9 @@ class ProcessChannel(object):
if self._stop.value == 1: if self._stop.value == 1:
raise ChannelStopError() raise ChannelStopError()
self._cv.notify_all() self._cv.notify_all()
_LOGGER.debug(self._log("{} notify all".format(op_name))) _LOGGER.debug(
_LOGGER.debug(self._log("{} push data succ!".format(op_name))) self._log("{} succ push data[{}] into internal queue.".format(
op_name, channeldata.id)))
return True return True
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -287,7 +289,6 @@ class ProcessChannel(object): ...@@ -287,7 +289,6 @@ class ProcessChannel(object):
data_id = channeldata.id data_id = channeldata.id
put_data = None put_data = None
with self._cv: with self._cv:
_LOGGER.debug(self._log("{} get lock".format(op_name)))
if data_id not in self._input_buf: if data_id not in self._input_buf:
self._input_buf[data_id] = { self._input_buf[data_id] = {
name: None name: None
...@@ -309,14 +310,11 @@ class ProcessChannel(object): ...@@ -309,14 +310,11 @@ class ProcessChannel(object):
if put_data is None: if put_data is None:
_LOGGER.debug( _LOGGER.debug(
self._log("{} push data succ, but not push to queue.". self._log("{} succ push data[{}] into input_buffer.".format(
format(op_name))) op_name, data_id)))
else: else:
while self._stop.value == 0: while self._stop.value == 0:
try: try:
_LOGGER.debug(
self._log("{} push data succ: {}".format(
op_name, put_data.__str__())))
self._que.put(put_data, timeout=0) self._que.put(put_data, timeout=0)
break break
except Queue.Empty: except Queue.Empty:
...@@ -325,11 +323,15 @@ class ProcessChannel(object): ...@@ -325,11 +323,15 @@ class ProcessChannel(object):
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name))) self._log("{} succ push data[{}] into internal queue.".
format(op_name, data_id)))
self._cv.notify_all() self._cv.notify_all()
return True return True
def front(self, op_name=None, timeout=None): def front(self, op_name=None, timeout=None):
_LOGGER.debug(
self._log("{} try to get data[?]; timeout={}".format(op_name,
timeout)))
endtime = None endtime = None
if timeout is not None: if timeout is not None:
if timeout <= 0: if timeout <= 0:
...@@ -337,7 +339,6 @@ class ProcessChannel(object): ...@@ -337,7 +339,6 @@ class ProcessChannel(object):
else: else:
endtime = _time() + timeout endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data...".format(op_name)))
if len(self._consumer_cursors) == 0: if len(self._consumer_cursors) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -348,21 +349,24 @@ class ProcessChannel(object): ...@@ -348,21 +349,24 @@ class ProcessChannel(object):
with self._cv: with self._cv:
while self._stop.value == 0 and resp is None: while self._stop.value == 0 and resp is None:
try: try:
_LOGGER.debug(
self._log("{} try to get(with channel empty: {})".
format(op_name, self._que.empty())))
resp = self._que.get(timeout=0) resp = self._que.get(timeout=0)
break break
except Queue.Empty: except Queue.Empty:
if timeout is not None: if timeout is not None:
remaining = endtime - _time() remaining = endtime - _time()
if remaining <= 0.0: if remaining <= 0.0:
_LOGGER.debug(
self._log("{} get data[?] timeout".format(
op_name)))
raise ChannelTimeoutError() raise ChannelTimeoutError()
self._cv.wait(remaining) self._cv.wait(remaining)
else: else:
self._cv.wait() self._cv.wait()
if self._stop.value == 1: if self._stop.value == 1:
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug(
self._log("{} succ get data[{}]".format(op_name,
resp.values()[0].id)))
return resp return resp
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -384,22 +388,20 @@ class ProcessChannel(object): ...@@ -384,22 +388,20 @@ class ProcessChannel(object):
# it is necessary to obtain a data from queue and add it to output_buf. # it is necessary to obtain a data from queue and add it to output_buf.
while self._stop.value == 0 and self._consumer_cursors[ while self._stop.value == 0 and self._consumer_cursors[
op_name] - self._base_cursor.value >= len(self._output_buf): op_name] - self._base_cursor.value >= len(self._output_buf):
_LOGGER.debug(
self._log(
"({}) B self._consumer_cursors: {}, self._base_cursor: {}, len(self._output_buf): {}".
format(op_name, self._consumer_cursors,
self._base_cursor.value, len(self._output_buf))))
try: try:
_LOGGER.debug(
self._log("{} try to get(with channel size: {})".format(
op_name, self._que.qsize())))
channeldata = self._que.get(timeout=0) channeldata = self._que.get(timeout=0)
self._output_buf.append(channeldata) self._output_buf.append(channeldata)
_LOGGER.debug(
self._log("pop ready item[{}] into output_buffer".
format(channeldata.values()[0].id)))
break break
except Queue.Empty: except Queue.Empty:
if timeout is not None: if timeout is not None:
remaining = endtime - _time() remaining = endtime - _time()
if remaining <= 0.0: if remaining <= 0.0:
_LOGGER.debug(
self._log("{} get data[?] timeout".format(
op_name)))
raise ChannelTimeoutError() raise ChannelTimeoutError()
self._cv.wait(remaining) self._cv.wait(remaining)
else: else:
...@@ -411,7 +413,6 @@ class ProcessChannel(object): ...@@ -411,7 +413,6 @@ class ProcessChannel(object):
base_cursor = self._base_cursor.value base_cursor = self._base_cursor.value
data_idx = consumer_cursor - base_cursor data_idx = consumer_cursor - base_cursor
resp = self._output_buf[data_idx] resp = self._output_buf[data_idx]
_LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
self._cursor_count[consumer_cursor] -= 1 self._cursor_count[consumer_cursor] -= 1
if consumer_cursor == base_cursor and self._cursor_count[ if consumer_cursor == base_cursor and self._cursor_count[
...@@ -423,6 +424,7 @@ class ProcessChannel(object): ...@@ -423,6 +424,7 @@ class ProcessChannel(object):
self._base_cursor.value += 1 self._base_cursor.value += 1
# to avoid cursor overflow # to avoid cursor overflow
if self._base_cursor.value >= self._reset_max_cursor: if self._base_cursor.value >= self._reset_max_cursor:
_LOGGER.info(self._log("reset cursor in Channel"))
self._base_cursor.value -= self._reset_max_cursor self._base_cursor.value -= self._reset_max_cursor
for name in self._consumer_cursors.keys(): for name in self._consumer_cursors.keys():
self._consumer_cursors[name] -= self._reset_max_cursor self._consumer_cursors[name] -= self._reset_max_cursor
...@@ -440,16 +442,12 @@ class ProcessChannel(object): ...@@ -440,16 +442,12 @@ class ProcessChannel(object):
self._cursor_count[new_consumer_cursor] = 0 self._cursor_count[new_consumer_cursor] = 0
self._cursor_count[new_consumer_cursor] += 1 self._cursor_count[new_consumer_cursor] += 1
_LOGGER.debug(
self._log(
"({}) A self._consumer_cursors: {}, self._base_cursor: {}, len(self._output_buf): {}".
format(op_name, self._consumer_cursors,
self._base_cursor.value, len(self._output_buf))))
_LOGGER.debug(self._log("{} notify all".format(op_name)))
self._cv.notify_all() self._cv.notify_all()
_LOGGER.debug(self._log("multi | {} get data succ!".format(op_name))) _LOGGER.debug(
return resp # reference, read only self._log("{} succ get data[{}] from output_buffer".format(
op_name, resp.values()[0].id)))
return resp
def stop(self): def stop(self):
_LOGGER.debug(self._log("stop.")) _LOGGER.debug(self._log("stop."))
...@@ -538,7 +536,8 @@ class ThreadChannel(Queue.Queue): ...@@ -538,7 +536,8 @@ class ThreadChannel(Queue.Queue):
def push(self, channeldata, op_name=None): def push(self, channeldata, op_name=None):
_LOGGER.debug( _LOGGER.debug(
self._log("{} try to push data[{}]".format(op_name, channeldata.id))) self._log("{} try to push data[{}]".format(op_name,
channeldata.id)))
if len(self._producers) == 0: if len(self._producers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -556,9 +555,8 @@ class ThreadChannel(Queue.Queue): ...@@ -556,9 +555,8 @@ class ThreadChannel(Queue.Queue):
raise ChannelStopError() raise ChannelStopError()
self._cv.notify_all() self._cv.notify_all()
_LOGGER.debug( _LOGGER.debug(
self._log( self._log("{} succ push data[{}] into internal queue.".format(
"{} succ push data[{}] into internal queue.".format( op_name, channeldata.id)))
op_name, channeldata.id)))
return True return True
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -585,9 +583,8 @@ class ThreadChannel(Queue.Queue): ...@@ -585,9 +583,8 @@ class ThreadChannel(Queue.Queue):
if put_data is None: if put_data is None:
_LOGGER.debug( _LOGGER.debug(
self._log( self._log("{} succ push data[{}] into input_buffer.".format(
"{} succ push data[{}] into input_buffer.".format( op_name, data_id)))
op_name, data_id)))
else: else:
while self._stop is False: while self._stop is False:
try: try:
...@@ -599,17 +596,15 @@ class ThreadChannel(Queue.Queue): ...@@ -599,17 +596,15 @@ class ThreadChannel(Queue.Queue):
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log( self._log("{} succ push data[{}] into internal queue.".
"{} succ push data[{}] into internal queue.".format( format(op_name, data_id)))
op_name, data_id)))
self._cv.notify_all() self._cv.notify_all()
return True return True
def front(self, op_name=None, timeout=None): def front(self, op_name=None, timeout=None):
_LOGGER.debug( _LOGGER.debug(
self._log( self._log("{} try to get data[?]; timeout={}".format(op_name,
"{} try to get data[?]; timeout={}".format( timeout)))
op_name, timeout)))
endtime = None endtime = None
if timeout is not None: if timeout is not None:
if timeout <= 0: if timeout <= 0:
...@@ -634,8 +629,8 @@ class ThreadChannel(Queue.Queue): ...@@ -634,8 +629,8 @@ class ThreadChannel(Queue.Queue):
remaining = endtime - _time() remaining = endtime - _time()
if remaining <= 0.0: if remaining <= 0.0:
_LOGGER.debug( _LOGGER.debug(
self._log( self._log("{} get data[?] timeout".format(
"{} get data[?] timeout".format(op_name))) op_name)))
raise ChannelTimeoutError() raise ChannelTimeoutError()
self._cv.wait(remaining) self._cv.wait(remaining)
else: else:
...@@ -643,8 +638,8 @@ class ThreadChannel(Queue.Queue): ...@@ -643,8 +638,8 @@ class ThreadChannel(Queue.Queue):
if self._stop: if self._stop:
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("{} succ get data[{}]".format( self._log("{} succ get data[{}]".format(op_name,
op_name, resp.values()[0].id))) resp.values()[0].id)))
return resp return resp
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -670,17 +665,16 @@ class ThreadChannel(Queue.Queue): ...@@ -670,17 +665,16 @@ class ThreadChannel(Queue.Queue):
channeldata = self.get(timeout=0) channeldata = self.get(timeout=0)
self._output_buf.append(channeldata) self._output_buf.append(channeldata)
_LOGGER.debug( _LOGGER.debug(
self._log( self._log("pop ready item[{}] into output_buffer".
"pop ready item[{}] into output_buffer".format( format(channeldata.values()[0].id)))
channeldata.values()[0].id)))
break break
except Queue.Empty: except Queue.Empty:
if timeout is not None: if timeout is not None:
remaining = endtime - _time() remaining = endtime - _time()
if remaining <= 0.0: if remaining <= 0.0:
_LOGGER.debug( _LOGGER.debug(
self._log( self._log("{} get data[?] timeout".format(
"{} get data[?] timeout".format(op_name))) op_name)))
raise ChannelTimeoutError() raise ChannelTimeoutError()
self._cv.wait(remaining) self._cv.wait(remaining)
else: else:
...@@ -704,8 +698,7 @@ class ThreadChannel(Queue.Queue): ...@@ -704,8 +698,7 @@ class ThreadChannel(Queue.Queue):
self._base_cursor += 1 self._base_cursor += 1
# to avoid cursor overflow # to avoid cursor overflow
if self._base_cursor >= self._reset_max_cursor: if self._base_cursor >= self._reset_max_cursor:
_LOGGER.info( _LOGGER.info(self._log("reset cursor in Channel"))
self._log("reset cursor in Channel"))
self._base_cursor -= self._reset_max_cursor self._base_cursor -= self._reset_max_cursor
for name in self._consumer_cursors: for name in self._consumer_cursors:
self._consumer_cursors[name] -= self._reset_max_cursor self._consumer_cursors[name] -= self._reset_max_cursor
...@@ -725,9 +718,8 @@ class ThreadChannel(Queue.Queue): ...@@ -725,9 +718,8 @@ class ThreadChannel(Queue.Queue):
self._cv.notify_all() self._cv.notify_all()
_LOGGER.debug( _LOGGER.debug(
self._log( self._log("{} succ get data[{}] from output_buffer".format(
"{} succ get data[{}] from output_buffer".format( op_name, resp.values()[0].id)))
op_name, resp.values()[0].id)))
return resp return resp
def stop(self): def stop(self):
...@@ -736,10 +728,12 @@ class ThreadChannel(Queue.Queue): ...@@ -736,10 +728,12 @@ class ThreadChannel(Queue.Queue):
with self._cv: with self._cv:
self._cv.notify_all() self._cv.notify_all()
class ChannelTimeoutError(RuntimeError): class ChannelTimeoutError(RuntimeError):
def __init__(self): def __init__(self):
pass pass
class ChannelStopError(RuntimeError): class ChannelStopError(RuntimeError):
def __init__(self): def __init__(self):
pass pass
...@@ -47,7 +47,7 @@ class DAGExecutor(object): ...@@ -47,7 +47,7 @@ class DAGExecutor(object):
for key, val in default_conf.items(): for key, val in default_conf.items():
if dag_config.get(key) is None: if dag_config.get(key) is None:
_LOGGER.warning("[CONF] {} not set, use default: {}" _LOGGER.warning("[CONF] {} not set, use default: {}"
.format(key, val)) .format(key, val))
dag_config[key] = val dag_config[key] = val
self._retry = dag_config["retry"] self._retry = dag_config["retry"]
...@@ -60,7 +60,7 @@ class DAGExecutor(object): ...@@ -60,7 +60,7 @@ class DAGExecutor(object):
if show_info: if show_info:
_LOGGER.info("=============== DAGExecutor ===============") _LOGGER.info("=============== DAGExecutor ===============")
for key in default_conf.keys(): for key in default_conf.keys():
_LOGGER.info("{}: {}".format(key, dag_config[key])) _LOGGER.info("{}: {}".format(key, dag_config[key]))
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
self.name = "@G" self.name = "@G"
...@@ -110,17 +110,15 @@ class DAGExecutor(object): ...@@ -110,17 +110,15 @@ class DAGExecutor(object):
def _set_in_channel(self, in_channel): def _set_in_channel(self, in_channel):
if not isinstance(in_channel, (ThreadChannel, ProcessChannel)): if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
raise TypeError( raise TypeError("in_channel must be Channel type, but get {}".
"in_channel must be Channel type, but get {}".format( format(type(in_channel)))
type(in_channel)))
in_channel.add_producer(self.name) in_channel.add_producer(self.name)
self._in_channel = in_channel self._in_channel = in_channel
def _set_out_channel(self, out_channel): def _set_out_channel(self, out_channel):
if not isinstance(out_channel, (ThreadChannel, ProcessChannel)): if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
raise TypeError( raise TypeError("iout_channel must be Channel type, but get {}".
"iout_channel must be Channel type, but get {}".format( format(type(out_channel)))
type(out_channel)))
out_channel.add_consumer(self.name) out_channel.add_consumer(self.name)
self._out_channel = out_channel self._out_channel = out_channel
...@@ -143,12 +141,14 @@ class DAGExecutor(object): ...@@ -143,12 +141,14 @@ class DAGExecutor(object):
break break
if len(channeldata_dict) != 1: if len(channeldata_dict) != 1:
_LOGGER.error("[DAG Executor] out_channel cannot have multiple input ops") _LOGGER.error(
"[DAG Executor] out_channel cannot have multiple input ops")
os._exit(-1) os._exit(-1)
(_, channeldata), = channeldata_dict.items() (_, channeldata), = channeldata_dict.items()
if not isinstance(channeldata, ChannelData): if not isinstance(channeldata, ChannelData):
_LOGGER.error('[DAG Executor] data must be ChannelData type, but get {}' _LOGGER.error(
.format(type(channeldata))) '[DAG Executor] data must be ChannelData type, but get {}'
.format(type(channeldata)))
os._exit(-1) os._exit(-1)
data_id = channeldata.id data_id = channeldata.id
...@@ -178,7 +178,7 @@ class DAGExecutor(object): ...@@ -178,7 +178,7 @@ class DAGExecutor(object):
dictdata = self._unpack_rpc_func(rpc_request) dictdata = self._unpack_rpc_func(rpc_request)
except Exception as e: except Exception as e:
_LOGGER.error("parse RPC package to data[{}] Error: {}" _LOGGER.error("parse RPC package to data[{}] Error: {}"
.format(data_id, e)) .format(data_id, e))
return ChannelData( return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value, ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error: {}".format(e), error_info="rpc package error: {}".format(e),
...@@ -192,7 +192,8 @@ class DAGExecutor(object): ...@@ -192,7 +192,8 @@ class DAGExecutor(object):
profile_value = rpc_request.value[idx] profile_value = rpc_request.value[idx]
break break
client_need_profile = (profile_value == self._client_profile_value) client_need_profile = (profile_value == self._client_profile_value)
_LOGGER.debug("request[{}] need profile: {}".format(data_id, client_need_profile)) _LOGGER.debug("request[{}] need profile: {}".format(
data_id, client_need_profile))
return ChannelData( return ChannelData(
datatype=ChannelDataType.DICT.value, datatype=ChannelDataType.DICT.value,
dictdata=dictdata, dictdata=dictdata,
...@@ -208,7 +209,8 @@ class DAGExecutor(object): ...@@ -208,7 +209,8 @@ class DAGExecutor(object):
else: else:
self._profiler.record("call_{}#DAG_0".format(data_id)) self._profiler.record("call_{}#DAG_0".format(data_id))
_LOGGER.debug("try parse RPC package to channeldata[{}]".format(data_id)) _LOGGER.debug("try parse RPC package to channeldata[{}]".format(
data_id))
self._profiler.record("prepack_{}#{}_0".format(data_id, self.name)) self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
req_channeldata = self._pack_channeldata(rpc_request, data_id) req_channeldata = self._pack_channeldata(rpc_request, data_id)
self._profiler.record("prepack_{}#{}_1".format(data_id, self.name)) self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
...@@ -226,23 +228,26 @@ class DAGExecutor(object): ...@@ -226,23 +228,26 @@ class DAGExecutor(object):
error_info="dag closed.", error_info="dag closed.",
data_id=data_id)) data_id=data_id))
_LOGGER.debug("wait for Graph engine for data[{}]...".format(data_id)) _LOGGER.debug("wait for Graph engine for data[{}]...".format(
data_id))
resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id) resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id)
if resp_channeldata.ecode == ChannelDataEcode.OK.value: if resp_channeldata.ecode == ChannelDataEcode.OK.value:
_LOGGER.debug("Graph engine predict data[{}] succ".format(data_id)) _LOGGER.debug("Graph engine predict data[{}] succ".format(
data_id))
break break
else: else:
_LOGGER.warn("Graph engine predict data[{}] failed: {}" _LOGGER.warn("Graph engine predict data[{}] failed: {}"
.format(data_id, resp_channeldata.error_info)) .format(data_id, resp_channeldata.error_info))
if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value: if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value:
break break
if i + 1 < self._retry: if i + 1 < self._retry:
_LOGGER.warn("retry({}/{}) data[{}]".format( _LOGGER.warn("retry({}/{}) data[{}]".format(i + 1, self._retry,
i + 1, self._retry, data_id)) data_id))
_LOGGER.debug("unpack channeldata[{}] into RPC resp package".format(data_id)) _LOGGER.debug("unpack channeldata[{}] into RPC resp package".format(
data_id))
self._profiler.record("postpack_{}#{}_0".format(data_id, self.name)) self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
rpc_resp = self._pack_for_rpc_resp(resp_channeldata) rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
self._profiler.record("postpack_{}#{}_1".format(data_id, self.name)) self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
...@@ -380,8 +385,8 @@ class DAG(object): ...@@ -380,8 +385,8 @@ class DAG(object):
) )
if self._build_dag_each_worker: if self._build_dag_each_worker:
_LOGGER.info("Because `build_dag_each_worker` mode is used, " _LOGGER.info("Because `build_dag_each_worker` mode is used, "
"Auto-batching is set to the default config: " "Auto-batching is set to the default config: "
"batch_size=1, auto_batching_timeout=None") "batch_size=1, auto_batching_timeout=None")
for op in used_ops: for op in used_ops:
op.use_default_auto_batching_config() op.use_default_auto_batching_config()
...@@ -439,7 +444,7 @@ class DAG(object): ...@@ -439,7 +444,7 @@ class DAG(object):
channel = self._gen_channel(channel_name_gen) channel = self._gen_channel(channel_name_gen)
channels.append(channel) channels.append(channel)
_LOGGER.debug("[DAG] Channel({}) => Op({})" _LOGGER.debug("[DAG] Channel({}) => Op({})"
.format(channel.name, op.name)) .format(channel.name, op.name))
op.add_input_channel(channel) op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op.name] pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0: if v_idx == 0:
...@@ -448,7 +453,7 @@ class DAG(object): ...@@ -448,7 +453,7 @@ class DAG(object):
# if pred_op is virtual op, it will use ancestors as producers to channel # if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops: for pred_op in pred_ops:
_LOGGER.debug("[DAG] Op({}) => Channel({})" _LOGGER.debug("[DAG] Op({}) => Channel({})"
.format(pred_op.name, channel.name)) .format(pred_op.name, channel.name))
pred_op.add_output_channel(channel) pred_op.add_output_channel(channel)
processed_op.add(op.name) processed_op.add(op.name)
# find same input op to combine channel # find same input op to combine channel
...@@ -465,7 +470,7 @@ class DAG(object): ...@@ -465,7 +470,7 @@ class DAG(object):
break break
if same_flag: if same_flag:
_LOGGER.debug("[DAG] Channel({}) => Op({})" _LOGGER.debug("[DAG] Channel({}) => Op({})"
.format(channel.name, other_op.name)) .format(channel.name, other_op.name))
other_op.add_input_channel(channel) other_op.add_input_channel(channel)
processed_op.add(other_op.name) processed_op.add(other_op.name)
output_channel = self._gen_channel(channel_name_gen) output_channel = self._gen_channel(channel_name_gen)
...@@ -484,7 +489,7 @@ class DAG(object): ...@@ -484,7 +489,7 @@ class DAG(object):
for c in channels: for c in channels:
_LOGGER.debug("Channel({}):\n -producers: {}\n -consumers: {}" _LOGGER.debug("Channel({}):\n -producers: {}\n -consumers: {}"
.format(c.name, c.get_producers(), c.get_consumers())) .format(c.name, c.get_producers(), c.get_consumers()))
return (actual_ops, channels, input_channel, output_channel, pack_func, return (actual_ops, channels, input_channel, output_channel, pack_func,
unpack_func) unpack_func)
......
...@@ -81,14 +81,12 @@ class Op(object): ...@@ -81,14 +81,12 @@ class Op(object):
def use_default_auto_batching_config(self): def use_default_auto_batching_config(self):
if self._batch_size != 1: if self._batch_size != 1:
_LOGGER.warn( _LOGGER.warn("Op({}) reset batch_size=1 (original: {})"
"Op({}) reset batch_size=1 (original: {})" .format(self.name, self._batch_size))
.format(self.name, self._batch_size))
self._batch_size = 1 self._batch_size = 1
if self._auto_batching_timeout != None: if self._auto_batching_timeout != None:
_LOGGER.warn( _LOGGER.warn("Op({}) reset auto_batching_timeout=1 (original: {})"
"Op({}) reset auto_batching_timeout=1 (original: {})" .format(self.name, self._auto_batching_timeout))
.format(self.name, self._auto_batching_timeout))
self._auto_batching_timeout = None self._auto_batching_timeout = None
def use_profiler(self, use_profile): def use_profiler(self, use_profile):
...@@ -104,10 +102,12 @@ class Op(object): ...@@ -104,10 +102,12 @@ class Op(object):
if self.with_serving == False: if self.with_serving == False:
_LOGGER.info("Op({}) no client".format(self.name)) _LOGGER.info("Op({}) no client".format(self.name))
return None return None
_LOGGER.info("Op({}) service endpoints: {}".format(self.name, server_endpoints)) _LOGGER.info("Op({}) service endpoints: {}".format(self.name,
server_endpoints))
_LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names)) _LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc': if client_type == 'brpc':
_LOGGER.debug("Op({}) client_config: {}".format(self.name, client_config)) _LOGGER.debug("Op({}) client_config: {}".format(self.name,
client_config))
client = Client() client = Client()
client.load_client_config(client_config) client.load_client_config(client_config)
elif client_type == 'grpc': elif client_type == 'grpc':
...@@ -259,27 +259,24 @@ class Op(object): ...@@ -259,27 +259,24 @@ class Op(object):
preped_data = self.preprocess(parsed_data) preped_data = self.preprocess(parsed_data)
except NotImplementedError as e: except NotImplementedError as e:
# preprocess function not implemented # preprocess function not implemented
error_info = log_func( error_info = log_func("preprocess data[{}] failed: {}".format(
"preprocess data[{}] failed: {}".format( data_id, e))
data_id, e))
error_channeldata = ChannelData( error_channeldata = ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info, error_info=error_info,
data_id=data_id) data_id=data_id)
except TypeError as e: except TypeError as e:
# Error type in channeldata.datatype # Error type in channeldata.datatype
error_info = log_func( error_info = log_func("preprocess data[{}] failed: {}".format(
"preprocess data[{}] failed: {}".format( data_id, e))
data_id, e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
error_channeldata = ChannelData( error_channeldata = ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value, ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info, error_info=error_info,
data_id=data_id) data_id=data_id)
except Exception as e: except Exception as e:
error_info = log_func( error_info = log_func("preprocess data[{}] failed: {}".format(
"preprocess data[{}] failed: {}".format( data_id, e))
data_id, e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
error_channeldata = ChannelData( error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value, ecode=ChannelDataEcode.UNKNOW.value,
...@@ -321,10 +318,11 @@ class Op(object): ...@@ -321,10 +318,11 @@ class Op(object):
else: else:
_LOGGER.warn( _LOGGER.warn(
log_func("timeout, retry({}/{})" log_func("timeout, retry({}/{})"
.format(i + 1, self._retry))) .format(i + 1, self._retry)))
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func("process batch failed: {}".format(e)) error_info = log_func("process batch failed: {}".format(
e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
break break
else: else:
...@@ -332,24 +330,23 @@ class Op(object): ...@@ -332,24 +330,23 @@ class Op(object):
if ecode != ChannelDataEcode.OK.value: if ecode != ChannelDataEcode.OK.value:
for data_id in data_ids: for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData( err_channeldata_dict[data_id] = ChannelData(
ecode=ecode, ecode=ecode, error_info=error_info, data_id=data_id)
error_info=error_info,
data_id=data_id)
elif midped_batch is None: elif midped_batch is None:
# op client return None # op client return None
error_info=log_func( error_info = log_func(
"predict failed. pls check the server side.") "predict failed. pls check the server side.")
_LOGGER.error(error_info) _LOGGER.error(error_info)
for data_id in data_ids: for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData( err_channeldata_dict[data_id] = ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value, ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=error_info, error_info=error_info,
data_id=data_id) data_id=data_id)
else: else:
# transform np format to dict format # transform np format to dict format
for idx, data_id in enumerate(data_ids): for idx, data_id in enumerate(data_ids):
midped_data_dict[data_id] = { midped_data_dict[data_id] = {
k: v[idx] for k, v in midped_batch.items() k: v[idx]
for k, v in midped_batch.items()
} }
else: else:
midped_data_dict = preped_data_dict midped_data_dict = preped_data_dict
...@@ -363,11 +360,11 @@ class Op(object): ...@@ -363,11 +360,11 @@ class Op(object):
for data_id, midped_data in midped_data_dict.items(): for data_id, midped_data in midped_data_dict.items():
postped_data, err_channeldata = None, None postped_data, err_channeldata = None, None
try: try:
postped_data = self.postprocess( postped_data = self.postprocess(parsed_data_dict[data_id],
parsed_data_dict[data_id], midped_data) midped_data)
except Exception as e: except Exception as e:
error_info = log_func("postprocess data[{}] failed: {}" error_info = log_func("postprocess data[{}] failed: {}"
.format(data_id, e)) .format(data_id, e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
err_channeldata = ChannelData( err_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value, ecode=ChannelDataEcode.UNKNOW.value,
...@@ -403,15 +400,14 @@ class Op(object): ...@@ -403,15 +400,14 @@ class Op(object):
postped_data_dict[data_id] = output_data postped_data_dict[data_id] = output_data
_LOGGER.debug(log_func("succ run postprocess")) _LOGGER.debug(log_func("succ run postprocess"))
return postped_data_dict, err_channeldata_dict return postped_data_dict, err_channeldata_dict
def _auto_batching_generator(self, input_channel, op_name, def _auto_batching_generator(self, input_channel, op_name, batch_size,
batch_size, timeout, log_func): timeout, log_func):
while True: while True:
batch = [] batch = []
_LOGGER.debug( _LOGGER.debug(
log_func( log_func("Auto-batching expect size: {}; timeout: {}".format(
"Auto-batching expect size: {}; timeout: {}".format( batch_size, timeout)))
batch_size, timeout)))
while len(batch) == 0: while len(batch) == 0:
endtime = None endtime = None
if timeout is not None: if timeout is not None:
...@@ -424,14 +420,16 @@ class Op(object): ...@@ -424,14 +420,16 @@ class Op(object):
if remaining <= 0.0: if remaining <= 0.0:
_LOGGER.debug(log_func("Auto-batching timeout")) _LOGGER.debug(log_func("Auto-batching timeout"))
break break
channeldata_dict = input_channel.front(op_name, timeout) channeldata_dict = input_channel.front(op_name,
timeout)
else: else:
channeldata_dict = input_channel.front(op_name) channeldata_dict = input_channel.front(op_name)
batch.append(channeldata_dict) batch.append(channeldata_dict)
except ChannelTimeoutError: except ChannelTimeoutError:
_LOGGER.debug(log_func("Auto-batching timeout")) _LOGGER.debug(log_func("Auto-batching timeout"))
break break
_LOGGER.debug(log_func("Auto-batching actual size: {}".format(len(batch)))) _LOGGER.debug(
log_func("Auto-batching actual size: {}".format(len(batch))))
yield batch yield batch
def _parse_channeldata_batch(self, batch, output_channels): def _parse_channeldata_batch(self, batch, output_channels):
...@@ -449,16 +447,17 @@ class Op(object): ...@@ -449,16 +447,17 @@ class Op(object):
else: else:
# error data in predecessor Op # error data in predecessor Op
# (error_channeldata with profile info) # (error_channeldata with profile info)
self._push_to_output_channels( self._push_to_output_channels(error_channeldata,
error_channeldata, output_channels) output_channels)
return parsed_data_dict, need_profile_dict, profile_dict return parsed_data_dict, need_profile_dict, profile_dict
def _run(self, concurrency_idx, input_channel, output_channels, def _run(self, concurrency_idx, input_channel, output_channels, client_type,
client_type, is_thread_op): is_thread_op):
def get_log_func(op_info_prefix): def get_log_func(op_info_prefix):
def log_func(info_str): def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str) return "{} {}".format(op_info_prefix, info_str)
return log_func return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
...@@ -474,12 +473,12 @@ class Op(object): ...@@ -474,12 +473,12 @@ class Op(object):
_LOGGER.info(log("succ init")) _LOGGER.info(log("succ init"))
batch_generator = self._auto_batching_generator( batch_generator = self._auto_batching_generator(
input_channel=input_channel, input_channel=input_channel,
op_name=self.name, op_name=self.name,
batch_size=self._batch_size, batch_size=self._batch_size,
timeout=self._auto_batching_timeout, timeout=self._auto_batching_timeout,
log_func=log) log_func=log)
while True: while True:
try: try:
channeldata_dict_batch = next(batch_generator) channeldata_dict_batch = next(batch_generator)
...@@ -528,10 +527,10 @@ class Op(object): ...@@ -528,10 +527,10 @@ class Op(object):
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
err_channeldata, err_channeldata,
output_channels, output_channels,
client_need_profile=need_profile_dict[data_id], client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id]) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("channel stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
...@@ -548,10 +547,10 @@ class Op(object): ...@@ -548,10 +547,10 @@ class Op(object):
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
error_channeldata, error_channeldata,
output_channels, output_channels,
client_need_profile=need_profile_dict[data_id], client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id]) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("channel stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
...@@ -563,10 +562,10 @@ class Op(object): ...@@ -563,10 +562,10 @@ class Op(object):
try: try:
for data_id, postped_data in postped_data_dict.items(): for data_id, postped_data in postped_data_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
postped_data, postped_data,
output_channels, output_channels,
client_need_profile=need_profile_dict[data_id], client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id]) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("channel stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
...@@ -583,8 +582,8 @@ class Op(object): ...@@ -583,8 +582,8 @@ class Op(object):
self._profiler.enable(True) self._profiler.enable(True)
# init client # init client
self.client = self.init_client( self.client = self.init_client(
client_type, self._client_config, client_type, self._client_config,
self._server_endpoints, self._fetch_names) self._server_endpoints, self._fetch_names)
# user defined # user defined
self.init_op() self.init_op()
self._succ_init_op = True self._succ_init_op = True
...@@ -595,13 +594,12 @@ class Op(object): ...@@ -595,13 +594,12 @@ class Op(object):
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(True) self._profiler.enable(True)
# init client # init client
self.client = self.init_client( self.client = self.init_client(client_type, self._client_config,
client_type, self._client_config, self._server_endpoints,
self._server_endpoints, self._fetch_names)
self._fetch_names)
# user defined # user defined
self.init_op() self.init_op()
def _finalize(self, is_thread_op): def _finalize(self, is_thread_op):
if is_thread_op: if is_thread_op:
with self._for_close_op_lock: with self._for_close_op_lock:
...@@ -625,7 +623,7 @@ class RequestOp(Op): ...@@ -625,7 +623,7 @@ class RequestOp(Op):
try: try:
self.init_op() self.init_op()
except Exception as e: except Exception as e:
_LOGGER.error(e) _LOGGER.error("Op(Request) init op failed: {}".format(e))
os._exit(-1) os._exit(-1)
def unpack_request_package(self, request): def unpack_request_package(self, request):
...@@ -649,7 +647,7 @@ class ResponseOp(Op): ...@@ -649,7 +647,7 @@ class ResponseOp(Op):
try: try:
self.init_op() self.init_op()
except Exception as e: except Exception as e:
_LOGGER.error(e) _LOGGER.error("Op(ResponseOp) init op failed: {}".format(e))
os._exit(-1) os._exit(-1)
def pack_response_package(self, channeldata): def pack_response_package(self, channeldata):
...@@ -730,7 +728,7 @@ class VirtualOp(Op): ...@@ -730,7 +728,7 @@ class VirtualOp(Op):
try: try:
channeldata_dict = input_channel.front(self.name) channeldata_dict = input_channel.front(self.name)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("Channel stop."))
break break
try: try:
...@@ -738,5 +736,5 @@ class VirtualOp(Op): ...@@ -738,5 +736,5 @@ class VirtualOp(Op):
self._push_to_output_channels( self._push_to_output_channels(
data, channels=output_channels, name=name) data, channels=output_channels, name=name)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("Channel stop."))
break break
...@@ -61,7 +61,7 @@ class PipelineClient(object): ...@@ -61,7 +61,7 @@ class PipelineClient(object):
def _unpack_response_package(self, resp, fetch): def _unpack_response_package(self, resp, fetch):
if resp.ecode != 0: if resp.ecode != 0:
return { return {
"ecode": resp.ecode, "ecode": resp.ecode,
"ecode_desc": ChannelDataEcode(resp.ecode), "ecode_desc": ChannelDataEcode(resp.ecode),
"error_info": resp.error_info, "error_info": resp.error_info,
} }
......
...@@ -90,7 +90,7 @@ class PipelineServer(object): ...@@ -90,7 +90,7 @@ class PipelineServer(object):
for key, val in default_config.items(): for key, val in default_config.items():
if yml_config.get(key) is None: if yml_config.get(key) is None:
_LOGGER.warning("[CONF] {} not set, use default: {}" _LOGGER.warning("[CONF] {} not set, use default: {}"
.format(key, val)) .format(key, val))
yml_config[key] = val yml_config[key] = val
self._port = yml_config["port"] self._port = yml_config["port"]
...@@ -98,12 +98,13 @@ class PipelineServer(object): ...@@ -98,12 +98,13 @@ class PipelineServer(object):
raise SystemExit("Prot {} is already used".format(self._port)) raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = yml_config["worker_num"] self._worker_num = yml_config["worker_num"]
self._build_dag_each_worker = yml_config["build_dag_each_worker"] self._build_dag_each_worker = yml_config["build_dag_each_worker"]
_LOGGER.info("============= PIPELINE SERVER =============") _LOGGER.info("============= PIPELINE SERVER =============")
for key in default_config.keys(): for key in default_config.keys():
_LOGGER.info("{}: {}".format(key, yml_config[key])) _LOGGER.info("{}: {}".format(key, yml_config[key]))
if self._build_dag_each_worker is True: if self._build_dag_each_worker is True:
_LOGGER.info("(Make sure that install grpcio whl with --no-binary flag)") _LOGGER.info(
"(Make sure that install grpcio whl with --no-binary flag)")
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
self._dag_config = yml_config.get("dag", {}) self._dag_config = yml_config.get("dag", {})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册