提交 27ecb875 编写于 作者: B barriery

update code

上级 89f32061
...@@ -103,7 +103,7 @@ The meaning of each parameter is as follows: ...@@ -103,7 +103,7 @@ The meaning of each parameter is as follows:
| def process(self, feed_dict) | The RPC prediction process is based on the Paddle Serving Client, and the processed data will be used as the input of the **postprocess** function. | | def process(self, feed_dict) | The RPC prediction process is based on the Paddle Serving Client, and the processed data will be used as the input of the **postprocess** function. |
| def postprocess(self, input_dicts, fetch_dict) | After processing the prediction results, the processed data will be put into the subsequent Channel to be obtained by the subsequent OP. | | def postprocess(self, input_dicts, fetch_dict) | After processing the prediction results, the processed data will be put into the subsequent Channel to be obtained by the subsequent OP. |
| def init_op(self) | Used to load resources (such as word dictionary). | | def init_op(self) | Used to load resources (such as word dictionary). |
| self.concurrency_idx | Concurrency index of current thread / process (different kinds of OP are calculated separately). | | self.concurrency_idx | Concurrency index of current process(not thread) (different kinds of OP are calculated separately). |
In a running cycle, OP will execute three operations: preprocess, process, and postprocess (when the `server_endpoints` parameter is not set, the process operation is not executed). Users can rewrite these three functions. The default implementation is as follows: In a running cycle, OP will execute three operations: preprocess, process, and postprocess (when the `server_endpoints` parameter is not set, the process operation is not executed). Users can rewrite these three functions. The default implementation is as follows:
......
...@@ -103,7 +103,7 @@ def __init__(name=None, ...@@ -103,7 +103,7 @@ def __init__(name=None,
| def process(self, feed_dict) | 基于 Paddle Serving Client 进行 RPC 预测,处理完的数据将作为 **postprocess** 函数的输入。 | | def process(self, feed_dict) | 基于 Paddle Serving Client 进行 RPC 预测,处理完的数据将作为 **postprocess** 函数的输入。 |
| def postprocess(self, input_dicts, fetch_dict) | 处理预测结果,处理完的数据将被放入后继 Channel 中,以被后继 OP 获取。 | | def postprocess(self, input_dicts, fetch_dict) | 处理预测结果,处理完的数据将被放入后继 Channel 中,以被后继 OP 获取。 |
| def init_op(self) | 用于加载资源(如字典等)。 | | def init_op(self) | 用于加载资源(如字典等)。 |
| self.concurrency_idx | 当前线程(进程)的并发数索引(不同种类的 OP 单独计算)。 | | self.concurrency_idx | 当前进程(非线程)的并发数索引(不同种类的 OP 单独计算)。 |
OP 在一个运行周期中会依次执行 preprocess,process,postprocess 三个操作(当不设置 `server_endpoints` 参数时,不执行 process 操作),用户可以对这三个函数进行重写,默认实现如下: OP 在一个运行周期中会依次执行 preprocess,process,postprocess 三个操作(当不设置 `server_endpoints` 参数时,不执行 process 操作),用户可以对这三个函数进行重写,默认实现如下:
......
...@@ -13,19 +13,19 @@ ...@@ -13,19 +13,19 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server.pipeline.channel import ChannelDataEcode
import numpy as np
import logging import logging
from paddle_serving_app.reader import IMDBDataset logging.basicConfig(format='%(levelname)s:%(asctime)s:%(message)s',
level=logging.INFO)
logging.basicConfig(level=logging.DEBUG) from paddle_serving_server_gpu.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server_gpu.pipeline import PipelineServer
from paddle_serving_server_gpu.pipeline.proto import pipeline_service_pb2
from paddle_serving_server_gpu.pipeline.channel import ChannelDataEcode
import numpy as np
from paddle_serving_app.reader import IMDBDataset
_LOGGER = logging.getLogger() _LOGGER = logging.getLogger()
class ImdbRequestOp(RequestOp): class ImdbRequestOp(RequestOp):
def init_op(self): def init_op(self):
self.imdb_dataset = IMDBDataset() self.imdb_dataset = IMDBDataset()
...@@ -51,7 +51,6 @@ class CombineOp(Op): ...@@ -51,7 +51,6 @@ class CombineOp(Op):
data = {"prediction": combined_prediction / 2} data = {"prediction": combined_prediction / 2}
return data return data
class ImdbResponseOp(ResponseOp): class ImdbResponseOp(ResponseOp):
# Here ImdbResponseOp is consistent with the default ResponseOp implementation # Here ImdbResponseOp is consistent with the default ResponseOp implementation
def pack_response_package(self, channeldata): def pack_response_package(self, channeldata):
......
...@@ -238,10 +238,6 @@ class ProcessChannel(object): ...@@ -238,10 +238,6 @@ class ProcessChannel(object):
def _log(self, info_str): def _log(self, info_str):
return "[{}] {}".format(self.name, info_str) return "[{}] {}".format(self.name, info_str)
def debug(self):
return self._log("p: {}, c: {}".format(self.get_producers(),
self.get_consumers()))
def add_producer(self, op_name): def add_producer(self, op_name):
""" not thread safe, and can only be called during initialization. """ """ not thread safe, and can only be called during initialization. """
if op_name in self._producers: if op_name in self._producers:
...@@ -262,8 +258,7 @@ class ProcessChannel(object): ...@@ -262,8 +258,7 @@ class ProcessChannel(object):
def push(self, channeldata, op_name=None): def push(self, channeldata, op_name=None):
_LOGGER.debug( _LOGGER.debug(
self._log("{} try to push data: {}".format(op_name, self._log("{} try to push data[{}]".format(op_name, channeldata.id)))
channeldata.__str__())))
if len(self._producers) == 0: if len(self._producers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -279,9 +274,6 @@ class ProcessChannel(object): ...@@ -279,9 +274,6 @@ class ProcessChannel(object):
self._cv.wait() self._cv.wait()
if self._stop.value == 1: if self._stop.value == 1:
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug(
self._log("{} channel size: {}".format(op_name,
self._que.qsize())))
self._cv.notify_all() self._cv.notify_all()
_LOGGER.debug(self._log("{} notify all".format(op_name))) _LOGGER.debug(self._log("{} notify all".format(op_name)))
_LOGGER.debug(self._log("{} push data succ!".format(op_name))) _LOGGER.debug(self._log("{} push data succ!".format(op_name)))
...@@ -546,8 +538,7 @@ class ThreadChannel(Queue.Queue): ...@@ -546,8 +538,7 @@ class ThreadChannel(Queue.Queue):
def push(self, channeldata, op_name=None): def push(self, channeldata, op_name=None):
_LOGGER.debug( _LOGGER.debug(
self._log("{} try to push data: {}".format(op_name, self._log("{} try to push data[{}]".format(op_name, channeldata.id)))
channeldata.__str__())))
if len(self._producers) == 0: if len(self._producers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -564,7 +555,10 @@ class ThreadChannel(Queue.Queue): ...@@ -564,7 +555,10 @@ class ThreadChannel(Queue.Queue):
if self._stop: if self._stop:
raise ChannelStopError() raise ChannelStopError()
self._cv.notify_all() self._cv.notify_all()
_LOGGER.debug(self._log("{} push data succ!".format(op_name))) _LOGGER.debug(
self._log(
"{} succ push data[{}] into internal queue.".format(
op_name, channeldata.id)))
return True return True
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -575,7 +569,6 @@ class ThreadChannel(Queue.Queue): ...@@ -575,7 +569,6 @@ class ThreadChannel(Queue.Queue):
data_id = channeldata.id data_id = channeldata.id
put_data = None put_data = None
with self._cv: with self._cv:
_LOGGER.debug(self._log("{} get lock".format(op_name)))
if data_id not in self._input_buf: if data_id not in self._input_buf:
self._input_buf[data_id] = { self._input_buf[data_id] = {
name: None name: None
...@@ -592,8 +585,9 @@ class ThreadChannel(Queue.Queue): ...@@ -592,8 +585,9 @@ class ThreadChannel(Queue.Queue):
if put_data is None: if put_data is None:
_LOGGER.debug( _LOGGER.debug(
self._log("{} push data succ, but not push to queue.". self._log(
format(op_name))) "{} succ push data[{}] into input_buffer.".format(
op_name, data_id)))
else: else:
while self._stop is False: while self._stop is False:
try: try:
...@@ -605,11 +599,17 @@ class ThreadChannel(Queue.Queue): ...@@ -605,11 +599,17 @@ class ThreadChannel(Queue.Queue):
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name))) self._log(
"{} succ push data[{}] into internal queue.".format(
op_name, data_id)))
self._cv.notify_all() self._cv.notify_all()
return True return True
def front(self, op_name=None, timeout=None): def front(self, op_name=None, timeout=None):
_LOGGER.debug(
self._log(
"{} try to get data[?]; timeout={}".format(
op_name, timeout)))
endtime = None endtime = None
if timeout is not None: if timeout is not None:
if timeout <= 0: if timeout <= 0:
...@@ -617,7 +617,6 @@ class ThreadChannel(Queue.Queue): ...@@ -617,7 +617,6 @@ class ThreadChannel(Queue.Queue):
else: else:
endtime = _time() + timeout endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumer_cursors) == 0: if len(self._consumer_cursors) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -634,6 +633,9 @@ class ThreadChannel(Queue.Queue): ...@@ -634,6 +633,9 @@ class ThreadChannel(Queue.Queue):
if timeout is not None: if timeout is not None:
remaining = endtime - _time() remaining = endtime - _time()
if remaining <= 0.0: if remaining <= 0.0:
_LOGGER.debug(
self._log(
"{} get data[?] timeout".format(op_name)))
raise ChannelTimeoutError() raise ChannelTimeoutError()
self._cv.wait(remaining) self._cv.wait(remaining)
else: else:
...@@ -641,8 +643,8 @@ class ThreadChannel(Queue.Queue): ...@@ -641,8 +643,8 @@ class ThreadChannel(Queue.Queue):
if self._stop: if self._stop:
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__( self._log("{} succ get data[{}]".format(
)))) op_name, resp.values()[0].id)))
return resp return resp
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -667,11 +669,18 @@ class ThreadChannel(Queue.Queue): ...@@ -667,11 +669,18 @@ class ThreadChannel(Queue.Queue):
try: try:
channeldata = self.get(timeout=0) channeldata = self.get(timeout=0)
self._output_buf.append(channeldata) self._output_buf.append(channeldata)
_LOGGER.debug(
self._log(
"pop ready item[{}] into output_buffer".format(
channeldata.values()[0].id)))
break break
except Queue.Empty: except Queue.Empty:
if timeout is not None: if timeout is not None:
remaining = endtime - _time() remaining = endtime - _time()
if remaining <= 0.0: if remaining <= 0.0:
_LOGGER.debug(
self._log(
"{} get data[?] timeout".format(op_name)))
raise ChannelTimeoutError() raise ChannelTimeoutError()
self._cv.wait(remaining) self._cv.wait(remaining)
else: else:
...@@ -695,6 +704,8 @@ class ThreadChannel(Queue.Queue): ...@@ -695,6 +704,8 @@ class ThreadChannel(Queue.Queue):
self._base_cursor += 1 self._base_cursor += 1
# to avoid cursor overflow # to avoid cursor overflow
if self._base_cursor >= self._reset_max_cursor: if self._base_cursor >= self._reset_max_cursor:
_LOGGER.info(
self._log("reset cursor in Channel"))
self._base_cursor -= self._reset_max_cursor self._base_cursor -= self._reset_max_cursor
for name in self._consumer_cursors: for name in self._consumer_cursors:
self._consumer_cursors[name] -= self._reset_max_cursor self._consumer_cursors[name] -= self._reset_max_cursor
...@@ -704,7 +715,6 @@ class ThreadChannel(Queue.Queue): ...@@ -704,7 +715,6 @@ class ThreadChannel(Queue.Queue):
} }
else: else:
resp = copy.deepcopy(self._output_buf[data_idx]) resp = copy.deepcopy(self._output_buf[data_idx])
_LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
self._consumer_cursors[op_name] += 1 self._consumer_cursors[op_name] += 1
new_consumer_cursor = self._consumer_cursors[op_name] new_consumer_cursor = self._consumer_cursors[op_name]
...@@ -714,7 +724,10 @@ class ThreadChannel(Queue.Queue): ...@@ -714,7 +724,10 @@ class ThreadChannel(Queue.Queue):
self._cv.notify_all() self._cv.notify_all()
_LOGGER.debug(self._log("multi | {} get data succ!".format(op_name))) _LOGGER.debug(
self._log(
"{} succ get data[{}] from output_buffer".format(
op_name, resp.values()[0].id)))
return resp return resp
def stop(self): def stop(self):
......
...@@ -36,21 +36,31 @@ _LOGGER = logging.getLogger() ...@@ -36,21 +36,31 @@ _LOGGER = logging.getLogger()
class DAGExecutor(object): class DAGExecutor(object):
def __init__(self, response_op, dag_config, show_info): def __init__(self, response_op, dag_config, show_info):
self._retry = dag_config.get('retry', 1) default_conf = {
"retry": 1,
client_type = dag_config.get('client_type', 'brpc') "client_type": "brpc",
self._server_use_profile = dag_config.get('use_profile', False) "use_profile": False,
channel_size = dag_config.get('channel_size', 0) "channel_size": 0,
self._is_thread_op = dag_config.get('is_thread_op', True) "is_thread_op": True
}
if show_info and self._server_use_profile: for key, val in default_conf.items():
_LOGGER.info("================= PROFILER ================") if dag_config.get(key) is None:
if self._is_thread_op: _LOGGER.warning("[CONF] {} not set, use default: {}"
_LOGGER.info("op: thread") .format(key, val))
_LOGGER.info("profile mode: sync") dag_config[key] = val
else:
_LOGGER.info("op: process") self._retry = dag_config["retry"]
_LOGGER.info("profile mode: asyn") client_type = dag_config["client_type"]
self._server_use_profile = dag_config["use_profile"]
channel_size = dag_config["channel_size"]
self._is_thread_op = dag_config["is_thread_op"]
build_dag_each_worker = dag_config["build_dag_each_worker"]
if show_info:
_LOGGER.info("=============== DAGExecutor ===============")
for key in default_conf.keys():
_LOGGER.info("{}: {}".format(key, dag_config[key]))
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
self.name = "@G" self.name = "@G"
...@@ -59,7 +69,7 @@ class DAGExecutor(object): ...@@ -59,7 +69,7 @@ class DAGExecutor(object):
self._dag = DAG(self.name, response_op, self._server_use_profile, self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, client_type, channel_size, self._is_thread_op, client_type, channel_size,
show_info) show_info, build_dag_each_worker)
(in_channel, out_channel, pack_rpc_func, (in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build() unpack_rpc_func) = self._dag.build()
self._dag.start() self._dag.start()
...@@ -69,9 +79,6 @@ class DAGExecutor(object): ...@@ -69,9 +79,6 @@ class DAGExecutor(object):
self._pack_rpc_func = pack_rpc_func self._pack_rpc_func = pack_rpc_func
self._unpack_rpc_func = unpack_rpc_func self._unpack_rpc_func = unpack_rpc_func
_LOGGER.debug(self._log(in_channel.debug()))
_LOGGER.debug(self._log(out_channel.debug()))
self._id_lock = threading.Lock() self._id_lock = threading.Lock()
self._id_counter = 0 self._id_counter = 0
self._reset_max_id = 1000000000000000000 self._reset_max_id = 1000000000000000000
...@@ -87,10 +94,12 @@ class DAGExecutor(object): ...@@ -87,10 +94,12 @@ class DAGExecutor(object):
self._recive_func = threading.Thread( self._recive_func = threading.Thread(
target=DAGExecutor._recive_out_channel_func, args=(self, )) target=DAGExecutor._recive_out_channel_func, args=(self, ))
self._recive_func.start() self._recive_func.start()
_LOGGER.debug("[DAG Executor] start recive thread")
def stop(self): def stop(self):
self._dag.stop() self._dag.stop()
self._dag.join() self._dag.join()
_LOGGER.info("[DAG Executor] succ stop")
def _get_next_data_id(self): def _get_next_data_id(self):
with self._id_lock: with self._id_lock:
...@@ -102,7 +111,7 @@ class DAGExecutor(object): ...@@ -102,7 +111,7 @@ class DAGExecutor(object):
def _set_in_channel(self, in_channel): def _set_in_channel(self, in_channel):
if not isinstance(in_channel, (ThreadChannel, ProcessChannel)): if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
raise TypeError( raise TypeError(
self._log('in_channel must be Channel type, but get {}'.format( "in_channel must be Channel type, but get {}".format(
type(in_channel)))) type(in_channel))))
in_channel.add_producer(self.name) in_channel.add_producer(self.name)
self._in_channel = in_channel self._in_channel = in_channel
...@@ -110,7 +119,7 @@ class DAGExecutor(object): ...@@ -110,7 +119,7 @@ class DAGExecutor(object):
def _set_out_channel(self, out_channel): def _set_out_channel(self, out_channel):
if not isinstance(out_channel, (ThreadChannel, ProcessChannel)): if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
raise TypeError( raise TypeError(
self._log('out_channel must be Channel type, but get {}'.format( "iout_channel must be Channel type, but get {}".format(
type(out_channel)))) type(out_channel))))
out_channel.add_consumer(self.name) out_channel.add_consumer(self.name)
self._out_channel = out_channel self._out_channel = out_channel
...@@ -121,7 +130,7 @@ class DAGExecutor(object): ...@@ -121,7 +130,7 @@ class DAGExecutor(object):
try: try:
channeldata_dict = self._out_channel.front(self.name) channeldata_dict = self._out_channel.front(self.name)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(self._log("stop.")) _LOGGER.debug("[DAG Executor] channel stop.")
with self._cv_for_cv_pool: with self._cv_for_cv_pool:
for data_id, cv in self._cv_pool.items(): for data_id, cv in self._cv_pool.items():
closed_errror_data = ChannelData( closed_errror_data = ChannelData(
...@@ -134,16 +143,16 @@ class DAGExecutor(object): ...@@ -134,16 +143,16 @@ class DAGExecutor(object):
break break
if len(channeldata_dict) != 1: if len(channeldata_dict) != 1:
_LOGGER.error("out_channel cannot have multiple input ops") _LOGGER.error("[DAG Executor] out_channel cannot have multiple input ops")
os._exit(-1) os._exit(-1)
(_, channeldata), = channeldata_dict.items() (_, channeldata), = channeldata_dict.items()
if not isinstance(channeldata, ChannelData): if not isinstance(channeldata, ChannelData):
raise TypeError( _LOGGER.error('[DAG Executor] data must be ChannelData type, but get {}'
self._log('data must be ChannelData type, but get {}'. .format(type(channeldata))))
format(type(channeldata)))) os._exit(-1)
data_id = channeldata.id data_id = channeldata.id
_LOGGER.debug("recive thread fetch data: {}".format(data_id)) _LOGGER.debug("recive thread fetch data[{}]".format(data_id))
with self._cv_for_cv_pool: with self._cv_for_cv_pool:
cv = self._cv_pool[data_id] cv = self._cv_pool[data_id]
with cv: with cv:
...@@ -157,18 +166,19 @@ class DAGExecutor(object): ...@@ -157,18 +166,19 @@ class DAGExecutor(object):
self._cv_pool[data_id] = cv self._cv_pool[data_id] = cv
with cv: with cv:
cv.wait() cv.wait()
_LOGGER.debug("resp func get lock (data_id: {})".format(data_id))
resp = copy.deepcopy(self._fetch_buffer)
with self._cv_for_cv_pool: with self._cv_for_cv_pool:
resp = copy.deepcopy(self._fetch_buffer)
_LOGGER.debug("resp thread get resp data[{}]".format(data_id))
self._cv_pool.pop(data_id) self._cv_pool.pop(data_id)
return resp return resp
def _pack_channeldata(self, rpc_request, data_id): def _pack_channeldata(self, rpc_request, data_id):
_LOGGER.debug(self._log('start inferce'))
dictdata = None dictdata = None
try: try:
dictdata = self._unpack_rpc_func(rpc_request) dictdata = self._unpack_rpc_func(rpc_request)
except Exception as e: except Exception as e:
_LOGGER.error("parse RPC package to data[{}] Error: {}"
.format(data_id, e))
return ChannelData( return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value, ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error: {}".format(e), error_info="rpc package error: {}".format(e),
...@@ -181,50 +191,58 @@ class DAGExecutor(object): ...@@ -181,50 +191,58 @@ class DAGExecutor(object):
if key == self._client_profile_key: if key == self._client_profile_key:
profile_value = rpc_request.value[idx] profile_value = rpc_request.value[idx]
break break
client_need_profile = (profile_value == self._client_profile_value)
_LOGGER.debug("request[{}] need profile: {}".format(data_id, client_need_profile))
return ChannelData( return ChannelData(
datatype=ChannelDataType.DICT.value, datatype=ChannelDataType.DICT.value,
dictdata=dictdata, dictdata=dictdata,
data_id=data_id, data_id=data_id,
client_need_profile=( client_need_profile=client_need_profile)
profile_value == self._client_profile_value))
def call(self, rpc_request): def call(self, rpc_request):
data_id = self._get_next_data_id() data_id = self._get_next_data_id()
_LOGGER.debug("generate id: {}".format(data_id))
if not self._is_thread_op: if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id)) self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id))
else: else:
self._profiler.record("call_{}#DAG_0".format(data_id)) self._profiler.record("call_{}#DAG_0".format(data_id))
_LOGGER.debug("try parse RPC package to channeldata[{}]".format(data_id))
self._profiler.record("prepack_{}#{}_0".format(data_id, self.name)) self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
req_channeldata = self._pack_channeldata(rpc_request, data_id) req_channeldata = self._pack_channeldata(rpc_request, data_id)
self._profiler.record("prepack_{}#{}_1".format(data_id, self.name)) self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
resp_channeldata = None resp_channeldata = None
for i in range(self._retry): for i in range(self._retry):
_LOGGER.debug(self._log('push data')) _LOGGER.debug("push data[{}] into Graph engine".format(data_id))
#self._profiler.record("push_{}#{}_0".format(data_id, self.name))
try: try:
self._in_channel.push(req_channeldata, self.name) self._in_channel.push(req_channeldata, self.name)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(self._log("stop.")) _LOGGER.debug("[DAG Executor] channel stop.")
return self._pack_for_rpc_resp( return self._pack_for_rpc_resp(
ChannelData( ChannelData(
ecode=ChannelDataEcode.CLOSED_ERROR.value, ecode=ChannelDataEcode.CLOSED_ERROR.value,
error_info="dag closed.", error_info="dag closed.",
data_id=data_id)) data_id=data_id))
#self._profiler.record("push_{}#{}_1".format(data_id, self.name))
_LOGGER.debug(self._log('wait for infer')) _LOGGER.debug("wait for Graph engine for data[{}]...".format(data_id))
#self._profiler.record("fetch_{}#{}_0".format(data_id, self.name))
resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id) resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id)
#self._profiler.record("fetch_{}#{}_1".format(data_id, self.name))
if resp_channeldata.ecode == ChannelDataEcode.OK.value: if resp_channeldata.ecode == ChannelDataEcode.OK.value:
_LOGGER.debug("Graph engine predict data[{}] succ".format(data_id))
break break
else:
_LOGGER.warn("Graph engine predict data[{}] failed: {}"
.format(data_id, resp_channeldata.error_info))
if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value:
break
if i + 1 < self._retry: if i + 1 < self._retry:
_LOGGER.warn("retry({}): {}".format( _LOGGER.warn("retry({}/{}) data[{}]".format(
i + 1, resp_channeldata.error_info)) i + 1, self._retry, data_id))
_LOGGER.debug("unpack channeldata[{}] into RPC resp package".format(data_id))
self._profiler.record("postpack_{}#{}_0".format(data_id, self.name)) self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
rpc_resp = self._pack_for_rpc_resp(resp_channeldata) rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
self._profiler.record("postpack_{}#{}_1".format(data_id, self.name)) self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
...@@ -232,7 +250,6 @@ class DAGExecutor(object): ...@@ -232,7 +250,6 @@ class DAGExecutor(object):
self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id)) self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
else: else:
self._profiler.record("call_{}#DAG_1".format(data_id)) self._profiler.record("call_{}#DAG_1".format(data_id))
#self._profiler.print_profile()
profile_str = self._profiler.gen_profile_str() profile_str = self._profiler.gen_profile_str()
if self._server_use_profile: if self._server_use_profile:
...@@ -250,16 +267,12 @@ class DAGExecutor(object): ...@@ -250,16 +267,12 @@ class DAGExecutor(object):
return rpc_resp return rpc_resp
def _pack_for_rpc_resp(self, channeldata): def _pack_for_rpc_resp(self, channeldata):
_LOGGER.debug(self._log('get channeldata'))
return self._pack_rpc_func(channeldata) return self._pack_rpc_func(channeldata)
def _log(self, info_str):
return "[{}] {}".format(self.name, info_str)
class DAG(object): class DAG(object):
def __init__(self, request_name, response_op, use_profile, is_thread_op, def __init__(self, request_name, response_op, use_profile, is_thread_op,
client_type, channel_size, show_info): client_type, channel_size, show_info, build_dag_each_worker):
self._request_name = request_name self._request_name = request_name
self._response_op = response_op self._response_op = response_op
self._use_profile = use_profile self._use_profile = use_profile
...@@ -267,8 +280,10 @@ class DAG(object): ...@@ -267,8 +280,10 @@ class DAG(object):
self._channel_size = channel_size self._channel_size = channel_size
self._client_type = client_type self._client_type = client_type
self._show_info = show_info self._show_info = show_info
self._build_dag_each_worker = build_dag_each_worker
if not self._is_thread_op: if not self._is_thread_op:
self._manager = multiprocessing.Manager() self._manager = multiprocessing.Manager()
_LOGGER.info("[DAG] succ init")
def get_use_ops(self, response_op): def get_use_ops(self, response_op):
unique_names = set() unique_names = set()
...@@ -301,10 +316,13 @@ class DAG(object): ...@@ -301,10 +316,13 @@ class DAG(object):
else: else:
channel = ProcessChannel( channel = ProcessChannel(
self._manager, name=name_gen.next(), maxsize=self._channel_size) self._manager, name=name_gen.next(), maxsize=self._channel_size)
_LOGGER.debug("[DAG] gen Channel: {}".format(channel.name))
return channel return channel
def _gen_virtual_op(self, name_gen): def _gen_virtual_op(self, name_gen):
return VirtualOp(name=name_gen.next()) vir_op = VirtualOp(name=name_gen.next())
_LOGGER.debug("[DAG] gen VirtualOp: {}".format(vir_op.name))
return vir_op
def _topo_sort(self, used_ops, response_op, out_degree_ops): def _topo_sort(self, used_ops, response_op, out_degree_ops):
out_degree_num = { out_degree_num = {
...@@ -360,6 +378,11 @@ class DAG(object): ...@@ -360,6 +378,11 @@ class DAG(object):
raise Exception( raise Exception(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG." "Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
) )
if self._build_dag_each_worker:
_LOGGER.info("Because `build_dag_each_worker` mode is used, "
"Auto-batching is set to the default config.")
for op in used_ops:
op.use_default_auto_batching_config()
dag_views, last_op = self._topo_sort(used_ops, response_op, dag_views, last_op = self._topo_sort(used_ops, response_op,
out_degree_ops) out_degree_ops)
...@@ -414,7 +437,8 @@ class DAG(object): ...@@ -414,7 +437,8 @@ class DAG(object):
continue continue
channel = self._gen_channel(channel_name_gen) channel = self._gen_channel(channel_name_gen)
channels.append(channel) channels.append(channel)
_LOGGER.debug("{} => {}".format(channel.name, op.name)) _LOGGER.debug("[DAG] Channel({}) => Op({})"
.format(channel.name, op.name))
op.add_input_channel(channel) op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op.name] pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0: if v_idx == 0:
...@@ -422,8 +446,8 @@ class DAG(object): ...@@ -422,8 +446,8 @@ class DAG(object):
else: else:
# if pred_op is virtual op, it will use ancestors as producers to channel # if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops: for pred_op in pred_ops:
_LOGGER.debug("{} => {}".format(pred_op.name, _LOGGER.debug("[DAG] Op({}) => Channel({})"
channel.name)) .format(pred_op.name, channel.name))
pred_op.add_output_channel(channel) pred_op.add_output_channel(channel)
processed_op.add(op.name) processed_op.add(op.name)
# find same input op to combine channel # find same input op to combine channel
...@@ -439,8 +463,8 @@ class DAG(object): ...@@ -439,8 +463,8 @@ class DAG(object):
same_flag = False same_flag = False
break break
if same_flag: if same_flag:
_LOGGER.debug("{} => {}".format(channel.name, _LOGGER.debug("[DAG] Channel({}) => Op({})"
other_op.name)) .format(channel.name, other_op.name))
other_op.add_input_channel(channel) other_op.add_input_channel(channel)
processed_op.add(other_op.name) processed_op.add(other_op.name)
output_channel = self._gen_channel(channel_name_gen) output_channel = self._gen_channel(channel_name_gen)
...@@ -458,7 +482,8 @@ class DAG(object): ...@@ -458,7 +482,8 @@ class DAG(object):
actual_ops.append(op) actual_ops.append(op)
for c in channels: for c in channels:
_LOGGER.debug(c.debug()) _LOGGER.debug("Channel({}):\n -producers: {}\n -consumers: {}"
.format(c.name, c.get_producers(), c.get_consumers()))
return (actual_ops, channels, input_channel, output_channel, pack_func, return (actual_ops, channels, input_channel, output_channel, pack_func,
unpack_func) unpack_func)
...@@ -466,6 +491,7 @@ class DAG(object): ...@@ -466,6 +491,7 @@ class DAG(object):
def build(self): def build(self):
(actual_ops, channels, input_channel, output_channel, pack_func, (actual_ops, channels, input_channel, output_channel, pack_func,
unpack_func) = self._build_dag(self._response_op) unpack_func) = self._build_dag(self._response_op)
_LOGGER.info("[DAG] succ build dag")
self._actual_ops = actual_ops self._actual_ops = actual_ops
self._channels = channels self._channels = channels
...@@ -486,6 +512,8 @@ class DAG(object): ...@@ -486,6 +512,8 @@ class DAG(object):
else: else:
self._threads_or_proces.extend( self._threads_or_proces.extend(
op.start_with_process(self._client_type)) op.start_with_process(self._client_type))
_LOGGER.info("[DAG] start")
# not join yet # not join yet
return self._threads_or_proces return self._threads_or_proces
......
...@@ -67,8 +67,9 @@ class Op(object): ...@@ -67,8 +67,9 @@ class Op(object):
self._batch_size = batch_size self._batch_size = batch_size
self._auto_batching_timeout = auto_batching_timeout self._auto_batching_timeout = auto_batching_timeout
if self._auto_batching_timeout is not None and self._auto_batching_timeout <= 0: if self._auto_batching_timeout is not None:
self._auto_batching_timeout = None if self._auto_batching_timeout <= 0 or self._batch_size == 1:
self._auto_batching_timeout = None
self._server_use_profile = False self._server_use_profile = False
...@@ -78,6 +79,10 @@ class Op(object): ...@@ -78,6 +79,10 @@ class Op(object):
self._succ_init_op = False self._succ_init_op = False
self._succ_close_op = False self._succ_close_op = False
def use_default_auto_batching_config(self):
self._batch_size = 1
self._auto_batching_timeout = None
def use_profiler(self, use_profile): def use_profiler(self, use_profile):
self._server_use_profile = use_profile self._server_use_profile = use_profile
...@@ -89,11 +94,12 @@ class Op(object): ...@@ -89,11 +94,12 @@ class Op(object):
def init_client(self, client_type, client_config, server_endpoints, def init_client(self, client_type, client_config, server_endpoints,
fetch_names): fetch_names):
if self.with_serving == False: if self.with_serving == False:
_LOGGER.debug("{} no client".format(self.name)) _LOGGER.info("Op({}) no client".format(self.name))
return None return None
_LOGGER.debug("{} client_config: {}".format(self.name, client_config)) _LOGGER.info("Op({}) service endpoints: {}".format(self.name, server_endpoints))
_LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names)) _LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc': if client_type == 'brpc':
_LOGGER.debug("Op({}) client_config: {}".format(self.name, client_config))
client = Client() client = Client()
client.load_client_config(client_config) client.load_client_config(client_config)
elif client_type == 'grpc': elif client_type == 'grpc':
...@@ -163,7 +169,6 @@ class Op(object): ...@@ -163,7 +169,6 @@ class Op(object):
"{} Please override preprocess func.".format(err_info)) "{} Please override preprocess func.".format(err_info))
call_result = self.client.predict( call_result = self.client.predict(
feed=feed_batch, fetch=self._fetch_names) feed=feed_batch, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result return call_result
def postprocess(self, input_dict, fetch_dict): def postprocess(self, input_dict, fetch_dict):
...@@ -237,6 +242,7 @@ class Op(object): ...@@ -237,6 +242,7 @@ class Op(object):
pass pass
def _run_preprocess(self, parsed_data_dict, log_func): def _run_preprocess(self, parsed_data_dict, log_func):
_LOGGER.debug(log_func("try to run preprocess"))
preped_data_dict = {} preped_data_dict = {}
err_channeldata_dict = {} err_channeldata_dict = {}
for data_id, parsed_data in parsed_data_dict.items(): for data_id, parsed_data in parsed_data_dict.items():
...@@ -245,22 +251,27 @@ class Op(object): ...@@ -245,22 +251,27 @@ class Op(object):
preped_data = self.preprocess(parsed_data) preped_data = self.preprocess(parsed_data)
except NotImplementedError as e: except NotImplementedError as e:
# preprocess function not implemented # preprocess function not implemented
error_info = log_func(e) error_info = log_func(
_LOGGER.error(error_info) "preprocess data[{}] failed: {}".format(
data_id, e))
error_channeldata = ChannelData( error_channeldata = ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info, error_info=error_info,
data_id=data_id) data_id=data_id)
except TypeError as e: except TypeError as e:
# Error type in channeldata.datatype # Error type in channeldata.datatype
error_info = log_func(e) error_info = log_func(
"preprocess data[{}] failed: {}".format(
data_id, e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
error_channeldata = ChannelData( error_channeldata = ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value, ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info, error_info=error_info,
data_id=data_id) data_id=data_id)
except Exception as e: except Exception as e:
error_info = log_func(e) error_info = log_func(
"preprocess data[{}] failed: {}".format(
data_id, e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
error_channeldata = ChannelData( error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value, ecode=ChannelDataEcode.UNKNOW.value,
...@@ -270,9 +281,11 @@ class Op(object): ...@@ -270,9 +281,11 @@ class Op(object):
err_channeldata_dict[data_id] = error_channeldata err_channeldata_dict[data_id] = error_channeldata
else: else:
preped_data_dict[data_id] = preped_data preped_data_dict[data_id] = preped_data
_LOGGER.debug(log_func("succ run preprocess"))
return preped_data_dict, err_channeldata_dict return preped_data_dict, err_channeldata_dict
def _run_process(self, preped_data_dict, log_func): def _run_process(self, preped_data_dict, log_func):
_LOGGER.debug(log_func("try to run process"))
midped_data_dict = {} midped_data_dict = {}
err_channeldata_dict = {} err_channeldata_dict = {}
if self.with_serving: if self.with_serving:
...@@ -285,7 +298,7 @@ class Op(object): ...@@ -285,7 +298,7 @@ class Op(object):
midped_batch = self.process(feed_batch) midped_batch = self.process(feed_batch)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e) error_info = log_func("process batch failed: {}".format(e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
else: else:
for i in range(self._retry): for i in range(self._retry):
...@@ -299,10 +312,11 @@ class Op(object): ...@@ -299,10 +312,11 @@ class Op(object):
_LOGGER.error(error_info) _LOGGER.error(error_info)
else: else:
_LOGGER.warn( _LOGGER.warn(
log_func("timeout, retry({})".format(i + 1))) log_func("timeout, retry({}/{})"
.format(i + 1, self._retry)))
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e) error_info = log_func("process batch failed: {}".format(e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
break break
else: else:
...@@ -315,11 +329,13 @@ class Op(object): ...@@ -315,11 +329,13 @@ class Op(object):
data_id=data_id) data_id=data_id)
elif midped_batch is None: elif midped_batch is None:
# op client return None # op client return None
error_info=log_func(
"predict failed. pls check the server side.")
_LOGGER.error(error_info)
for data_id in data_ids: for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData( err_channeldata_dict[data_id] = ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value, ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=log_func( error_info=error_info,
"predict failed. pls check the server side."),
data_id=data_id) data_id=data_id)
else: else:
# transform np format to dict format # transform np format to dict format
...@@ -329,9 +345,11 @@ class Op(object): ...@@ -329,9 +345,11 @@ class Op(object):
} }
else: else:
midped_data_dict = preped_data_dict midped_data_dict = preped_data_dict
_LOGGER.debug(log_func("succ run process"))
return midped_data_dict, err_channeldata_dict return midped_data_dict, err_channeldata_dict
def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func): def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
_LOGGER.debug(log_func("try to run postprocess"))
postped_data_dict = {} postped_data_dict = {}
err_channeldata_dict = {} err_channeldata_dict = {}
for data_id, midped_data in midped_data_dict.items(): for data_id, midped_data in midped_data_dict.items():
...@@ -340,7 +358,8 @@ class Op(object): ...@@ -340,7 +358,8 @@ class Op(object):
postped_data = self.postprocess( postped_data = self.postprocess(
parsed_data_dict[data_id], midped_data) parsed_data_dict[data_id], midped_data)
except Exception as e: except Exception as e:
error_info = log_func(e) error_info = log_func("postprocess data[{}] failed: {}"
.format(data_id, e))
_LOGGER.error(error_info) _LOGGER.error(error_info)
err_channeldata = ChannelData( err_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value, ecode=ChannelDataEcode.UNKNOW.value,
...@@ -374,11 +393,17 @@ class Op(object): ...@@ -374,11 +393,17 @@ class Op(object):
dictdata=postped_data, dictdata=postped_data,
data_id=data_id) data_id=data_id)
postped_data_dict[data_id] = output_data postped_data_dict[data_id] = output_data
_LOGGER.debug(log_func("succ run postprocess"))
return postped_data_dict, err_channeldata_dict return postped_data_dict, err_channeldata_dict
def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout): def _auto_batching_generator(self, input_channel, op_name,
batch_size, timeout, log_func):
while True: while True:
batch = [] batch = []
_LOGGER.debug(
log_func(
"Auto-batching expect size: {}; timeout: {}".format(
batch_size, timeout)))
while len(batch) == 0: while len(batch) == 0:
endtime = None endtime = None
if timeout is not None: if timeout is not None:
...@@ -389,15 +414,16 @@ class Op(object): ...@@ -389,15 +414,16 @@ class Op(object):
if timeout is not None: if timeout is not None:
remaining = endtime - _time() remaining = endtime - _time()
if remaining <= 0.0: if remaining <= 0.0:
_LOGGER.info(log("auto-batching timeout")) _LOGGER.debug(log_func("Auto-batching timeout"))
break break
channeldata_dict = input_channel.front(op_name, timeout) channeldata_dict = input_channel.front(op_name, timeout)
else: else:
channeldata_dict = input_channel.front(op_name) channeldata_dict = input_channel.front(op_name)
batch.append(channeldata_dict) batch.append(channeldata_dict)
except ChannelTimeoutError: except ChannelTimeoutError:
_LOGGER.info(log("auto-batching timeout")) _LOGGER.debug(log_func("Auto-batching timeout"))
break break
_LOGGER.debug(log_func("Auto-batching actual size: {}".format(len(batch))))
yield batch yield batch
def _parse_channeldata_batch(self, batch, output_channels): def _parse_channeldata_batch(self, batch, output_channels):
...@@ -429,29 +455,28 @@ class Op(object): ...@@ -429,29 +455,28 @@ class Op(object):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = get_log_func(op_info_prefix) log = get_log_func(op_info_prefix)
preplog = get_log_func(op_info_prefix + "(prep)")
midplog = get_log_func(op_info_prefix + "(midp)")
postplog = get_log_func(op_info_prefix + "(postp)")
tid = threading.current_thread().ident tid = threading.current_thread().ident
# init op # init op
try: try:
self._initialize(is_thread_op, client_type) self._initialize(is_thread_op, client_type)
except Exception as e: except Exception as e:
_LOGGER.error(log(e)) _LOGGER.error(log("init op failed: {}".format(e)))
os._exit(-1) os._exit(-1)
_LOGGER.info(log("succ init"))
batch_generator = self._auto_batching_generator( batch_generator = self._auto_batching_generator(
input_channel=input_channel, input_channel=input_channel,
op_name=self.name, op_name=self.name,
batch_size=self._batch_size, batch_size=self._batch_size,
timeout=self._auto_batching_timeout) timeout=self._auto_batching_timeout,
log_func=log)
while True: while True:
try: try:
channeldata_dict_batch = next(batch_generator) channeldata_dict_batch = next(batch_generator)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
...@@ -461,7 +486,7 @@ class Op(object): ...@@ -461,7 +486,7 @@ class Op(object):
= self._parse_channeldata_batch( = self._parse_channeldata_batch(
channeldata_dict_batch, output_channels) channeldata_dict_batch, output_channels)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
if len(parsed_data_dict) == 0: if len(parsed_data_dict) == 0:
...@@ -471,7 +496,7 @@ class Op(object): ...@@ -471,7 +496,7 @@ class Op(object):
# preprecess # preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix)) self._profiler_record("prep#{}_0".format(op_info_prefix))
preped_data_dict, err_channeldata_dict \ preped_data_dict, err_channeldata_dict \
= self._run_preprocess(parsed_data_dict, preplog) = self._run_preprocess(parsed_data_dict, log)
self._profiler_record("prep#{}_1".format(op_info_prefix)) self._profiler_record("prep#{}_1".format(op_info_prefix))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
...@@ -481,7 +506,7 @@ class Op(object): ...@@ -481,7 +506,7 @@ class Op(object):
client_need_profile=need_profile_dict[data_id], client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id]) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
if len(parsed_data_dict) == 0: if len(parsed_data_dict) == 0:
...@@ -490,7 +515,7 @@ class Op(object): ...@@ -490,7 +515,7 @@ class Op(object):
# process # process
self._profiler_record("midp#{}_0".format(op_info_prefix)) self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data_dict, err_channeldata_dict \ midped_data_dict, err_channeldata_dict \
= self._run_process(preped_data_dict, midplog) = self._run_process(preped_data_dict, log)
self._profiler_record("midp#{}_1".format(op_info_prefix)) self._profiler_record("midp#{}_1".format(op_info_prefix))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
...@@ -500,7 +525,7 @@ class Op(object): ...@@ -500,7 +525,7 @@ class Op(object):
client_need_profile=need_profile_dict[data_id], client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id]) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
if len(midped_data_dict) == 0: if len(midped_data_dict) == 0:
...@@ -510,7 +535,7 @@ class Op(object): ...@@ -510,7 +535,7 @@ class Op(object):
self._profiler_record("postp#{}_0".format(op_info_prefix)) self._profiler_record("postp#{}_0".format(op_info_prefix))
postped_data_dict, err_channeldata_dict \ postped_data_dict, err_channeldata_dict \
= self._run_postprocess( = self._run_postprocess(
parsed_data_dict, midped_data_dict, postplog) parsed_data_dict, midped_data_dict, log)
self._profiler_record("postp#{}_1".format(op_info_prefix)) self._profiler_record("postp#{}_1".format(op_info_prefix))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
...@@ -520,7 +545,7 @@ class Op(object): ...@@ -520,7 +545,7 @@ class Op(object):
client_need_profile=need_profile_dict[data_id], client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id]) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
if len(postped_data_dict) == 0: if len(postped_data_dict) == 0:
...@@ -535,7 +560,7 @@ class Op(object): ...@@ -535,7 +560,7 @@ class Op(object):
client_need_profile=need_profile_dict[data_id], client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id]) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
......
...@@ -28,13 +28,14 @@ from .dag import DAGExecutor ...@@ -28,13 +28,14 @@ from .dag import DAGExecutor
_LOGGER = logging.getLogger() _LOGGER = logging.getLogger()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_config, show_info): def __init__(self, response_op, dag_config, show_info):
super(PipelineService, self).__init__() super(PipelineServicer, self).__init__()
# init dag executor # init dag executor
self._dag_executor = DAGExecutor( self._dag_executor = DAGExecutor(
response_op, dag_config, show_info=show_info) response_op, dag_config, show_info=show_info)
self._dag_executor.start() self._dag_executor.start()
_LOGGER.info("[PipelineServicer] succ init")
def inference(self, request, context): def inference(self, request, context):
resp = self._dag_executor.call(request) resp = self._dag_executor.call(request)
...@@ -80,26 +81,33 @@ class PipelineServer(object): ...@@ -80,26 +81,33 @@ class PipelineServer(object):
def prepare_server(self, yml_file): def prepare_server(self, yml_file):
with open(yml_file) as f: with open(yml_file) as f:
yml_config = yaml.load(f.read()) yml_config = yaml.load(f.read())
self._port = yml_config.get('port') default_config = {
if self._port is None: "port": 9292,
raise SystemExit("Please set *port* in [{}] yaml file.".format( "worker_num": 1,
yml_file)) "build_dag_each_worker": False,
}
for key, val in default_config.items():
if yml_config.get(key) is None:
_LOGGER.warning("[CONF] {} not set, use default: {}"
.format(key, val))
yml_config[key] = val
self._port = yml_config["port"]
if not self._port_is_available(self._port): if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port)) raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = yml_config.get('worker_num', 1) self._worker_num = yml_config["worker_num"]
self._build_dag_each_worker = yml_config.get('build_dag_each_worker', self._build_dag_each_worker = yml_config["build_dag_each_worker"]
False)
_LOGGER.info("============= PIPELINE SERVER =============") _LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("port: {}".format(self._port)) for key in default_config.keys():
_LOGGER.info("worker_num: {}".format(self._worker_num)) _LOGGER.info("{}: {}".format(key, yml_config[key]))
servicer_info = "build_dag_each_worker: {}".format(
self._build_dag_each_worker)
if self._build_dag_each_worker is True: if self._build_dag_each_worker is True:
servicer_info += " (Make sure that install grpcio whl with --no-binary flag)" _LOGGER.info("(Make sure that install grpcio whl with --no-binary flag)")
_LOGGER.info(servicer_info)
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
self._dag_config = yml_config.get("dag", {}) self._dag_config = yml_config.get("dag", {})
self._dag_config["build_dag_each_worker"] = self._build_dag_each_worker
def run_server(self): def run_server(self):
if self._build_dag_each_worker: if self._build_dag_each_worker:
...@@ -120,7 +128,7 @@ class PipelineServer(object): ...@@ -120,7 +128,7 @@ class PipelineServer(object):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._response_op, self._dag_config, True), PipelineServicer(self._response_op, self._dag_config, True),
server) server)
server.add_insecure_port('[::]:{}'.format(self._port)) server.add_insecure_port('[::]:{}'.format(self._port))
server.start() server.start()
...@@ -132,7 +140,7 @@ class PipelineServer(object): ...@@ -132,7 +140,7 @@ class PipelineServer(object):
futures.ThreadPoolExecutor( futures.ThreadPoolExecutor(
max_workers=1, ), options=options) max_workers=1, ), options=options)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(response_op, dag_config, False), server) PipelineServicer(response_op, dag_config, False), server)
server.add_insecure_port(bind_address) server.add_insecure_port(bind_address)
server.start() server.start()
server.wait_for_termination() server.wait_for_termination()
...@@ -22,7 +22,7 @@ elif sys.version_info.major == 3: ...@@ -22,7 +22,7 @@ elif sys.version_info.major == 3:
import queue as Queue import queue as Queue
else: else:
raise Exception("Error Python version") raise Exception("Error Python version")
import time from time import time as _time
import threading import threading
_LOGGER = logging.getLogger() _LOGGER = logging.getLogger()
...@@ -42,7 +42,7 @@ class TimeProfiler(object): ...@@ -42,7 +42,7 @@ class TimeProfiler(object):
def record(self, name_with_tag): def record(self, name_with_tag):
if self._enable is False: if self._enable is False:
return return
timestamp = int(round(time.time() * 1000000)) timestamp = int(round(_time() * 1000000))
name_with_tag = name_with_tag.split("_") name_with_tag = name_with_tag.split("_")
tag = name_with_tag[-1] tag = name_with_tag[-1]
name = '_'.join(name_with_tag[:-1]) name = '_'.join(name_with_tag[:-1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册