提交 29caa6d2 编写于 作者: B barriery

fix bug in dag-executor and update log

上级 94ea6590
......@@ -26,6 +26,7 @@ else:
import numpy as np
import logging
import enum
import os
import copy
_LOGGER = logging.getLogger()
......@@ -69,7 +70,8 @@ class ChannelData(object):
'''
if ecode is not None:
if data_id is None or error_info is None:
raise ValueError("data_id and error_info cannot be None")
_LOGGER.critical("data_id and error_info cannot be None")
os._exit(-1)
datatype = ChannelDataType.ERROR.value
else:
if datatype == ChannelDataType.CHANNEL_NPDATA.value:
......@@ -83,7 +85,8 @@ class ChannelData(object):
datatype = ChannelDataType.ERROR.value
_LOGGER.error(error_info)
else:
raise ValueError("datatype not match")
_LOGGER.critical("datatype not match")
os._exit(-1)
self.datatype = datatype
self.npdata = npdata
self.dictdata = dictdata
......@@ -168,7 +171,9 @@ class ChannelData(object):
# return dict
feed = self.dictdata
else:
raise TypeError("Error type({}) in datatype.".format(self.datatype))
_LOGGER.critical("Error type({}) in datatype.".format(
self.datatype))
os._exit(-1)
return feed
def __str__(self):
......@@ -241,30 +246,35 @@ class ProcessChannel(object):
def add_producer(self, op_name):
""" not thread safe, and can only be called during initialization. """
if op_name in self._producers:
raise ValueError(
_LOGGER.critical(
self._log("producer({}) is already in channel".format(op_name)))
os._exit(-1)
self._producers.append(op_name)
_LOGGER.debug(self._log("add a producer: {}".format(op_name)))
def add_consumer(self, op_name):
""" not thread safe, and can only be called during initialization. """
if op_name in self._consumer_cursors:
raise ValueError(
_LOGGER.critical(
self._log("consumer({}) is already in channel".format(op_name)))
os._exit(-1)
self._consumer_cursors[op_name] = 0
if self._cursor_count.get(0) is None:
self._cursor_count[0] = 0
self._cursor_count[0] += 1
_LOGGER.debug(self._log("add a consumer: {}".format(op_name)))
def push(self, channeldata, op_name=None):
_LOGGER.debug(
self._log("{} try to push data[{}]".format(op_name,
channeldata.id)))
if len(self._producers) == 0:
raise Exception(
_LOGGER.critical(
self._log(
"expected number of producers to be greater than 0, but the it is 0."
))
os._exit(-1)
elif len(self._producers) == 1:
with self._cv:
while self._stop.value == 0:
......@@ -281,9 +291,10 @@ class ProcessChannel(object):
op_name, channeldata.id)))
return True
elif op_name is None:
raise Exception(
_LOGGER.critical(
self._log(
"There are multiple producers, so op_name cannot be None."))
os._exit(-1)
producer_num = len(self._producers)
data_id = channeldata.id
......@@ -340,10 +351,11 @@ class ProcessChannel(object):
endtime = _time() + timeout
if len(self._consumer_cursors) == 0:
raise Exception(
_LOGGER.critical(
self._log(
"expected number of consumers to be greater than 0, but the it is 0."
))
os._exit(-1)
elif len(self._consumer_cursors) == 1:
resp = None
with self._cv:
......@@ -369,9 +381,10 @@ class ProcessChannel(object):
resp.values()[0].id)))
return resp
elif op_name is None:
raise Exception(
_LOGGER.critical(
self._log(
"There are multiple consumers, so op_name cannot be None."))
os._exit(-1)
# In output_buf, different Ops (according to op_name) have different
# cursors. In addition, there is a base_cursor. Their difference is
......@@ -450,7 +463,7 @@ class ProcessChannel(object):
return resp
def stop(self):
_LOGGER.debug(self._log("stop."))
_LOGGER.info(self._log("stop."))
self._stop.value = 1
with self._cv:
self._cv.notify_all()
......@@ -512,37 +525,38 @@ class ThreadChannel(Queue.Queue):
def _log(self, info_str):
return "[{}] {}".format(self.name, info_str)
def debug(self):
return self._log("p: {}, c: {}".format(self.get_producers(),
self.get_consumers()))
def add_producer(self, op_name):
""" not thread safe, and can only be called during initialization. """
if op_name in self._producers:
raise ValueError(
_LOGGER.critical(
self._log("producer({}) is already in channel".format(op_name)))
os._exit(-1)
self._producers.append(op_name)
_LOGGER.debug(self._log("add a producer: {}".format(op_name)))
def add_consumer(self, op_name):
""" not thread safe, and can only be called during initialization. """
if op_name in self._consumer_cursors:
raise ValueError(
_LOGGER.critical(
self._log("consumer({}) is already in channel".format(op_name)))
os._exit(-1)
self._consumer_cursors[op_name] = 0
if self._cursor_count.get(0) is None:
self._cursor_count[0] = 0
self._cursor_count[0] += 1
_LOGGER.debug(self._log("add a consumer: {}".format(op_name)))
def push(self, channeldata, op_name=None):
_LOGGER.debug(
self._log("{} try to push data[{}]".format(op_name,
channeldata.id)))
if len(self._producers) == 0:
raise Exception(
_LOGGER.critical(
self._log(
"expected number of producers to be greater than 0, but the it is 0."
))
os._exit(-1)
elif len(self._producers) == 1:
with self._cv:
while self._stop is False:
......@@ -559,9 +573,10 @@ class ThreadChannel(Queue.Queue):
op_name, channeldata.id)))
return True
elif op_name is None:
raise Exception(
_LOGGER.critical(
self._log(
"There are multiple producers, so op_name cannot be None."))
os._exit(-1)
producer_num = len(self._producers)
data_id = channeldata.id
......@@ -613,10 +628,11 @@ class ThreadChannel(Queue.Queue):
endtime = _time() + timeout
if len(self._consumer_cursors) == 0:
raise Exception(
_LOGGER.critical(
self._log(
"expected number of consumers to be greater than 0, but the it is 0."
))
os._exit(-1)
elif len(self._consumer_cursors) == 1:
resp = None
with self._cv:
......@@ -642,9 +658,10 @@ class ThreadChannel(Queue.Queue):
resp.values()[0].id)))
return resp
elif op_name is None:
raise Exception(
_LOGGER.critical(
self._log(
"There are multiple consumers, so op_name cannot be None."))
os._exit(-1)
# In output_buf, different Ops (according to op_name) have different
# cursors. In addition, there is a base_cursor. Their difference is
......@@ -723,7 +740,7 @@ class ThreadChannel(Queue.Queue):
return resp
def stop(self):
_LOGGER.debug(self._log("stop."))
_LOGGER.info(self._log("stop."))
self._stop = True
with self._cv:
self._cv.notify_all()
......
......@@ -35,33 +35,13 @@ _LOGGER = logging.getLogger()
class DAGExecutor(object):
def __init__(self, response_op, dag_config, show_info):
default_conf = {
"retry": 1,
"client_type": "brpc",
"use_profile": False,
"channel_size": 0,
"is_thread_op": True
}
for key, val in default_conf.items():
if dag_config.get(key) is None:
_LOGGER.warning("[CONF] {} not set, use default: {}"
.format(key, val))
dag_config[key] = val
self._retry = dag_config["retry"]
client_type = dag_config["client_type"]
self._server_use_profile = dag_config["use_profile"]
channel_size = dag_config["channel_size"]
self._is_thread_op = dag_config["is_thread_op"]
build_dag_each_worker = dag_config["build_dag_each_worker"]
if show_info:
_LOGGER.info("=============== DAGExecutor ===============")
for key in default_conf.keys():
_LOGGER.info("{}: {}".format(key, dag_config[key]))
_LOGGER.info("-------------------------------------------")
def __init__(self, response_op, dag_conf):
self._retry = dag_conf["retry"]
client_type = dag_conf["client_type"]
self._server_use_profile = dag_conf["use_profile"]
channel_size = dag_conf["channel_size"]
self._is_thread_op = dag_conf["is_thread_op"]
build_dag_each_worker = dag_conf["build_dag_each_worker"]
self.name = "@G"
self._profiler = TimeProfiler()
......@@ -69,7 +49,7 @@ class DAGExecutor(object):
self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, client_type, channel_size,
show_info, build_dag_each_worker)
build_dag_each_worker)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
self._dag.start()
......@@ -84,7 +64,7 @@ class DAGExecutor(object):
self._reset_max_id = 1000000000000000000
self._cv_pool = {}
self._cv_for_cv_pool = threading.Condition()
self._fetch_buffer = None
self._fetch_buffer = {}
self._recive_func = None
self._client_profile_key = "pipeline.profile"
......@@ -111,19 +91,22 @@ class DAGExecutor(object):
cond_v = threading.Condition()
with self._cv_for_cv_pool:
self._cv_pool[data_id] = cond_v
self._fetch_buffer[data_id] = None
return data_id, cond_v
def _set_in_channel(self, in_channel):
if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
raise TypeError("in_channel must be Channel type, but get {}".
format(type(in_channel)))
_LOGGER.critical("[DAG Executor] in_channel must be Channel"
" type, but get {}".format(type(in_channel)))
os._exit(-1)
in_channel.add_producer(self.name)
self._in_channel = in_channel
def _set_out_channel(self, out_channel):
if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
raise TypeError("iout_channel must be Channel type, but get {}".
format(type(out_channel)))
_LOGGER.critical("[DAG Executor]iout_channel must be Channel"
" type, but get {}".format(type(out_channel)))
os._exit(-1)
out_channel.add_consumer(self.name)
self._out_channel = out_channel
......@@ -133,7 +116,7 @@ class DAGExecutor(object):
try:
channeldata_dict = self._out_channel.front(self.name)
except ChannelStopError:
_LOGGER.debug("[DAG Executor] channel stop.")
_LOGGER.info("[DAG Executor] channel stop.")
with self._cv_for_cv_pool:
for data_id, cv in self._cv_pool.items():
closed_errror_data = ChannelData(
......@@ -141,17 +124,17 @@ class DAGExecutor(object):
error_info="dag closed.",
data_id=data_id)
with cv:
self._fetch_buffer = closed_errror_data
self._fetch_buffer[data_id] = closed_errror_data
cv.notify_all()
break
if len(channeldata_dict) != 1:
_LOGGER.error(
_LOGGER.critical(
"[DAG Executor] out_channel cannot have multiple input ops")
os._exit(-1)
(_, channeldata), = channeldata_dict.items()
if not isinstance(channeldata, ChannelData):
_LOGGER.error(
_LOGGER.critical(
'[DAG Executor] data must be ChannelData type, but get {}'
.format(type(channeldata)))
os._exit(-1)
......@@ -159,20 +142,30 @@ class DAGExecutor(object):
data_id = channeldata.id
_LOGGER.debug("recive thread fetch data[{}]".format(data_id))
with self._cv_for_cv_pool:
cv = self._cv_pool[data_id]
with cv:
self._fetch_buffer = channeldata
cv.notify_all()
cond_v = self._cv_pool[data_id]
with cond_v:
self._fetch_buffer[data_id] = channeldata
cond_v.notify_all()
def _get_channeldata_from_fetch_buffer(self, data_id, cond_v):
resp = None
ready_data = None
with cond_v:
with self._cv_for_cv_pool:
if self._fetch_buffer[data_id] is not None:
# The requested data is already ready
ready_data = self._fetch_buffer[data_id]
self._cv_pool.pop(data_id)
self._fetch_buffer.pop(data_id)
if ready_data is None:
# Wait for data ready
cond_v.wait()
with self._cv_for_cv_pool:
resp = copy.deepcopy(self._fetch_buffer)
_LOGGER.debug("resp thread get resp data[{}]".format(data_id))
ready_data = self._fetch_buffer[data_id]
self._cv_pool.pop(data_id)
return resp
self._fetch_buffer.pop(data_id)
_LOGGER.debug("resp thread get resp data[{}]".format(data_id))
return ready_data
def _pack_channeldata(self, rpc_request, data_id):
dictdata = None
......@@ -204,14 +197,14 @@ class DAGExecutor(object):
def call(self, rpc_request):
data_id, cond_v = self._get_next_data_id()
_LOGGER.debug("generate id: {}".format(data_id))
_LOGGER.debug("generate Request id: {}".format(data_id))
if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id))
else:
self._profiler.record("call_{}#DAG_0".format(data_id))
_LOGGER.debug("try parse RPC package to channeldata[{}]".format(
_LOGGER.debug("try parse RPC request to channeldata[{}]".format(
data_id))
self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
req_channeldata = self._pack_channeldata(rpc_request, data_id)
......@@ -232,26 +225,24 @@ class DAGExecutor(object):
error_info="dag closed.",
data_id=data_id))
_LOGGER.debug("wait for Graph engine for data[{}]...".format(
data_id))
_LOGGER.debug("wait Graph engine for data[{}]...".format(data_id))
resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id,
cond_v)
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
_LOGGER.debug("Graph engine predict data[{}] succ".format(
data_id))
_LOGGER.debug("request[{}] succ predict".format(data_id))
break
else:
_LOGGER.warn("Graph engine predict data[{}] failed: {}"
_LOGGER.warning("request[{}] predict failed: {}"
.format(data_id, resp_channeldata.error_info))
if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value:
break
if i + 1 < self._retry:
_LOGGER.warn("retry({}/{}) data[{}]".format(i + 1, self._retry,
data_id))
_LOGGER.warning("retry({}/{}) data[{}]".format(
i + 1, self._retry, data_id))
_LOGGER.debug("unpack channeldata[{}] into RPC resp package".format(
_LOGGER.debug("unpack channeldata[{}] into RPC response".format(
data_id))
self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
......@@ -282,14 +273,13 @@ class DAGExecutor(object):
class DAG(object):
def __init__(self, request_name, response_op, use_profile, is_thread_op,
client_type, channel_size, show_info, build_dag_each_worker):
client_type, channel_size, build_dag_each_worker):
self._request_name = request_name
self._response_op = response_op
self._use_profile = use_profile
self._is_thread_op = is_thread_op
self._channel_size = channel_size
self._client_type = client_type
self._show_info = show_info
self._build_dag_each_worker = build_dag_each_worker
if not self._is_thread_op:
self._manager = multiprocessing.Manager()
......@@ -313,8 +303,9 @@ class DAG(object):
used_ops.add(pred_op)
# check the name of op is globally unique
if pred_op.name in unique_names:
raise Exception("the name of Op must be unique: {}".
_LOGGER.critical("the name of Op must be unique: {}".
format(pred_op.name))
os._exit(-1)
unique_names.add(pred_op.name)
return used_ops, succ_ops_of_use_op
......@@ -346,7 +337,8 @@ class DAG(object):
if len(op.get_input_ops()) == 0:
zero_indegree_num += 1
if zero_indegree_num != 1:
raise Exception("DAG contains multiple input Ops")
_LOGGER.critical("DAG contains multiple RequestOps")
os._exit(-1)
last_op = response_op.get_input_ops()[0]
ques[que_idx].put(last_op)
......@@ -370,24 +362,27 @@ class DAG(object):
break
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(used_ops):
raise Exception("not legal DAG")
_LOGGER.critical("not legal DAG")
os._exit(-1)
return dag_views, last_op
def _build_dag(self, response_op):
if response_op is None:
raise Exception("response_op has not been set.")
_LOGGER.critical("ResponseOp has not been set.")
os._exit(-1)
used_ops, out_degree_ops = self.get_use_ops(response_op)
if self._show_info:
if not self._build_dag_each_worker:
_LOGGER.info("================= USED OP =================")
for op in used_ops:
if op.name != self._request_name:
_LOGGER.info(op.name)
_LOGGER.info("-------------------------------------------")
if len(used_ops) <= 1:
raise Exception(
_LOGGER.critical(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
)
os._exit(-1)
if self._build_dag_each_worker:
_LOGGER.info("Because `build_dag_each_worker` mode is used, "
"Auto-batching is set to the default config: "
......@@ -398,15 +393,15 @@ class DAG(object):
dag_views, last_op = self._topo_sort(used_ops, response_op,
out_degree_ops)
dag_views = list(reversed(dag_views))
if self._show_info:
_LOGGER.info("================== DAG ====================")
if not self._build_dag_each_worker:
_LOGGER.debug("================== DAG ====================")
for idx, view in enumerate(dag_views):
_LOGGER.info("(VIEW {})".format(idx))
_LOGGER.debug("(VIEW {})".format(idx))
for op in view:
_LOGGER.info(" [{}]".format(op.name))
_LOGGER.debug(" [{}]".format(op.name))
for out_op in out_degree_ops[op.name]:
_LOGGER.info(" - {}".format(out_op.name))
_LOGGER.info("-------------------------------------------")
_LOGGER.debug(" - {}".format(out_op.name))
_LOGGER.debug("-------------------------------------------")
# create channels and virtual ops
virtual_op_name_gen = NameGenerator("vir")
......@@ -493,7 +488,7 @@ class DAG(object):
actual_ops.append(op)
for c in channels:
_LOGGER.debug("Channel({}):\n -producers: {}\n -consumers: {}"
_LOGGER.debug("Channel({}):\n\t-producers: {}\n\t-consumers: {}"
.format(c.name, c.get_producers(), c.get_consumers()))
return (actual_ops, channels, input_channel, output_channel, pack_func,
......
......@@ -60,7 +60,10 @@ class Op(object):
self._client_config = client_config
self._fetch_names = fetch_list
self._timeout = timeout
if timeout > 0:
self._timeout = timeout / 1000.0
else:
self._timeout = -1
self._retry = max(1, retry)
self._input = None
self._outputs = []
......@@ -69,13 +72,32 @@ class Op(object):
self._auto_batching_timeout = auto_batching_timeout
if self._auto_batching_timeout is not None:
if self._auto_batching_timeout <= 0 or self._batch_size == 1:
_LOGGER.warning(
"Because auto_batching_timeout <= 0 or batch_size == 1,"
" set auto_batching_timeout to None.")
self._auto_batching_timeout = None
else:
self._auto_batching_timeout = self._auto_batching_timeout / 1000.0
if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
_LOGGER.info(
self._log("\n\tinput_ops: {},"
"\n\tserver_endpoints: {}"
"\n\tfetch_list: {}"
"\n\tclient_config: {}"
"\n\tconcurrency: {},"
"\n\ttimeout(s): {},"
"\n\tretry: {},"
"\n\tbatch_size: {},"
"\n\tauto_batching_timeout(s): {}".format(
", ".join([op.name for op in input_ops
]), self._server_endpoints,
self._fetch_names, self._client_config,
self.concurrency, self._timeout, self._retry,
self._batch_size, self._auto_batching_timeout)))
self._server_use_profile = False
# only for multithread
# only for thread op
self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
......@@ -83,11 +105,11 @@ class Op(object):
def use_default_auto_batching_config(self):
if self._batch_size != 1:
_LOGGER.warn("Op({}) reset batch_size=1 (original: {})"
_LOGGER.warning("Op({}) reset batch_size=1 (original: {})"
.format(self.name, self._batch_size))
self._batch_size = 1
if self._auto_batching_timeout != None:
_LOGGER.warn(
_LOGGER.warning(
"Op({}) reset auto_batching_timeout=None (original: {})"
.format(self.name, self._auto_batching_timeout))
self._auto_batching_timeout = None
......@@ -100,12 +122,7 @@ class Op(object):
if self.with_serving == False:
_LOGGER.info("Op({}) no client".format(self.name))
return None
_LOGGER.info("Op({}) service endpoints: {}".format(self.name,
server_endpoints))
_LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc':
_LOGGER.debug("Op({}) client_config: {}".format(self.name,
client_config))
client = Client()
client.load_client_config(client_config)
elif client_type == 'grpc':
......@@ -125,16 +142,18 @@ class Op(object):
self._input_ops = []
for op in ops:
if not isinstance(op, Op):
raise TypeError(
self._log('input op must be Op type, not {}'.format(
type(op))))
_LOGGER.critical(
self._log("input op must be Op type, not {}"
.format(type(op))))
os._exit(-1)
self._input_ops.append(op)
def add_input_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('input channel must be Channel type, not {}'.format(
type(channel))))
_LOGGER.critical(
self._log("input channel must be Channel type, not {}"
.format(type(channel))))
os._exit(-1)
channel.add_consumer(self.name)
self._input = channel
......@@ -146,9 +165,10 @@ class Op(object):
def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
_LOGGER.critical(
self._log("output channel must be Channel type, not {}"
.format(type(channel))))
os._exit(-1)
channel.add_producer(self.name)
self._outputs.append(channel)
......@@ -161,9 +181,11 @@ class Op(object):
def preprocess(self, input_dicts):
# multiple previous Op
if len(input_dicts) != 1:
raise NotImplementedError(
'this Op has multiple previous inputs. Please override this func.'
)
_LOGGER.critical(
self._log(
"this Op has multiple previous inputs. Please override this func."
))
os._exit(-1)
(_, input_dict), = input_dicts.items()
return input_dict
......@@ -171,8 +193,10 @@ class Op(object):
def process(self, feed_batch):
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
_LOGGER.critical(
self._log("{}, Please override preprocess func.".format(
err_info)))
os._exit(-1)
call_result = self.client.predict(
feed=feed_batch, fetch=self._fetch_names)
if isinstance(self.client, MultiLangClient):
......@@ -258,26 +282,18 @@ class Op(object):
preped_data, error_channeldata = None, None
try:
preped_data = self.preprocess(parsed_data)
except NotImplementedError as e:
# preprocess function not implemented
error_info = log_func("preprocess data[{}] failed: {}".format(
data_id, e))
error_channeldata = ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info,
data_id=data_id)
except TypeError as e:
# Error type in channeldata.datatype
error_info = log_func("preprocess data[{}] failed: {}".format(
data_id, e))
error_info = log_func("preprocess data[{}] failed: {}"
.format(data_id, e))
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id)
except Exception as e:
error_info = log_func("preprocess data[{}] failed: {}".format(
data_id, e))
error_info = log_func("preprocess data[{}] failed: {}"
.format(data_id, e))
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
......@@ -317,7 +333,7 @@ class Op(object):
error_info = log_func(e)
_LOGGER.error(error_info)
else:
_LOGGER.warn(
_LOGGER.warning(
log_func("PaddleService timeout, retry({}/{})"
.format(i + 1, self._retry)))
except Exception as e:
......@@ -376,7 +392,8 @@ class Op(object):
continue
else:
if not isinstance(postped_data, dict):
error_info = log_func("output of postprocess funticon must be " \
error_info = log_func(
"output of postprocess funticon must be "
"dict type, but get {}".format(type(postped_data)))
_LOGGER.error(error_info)
err_channeldata = ChannelData(
......@@ -471,7 +488,7 @@ class Op(object):
profiler = self._initialize(is_thread_op, client_type,
concurrency_idx)
except Exception as e:
_LOGGER.error(log("init op failed: {}".format(e)))
_LOGGER.critical(log("init op failed: {}".format(e)))
os._exit(-1)
_LOGGER.info(log("succ init"))
......@@ -629,7 +646,7 @@ class RequestOp(Op):
try:
self.init_op()
except Exception as e:
_LOGGER.error("Op(Request) init op failed: {}".format(e))
_LOGGER.critical("Op(Request) init op failed: {}".format(e))
os._exit(-1)
def unpack_request_package(self, request):
......@@ -653,7 +670,7 @@ class ResponseOp(Op):
try:
self.init_op()
except Exception as e:
_LOGGER.error("Op(ResponseOp) init op failed: {}".format(e))
_LOGGER.critical("Op(ResponseOp) init op failed: {}".format(e))
os._exit(-1)
def pack_response_package(self, channeldata):
......@@ -710,9 +727,10 @@ class VirtualOp(Op):
def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
_LOGGER.critical(
self._log("output channel must be Channel type, not {}"
.format(type(channel))))
os._exit(-1)
for op in self._virtual_pred_ops:
for op_name in self._actual_pred_op_names(op):
channel.add_producer(op_name)
......@@ -730,17 +748,27 @@ class VirtualOp(Op):
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
batch_generator = self._auto_batching_generator(
input_channel=input_channel,
op_name=self.name,
batch_size=1,
timeout=None,
log_func=log)
while True:
try:
channeldata_dict = input_channel.front(self.name)
channeldata_dict_batch = next(batch_generator)
except ChannelStopError:
_LOGGER.debug(log("Channel stop."))
_LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op)
break
try:
for channeldata_dict in channeldata_dict_batch:
for name, data in channeldata_dict.items():
self._push_to_output_channels(
data, channels=output_channels, name=name)
except ChannelStopError:
_LOGGER.debug(log("Channel stop."))
self._finalize(is_thread_op)
break
......@@ -15,6 +15,7 @@
from concurrent import futures
import grpc
import logging
import json
import socket
import contextlib
from contextlib import closing
......@@ -29,11 +30,10 @@ _LOGGER = logging.getLogger()
class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_config, show_info):
def __init__(self, response_op, dag_conf):
super(PipelineServicer, self).__init__()
# init dag executor
self._dag_executor = DAGExecutor(
response_op, dag_config, show_info=show_info)
self._dag_executor = DAGExecutor(response_op, dag_conf)
self._dag_executor.start()
_LOGGER.info("[PipelineServicer] succ init")
......@@ -79,36 +79,25 @@ class PipelineServer(object):
return result != 0
def prepare_server(self, yml_file):
with open(yml_file) as f:
yml_config = yaml.load(f.read())
default_config = {
"port": 9292,
"worker_num": 1,
"build_dag_each_worker": False,
}
conf = ServerYamlConfChecker.load_server_yaml_conf(yml_file)
for key, val in default_config.items():
if yml_config.get(key) is None:
_LOGGER.warning("[CONF] {} not set, use default: {}"
.format(key, val))
yml_config[key] = val
self._port = yml_config["port"]
self._port = conf["port"]
if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = yml_config["worker_num"]
self._build_dag_each_worker = yml_config["build_dag_each_worker"]
self._worker_num = conf["worker_num"]
self._build_dag_each_worker = conf["build_dag_each_worker"]
_LOGGER.info("============= PIPELINE SERVER =============")
for key in default_config.keys():
_LOGGER.info("{}: {}".format(key, yml_config[key]))
_LOGGER.info("\n{}".format(
json.dumps(
conf, indent=4, separators=(',', ':'))))
if self._build_dag_each_worker is True:
_LOGGER.info(
"(Make sure that install grpcio whl with --no-binary flag)")
_LOGGER.info("-------------------------------------------")
self._dag_config = yml_config.get("dag", {})
self._dag_config["build_dag_each_worker"] = self._build_dag_each_worker
self._dag_conf = conf["dag"]
self._dag_conf["build_dag_each_worker"] = self._build_dag_each_worker
def run_server(self):
if self._build_dag_each_worker:
......@@ -119,8 +108,7 @@ class PipelineServer(object):
show_info = (i == 0)
worker = multiprocessing.Process(
target=self._run_server_func,
args=(bind_address, self._response_op,
self._dag_config))
args=(bind_address, self._response_op, self._dag_conf))
worker.start()
workers.append(worker)
for worker in workers:
......@@ -129,19 +117,140 @@ class PipelineServer(object):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(self._response_op, self._dag_config, True),
server)
PipelineServicer(self._response_op, self._dag_conf), server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
server.wait_for_termination()
def _run_server_func(self, bind_address, response_op, dag_config):
def _run_server_func(self, bind_address, response_op, dag_conf):
options = (('grpc.so_reuseport', 1), )
server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=1, ), options=options)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(response_op, dag_config, False), server)
PipelineServicer(response_op, dag_conf), server)
server.add_insecure_port(bind_address)
server.start()
server.wait_for_termination()
class ServerYamlConfChecker(object):
def __init__(self):
pass
@staticmethod
def load_server_yaml_conf(yml_file):
with open(yml_file) as f:
conf = yaml.load(f.read())
ServerYamlConfChecker.check_server_conf(conf)
ServerYamlConfChecker.check_dag_conf(conf["dag"])
return conf
@staticmethod
def check_server_conf(conf):
default_conf = {
"port": 9292,
"worker_num": 1,
"build_dag_each_worker": False,
"dag": {},
}
ServerYamlConfChecker.fill_with_default_conf(conf, default_conf)
conf_type = {
"port": int,
"worker_num": int,
"build_dag_each_worker": bool,
}
ServerYamlConfChecker.check_conf_type(conf, conf_type)
conf_qualification = {
"port": [(">=", 1024), ("<=", 65535)],
"worker_num": (">=", 1),
}
ServerYamlConfChecker.check_conf_qualification(conf, conf_qualification)
@staticmethod
def check_dag_conf(conf):
default_conf = {
"retry": 1,
"client_type": "brpc",
"use_profile": False,
"channel_size": 0,
"is_thread_op": True
}
ServerYamlConfChecker.fill_with_default_conf(conf, default_conf)
conf_type = {
"retry": int,
"client_type": str,
"use_profile": bool,
"channel_size": int,
"is_thread_op": bool,
}
ServerYamlConfChecker.check_conf_type(conf, conf_type)
conf_qualification = {
"retry": (">=", 1),
"client_type": ("in", ["brpc", "grpc"]),
"channel_size": (">=", 0),
}
ServerYamlConfChecker.check_conf_qualification(conf, conf_qualification)
@staticmethod
def fill_with_default_conf(conf, default_conf):
for key, val in default_conf.items():
if conf.get(key) is None:
_LOGGER.warning("[CONF] {} not set, use default: {}"
.format(key, val))
conf[key] = val
@staticmethod
def check_conf_type(conf, conf_type):
for key, val in conf_type.items():
if not isinstance(conf[key], val):
raise SystemExit("[CONF] {} must be {} type, but get {}."
.format(key, val, type(conf[key])))
@staticmethod
def check_conf_qualification(conf, conf_qualification):
for key, qualification in conf_qualification.items():
if not isinstance(qualification, list):
qualification = [qualification]
if not ServerYamlConfChecker.qualification_check(conf[key],
qualification):
raise SystemExit("[CONF] {} must be {}, but get {}."
.format(key, ", ".join([
"{} {}"
.format(q[0], q[1]) for q in qualification
]), conf[key]))
@staticmethod
def qualification_check(value, qualifications):
if not isinstance(qualifications, list):
qualifications = [qualifications]
ok = True
for q in qualifications:
operator, limit = q
if operator == "<":
ok = value < limit
elif operator == "==":
ok = value == limit
elif operator == ">":
ok = value > limit
elif operator == "<=":
ok = value <= limit
elif operator == ">=":
ok = value >= limit
elif operator == "in":
ok = value in limit
else:
raise SystemExit("unknow operator: {}".format(operator))
if ok == False:
break
return ok
......@@ -29,6 +29,8 @@ _LOGGER = logging.getLogger()
class UnsafeTimeProfiler(object):
""" thread unsafe profiler """
def __init__(self):
self.pid = os.getpid()
self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册