提交 27ecb875 编写于 作者: B barriery

update code

上级 89f32061
......@@ -103,7 +103,7 @@ The meaning of each parameter is as follows:
| def process(self, feed_dict) | The RPC prediction process is based on the Paddle Serving Client, and the processed data will be used as the input of the **postprocess** function. |
| def postprocess(self, input_dicts, fetch_dict) | After processing the prediction results, the processed data will be put into the subsequent Channel to be obtained by the subsequent OP. |
| def init_op(self) | Used to load resources (such as word dictionary). |
| self.concurrency_idx | Concurrency index of current thread / process (different kinds of OP are calculated separately). |
| self.concurrency_idx | Concurrency index of current process(not thread) (different kinds of OP are calculated separately). |
In a running cycle, OP will execute three operations: preprocess, process, and postprocess (when the `server_endpoints` parameter is not set, the process operation is not executed). Users can rewrite these three functions. The default implementation is as follows:
......
......@@ -103,7 +103,7 @@ def __init__(name=None,
| def process(self, feed_dict) | 基于 Paddle Serving Client 进行 RPC 预测,处理完的数据将作为 **postprocess** 函数的输入。 |
| def postprocess(self, input_dicts, fetch_dict) | 处理预测结果,处理完的数据将被放入后继 Channel 中,以被后继 OP 获取。 |
| def init_op(self) | 用于加载资源(如字典等)。 |
| self.concurrency_idx | 当前线程(进程)的并发数索引(不同种类的 OP 单独计算)。 |
| self.concurrency_idx | 当前进程(非线程)的并发数索引(不同种类的 OP 单独计算)。 |
OP 在一个运行周期中会依次执行 preprocess,process,postprocess 三个操作(当不设置 `server_endpoints` 参数时,不执行 process 操作),用户可以对这三个函数进行重写,默认实现如下:
......
......@@ -13,19 +13,19 @@
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server.pipeline.channel import ChannelDataEcode
import numpy as np
import logging
from paddle_serving_app.reader import IMDBDataset
logging.basicConfig(format='%(levelname)s:%(asctime)s:%(message)s',
level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
from paddle_serving_server_gpu.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server_gpu.pipeline import PipelineServer
from paddle_serving_server_gpu.pipeline.proto import pipeline_service_pb2
from paddle_serving_server_gpu.pipeline.channel import ChannelDataEcode
import numpy as np
from paddle_serving_app.reader import IMDBDataset
_LOGGER = logging.getLogger()
class ImdbRequestOp(RequestOp):
def init_op(self):
self.imdb_dataset = IMDBDataset()
......@@ -51,7 +51,6 @@ class CombineOp(Op):
data = {"prediction": combined_prediction / 2}
return data
class ImdbResponseOp(ResponseOp):
# Here ImdbResponseOp is consistent with the default ResponseOp implementation
def pack_response_package(self, channeldata):
......
......@@ -238,10 +238,6 @@ class ProcessChannel(object):
def _log(self, info_str):
return "[{}] {}".format(self.name, info_str)
def debug(self):
return self._log("p: {}, c: {}".format(self.get_producers(),
self.get_consumers()))
def add_producer(self, op_name):
""" not thread safe, and can only be called during initialization. """
if op_name in self._producers:
......@@ -262,8 +258,7 @@ class ProcessChannel(object):
def push(self, channeldata, op_name=None):
_LOGGER.debug(
self._log("{} try to push data: {}".format(op_name,
channeldata.__str__())))
self._log("{} try to push data[{}]".format(op_name, channeldata.id)))
if len(self._producers) == 0:
raise Exception(
self._log(
......@@ -279,9 +274,6 @@ class ProcessChannel(object):
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} channel size: {}".format(op_name,
self._que.qsize())))
self._cv.notify_all()
_LOGGER.debug(self._log("{} notify all".format(op_name)))
_LOGGER.debug(self._log("{} push data succ!".format(op_name)))
......@@ -546,8 +538,7 @@ class ThreadChannel(Queue.Queue):
def push(self, channeldata, op_name=None):
_LOGGER.debug(
self._log("{} try to push data: {}".format(op_name,
channeldata.__str__())))
self._log("{} try to push data[{}]".format(op_name, channeldata.id)))
if len(self._producers) == 0:
raise Exception(
self._log(
......@@ -564,7 +555,10 @@ class ThreadChannel(Queue.Queue):
if self._stop:
raise ChannelStopError()
self._cv.notify_all()
_LOGGER.debug(self._log("{} push data succ!".format(op_name)))
_LOGGER.debug(
self._log(
"{} succ push data[{}] into internal queue.".format(
op_name, channeldata.id)))
return True
elif op_name is None:
raise Exception(
......@@ -575,7 +569,6 @@ class ThreadChannel(Queue.Queue):
data_id = channeldata.id
put_data = None
with self._cv:
_LOGGER.debug(self._log("{} get lock".format(op_name)))
if data_id not in self._input_buf:
self._input_buf[data_id] = {
name: None
......@@ -592,8 +585,9 @@ class ThreadChannel(Queue.Queue):
if put_data is None:
_LOGGER.debug(
self._log("{} push data succ, but not push to queue.".
format(op_name)))
self._log(
"{} succ push data[{}] into input_buffer.".format(
op_name, data_id)))
else:
while self._stop is False:
try:
......@@ -605,11 +599,17 @@ class ThreadChannel(Queue.Queue):
raise ChannelStopError()
_LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name)))
self._log(
"{} succ push data[{}] into internal queue.".format(
op_name, data_id)))
self._cv.notify_all()
return True
def front(self, op_name=None, timeout=None):
_LOGGER.debug(
self._log(
"{} try to get data[?]; timeout={}".format(
op_name, timeout)))
endtime = None
if timeout is not None:
if timeout <= 0:
......@@ -617,7 +617,6 @@ class ThreadChannel(Queue.Queue):
else:
endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumer_cursors) == 0:
raise Exception(
self._log(
......@@ -634,6 +633,9 @@ class ThreadChannel(Queue.Queue):
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
_LOGGER.debug(
self._log(
"{} get data[?] timeout".format(op_name)))
raise ChannelTimeoutError()
self._cv.wait(remaining)
else:
......@@ -641,8 +643,8 @@ class ThreadChannel(Queue.Queue):
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
self._log("{} succ get data[{}]".format(
op_name, resp.values()[0].id)))
return resp
elif op_name is None:
raise Exception(
......@@ -667,11 +669,18 @@ class ThreadChannel(Queue.Queue):
try:
channeldata = self.get(timeout=0)
self._output_buf.append(channeldata)
_LOGGER.debug(
self._log(
"pop ready item[{}] into output_buffer".format(
channeldata.values()[0].id)))
break
except Queue.Empty:
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
_LOGGER.debug(
self._log(
"{} get data[?] timeout".format(op_name)))
raise ChannelTimeoutError()
self._cv.wait(remaining)
else:
......@@ -695,6 +704,8 @@ class ThreadChannel(Queue.Queue):
self._base_cursor += 1
# to avoid cursor overflow
if self._base_cursor >= self._reset_max_cursor:
_LOGGER.info(
self._log("reset cursor in Channel"))
self._base_cursor -= self._reset_max_cursor
for name in self._consumer_cursors:
self._consumer_cursors[name] -= self._reset_max_cursor
......@@ -704,7 +715,6 @@ class ThreadChannel(Queue.Queue):
}
else:
resp = copy.deepcopy(self._output_buf[data_idx])
_LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
self._consumer_cursors[op_name] += 1
new_consumer_cursor = self._consumer_cursors[op_name]
......@@ -714,7 +724,10 @@ class ThreadChannel(Queue.Queue):
self._cv.notify_all()
_LOGGER.debug(self._log("multi | {} get data succ!".format(op_name)))
_LOGGER.debug(
self._log(
"{} succ get data[{}] from output_buffer".format(
op_name, resp.values()[0].id)))
return resp
def stop(self):
......
......@@ -36,21 +36,31 @@ _LOGGER = logging.getLogger()
class DAGExecutor(object):
def __init__(self, response_op, dag_config, show_info):
self._retry = dag_config.get('retry', 1)
client_type = dag_config.get('client_type', 'brpc')
self._server_use_profile = dag_config.get('use_profile', False)
channel_size = dag_config.get('channel_size', 0)
self._is_thread_op = dag_config.get('is_thread_op', True)
default_conf = {
"retry": 1,
"client_type": "brpc",
"use_profile": False,
"channel_size": 0,
"is_thread_op": True
}
if show_info and self._server_use_profile:
_LOGGER.info("================= PROFILER ================")
if self._is_thread_op:
_LOGGER.info("op: thread")
_LOGGER.info("profile mode: sync")
else:
_LOGGER.info("op: process")
_LOGGER.info("profile mode: asyn")
for key, val in default_conf.items():
if dag_config.get(key) is None:
_LOGGER.warning("[CONF] {} not set, use default: {}"
.format(key, val))
dag_config[key] = val
self._retry = dag_config["retry"]
client_type = dag_config["client_type"]
self._server_use_profile = dag_config["use_profile"]
channel_size = dag_config["channel_size"]
self._is_thread_op = dag_config["is_thread_op"]
build_dag_each_worker = dag_config["build_dag_each_worker"]
if show_info:
_LOGGER.info("=============== DAGExecutor ===============")
for key in default_conf.keys():
_LOGGER.info("{}: {}".format(key, dag_config[key]))
_LOGGER.info("-------------------------------------------")
self.name = "@G"
......@@ -59,7 +69,7 @@ class DAGExecutor(object):
self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, client_type, channel_size,
show_info)
show_info, build_dag_each_worker)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
self._dag.start()
......@@ -69,9 +79,6 @@ class DAGExecutor(object):
self._pack_rpc_func = pack_rpc_func
self._unpack_rpc_func = unpack_rpc_func
_LOGGER.debug(self._log(in_channel.debug()))
_LOGGER.debug(self._log(out_channel.debug()))
self._id_lock = threading.Lock()
self._id_counter = 0
self._reset_max_id = 1000000000000000000
......@@ -87,10 +94,12 @@ class DAGExecutor(object):
self._recive_func = threading.Thread(
target=DAGExecutor._recive_out_channel_func, args=(self, ))
self._recive_func.start()
_LOGGER.debug("[DAG Executor] start recive thread")
def stop(self):
self._dag.stop()
self._dag.join()
_LOGGER.info("[DAG Executor] succ stop")
def _get_next_data_id(self):
with self._id_lock:
......@@ -102,7 +111,7 @@ class DAGExecutor(object):
def _set_in_channel(self, in_channel):
if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('in_channel must be Channel type, but get {}'.format(
"in_channel must be Channel type, but get {}".format(
type(in_channel))))
in_channel.add_producer(self.name)
self._in_channel = in_channel
......@@ -110,7 +119,7 @@ class DAGExecutor(object):
def _set_out_channel(self, out_channel):
if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('out_channel must be Channel type, but get {}'.format(
"iout_channel must be Channel type, but get {}".format(
type(out_channel))))
out_channel.add_consumer(self.name)
self._out_channel = out_channel
......@@ -121,7 +130,7 @@ class DAGExecutor(object):
try:
channeldata_dict = self._out_channel.front(self.name)
except ChannelStopError:
_LOGGER.debug(self._log("stop."))
_LOGGER.debug("[DAG Executor] channel stop.")
with self._cv_for_cv_pool:
for data_id, cv in self._cv_pool.items():
closed_errror_data = ChannelData(
......@@ -134,16 +143,16 @@ class DAGExecutor(object):
break
if len(channeldata_dict) != 1:
_LOGGER.error("out_channel cannot have multiple input ops")
_LOGGER.error("[DAG Executor] out_channel cannot have multiple input ops")
os._exit(-1)
(_, channeldata), = channeldata_dict.items()
if not isinstance(channeldata, ChannelData):
raise TypeError(
self._log('data must be ChannelData type, but get {}'.
format(type(channeldata))))
_LOGGER.error('[DAG Executor] data must be ChannelData type, but get {}'
.format(type(channeldata))))
os._exit(-1)
data_id = channeldata.id
_LOGGER.debug("recive thread fetch data: {}".format(data_id))
_LOGGER.debug("recive thread fetch data[{}]".format(data_id))
with self._cv_for_cv_pool:
cv = self._cv_pool[data_id]
with cv:
......@@ -157,18 +166,19 @@ class DAGExecutor(object):
self._cv_pool[data_id] = cv
with cv:
cv.wait()
_LOGGER.debug("resp func get lock (data_id: {})".format(data_id))
resp = copy.deepcopy(self._fetch_buffer)
with self._cv_for_cv_pool:
resp = copy.deepcopy(self._fetch_buffer)
_LOGGER.debug("resp thread get resp data[{}]".format(data_id))
self._cv_pool.pop(data_id)
return resp
def _pack_channeldata(self, rpc_request, data_id):
_LOGGER.debug(self._log('start inferce'))
dictdata = None
try:
dictdata = self._unpack_rpc_func(rpc_request)
except Exception as e:
_LOGGER.error("parse RPC package to data[{}] Error: {}"
.format(data_id, e))
return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error: {}".format(e),
......@@ -181,50 +191,58 @@ class DAGExecutor(object):
if key == self._client_profile_key:
profile_value = rpc_request.value[idx]
break
client_need_profile = (profile_value == self._client_profile_value)
_LOGGER.debug("request[{}] need profile: {}".format(data_id, client_need_profile))
return ChannelData(
datatype=ChannelDataType.DICT.value,
dictdata=dictdata,
data_id=data_id,
client_need_profile=(
profile_value == self._client_profile_value))
client_need_profile=client_need_profile)
def call(self, rpc_request):
data_id = self._get_next_data_id()
_LOGGER.debug("generate id: {}".format(data_id))
if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id))
else:
self._profiler.record("call_{}#DAG_0".format(data_id))
_LOGGER.debug("try parse RPC package to channeldata[{}]".format(data_id))
self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
req_channeldata = self._pack_channeldata(rpc_request, data_id)
self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
resp_channeldata = None
for i in range(self._retry):
_LOGGER.debug(self._log('push data'))
#self._profiler.record("push_{}#{}_0".format(data_id, self.name))
_LOGGER.debug("push data[{}] into Graph engine".format(data_id))
try:
self._in_channel.push(req_channeldata, self.name)
except ChannelStopError:
_LOGGER.debug(self._log("stop."))
_LOGGER.debug("[DAG Executor] channel stop.")
return self._pack_for_rpc_resp(
ChannelData(
ecode=ChannelDataEcode.CLOSED_ERROR.value,
error_info="dag closed.",
data_id=data_id))
#self._profiler.record("push_{}#{}_1".format(data_id, self.name))
_LOGGER.debug(self._log('wait for infer'))
#self._profiler.record("fetch_{}#{}_0".format(data_id, self.name))
_LOGGER.debug("wait for Graph engine for data[{}]...".format(data_id))
resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id)
#self._profiler.record("fetch_{}#{}_1".format(data_id, self.name))
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
_LOGGER.debug("Graph engine predict data[{}] succ".format(data_id))
break
else:
_LOGGER.warn("Graph engine predict data[{}] failed: {}"
.format(data_id, resp_channeldata.error_info))
if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value:
break
if i + 1 < self._retry:
_LOGGER.warn("retry({}): {}".format(
i + 1, resp_channeldata.error_info))
_LOGGER.warn("retry({}/{}) data[{}]".format(
i + 1, self._retry, data_id))
_LOGGER.debug("unpack channeldata[{}] into RPC resp package".format(data_id))
self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
......@@ -232,7 +250,6 @@ class DAGExecutor(object):
self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
else:
self._profiler.record("call_{}#DAG_1".format(data_id))
#self._profiler.print_profile()
profile_str = self._profiler.gen_profile_str()
if self._server_use_profile:
......@@ -250,16 +267,12 @@ class DAGExecutor(object):
return rpc_resp
def _pack_for_rpc_resp(self, channeldata):
_LOGGER.debug(self._log('get channeldata'))
return self._pack_rpc_func(channeldata)
def _log(self, info_str):
return "[{}] {}".format(self.name, info_str)
class DAG(object):
def __init__(self, request_name, response_op, use_profile, is_thread_op,
client_type, channel_size, show_info):
client_type, channel_size, show_info, build_dag_each_worker):
self._request_name = request_name
self._response_op = response_op
self._use_profile = use_profile
......@@ -267,8 +280,10 @@ class DAG(object):
self._channel_size = channel_size
self._client_type = client_type
self._show_info = show_info
self._build_dag_each_worker = build_dag_each_worker
if not self._is_thread_op:
self._manager = multiprocessing.Manager()
_LOGGER.info("[DAG] succ init")
def get_use_ops(self, response_op):
unique_names = set()
......@@ -301,10 +316,13 @@ class DAG(object):
else:
channel = ProcessChannel(
self._manager, name=name_gen.next(), maxsize=self._channel_size)
_LOGGER.debug("[DAG] gen Channel: {}".format(channel.name))
return channel
def _gen_virtual_op(self, name_gen):
return VirtualOp(name=name_gen.next())
vir_op = VirtualOp(name=name_gen.next())
_LOGGER.debug("[DAG] gen VirtualOp: {}".format(vir_op.name))
return vir_op
def _topo_sort(self, used_ops, response_op, out_degree_ops):
out_degree_num = {
......@@ -360,6 +378,11 @@ class DAG(object):
raise Exception(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
)
if self._build_dag_each_worker:
_LOGGER.info("Because `build_dag_each_worker` mode is used, "
"Auto-batching is set to the default config.")
for op in used_ops:
op.use_default_auto_batching_config()
dag_views, last_op = self._topo_sort(used_ops, response_op,
out_degree_ops)
......@@ -414,7 +437,8 @@ class DAG(object):
continue
channel = self._gen_channel(channel_name_gen)
channels.append(channel)
_LOGGER.debug("{} => {}".format(channel.name, op.name))
_LOGGER.debug("[DAG] Channel({}) => Op({})"
.format(channel.name, op.name))
op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0:
......@@ -422,8 +446,8 @@ class DAG(object):
else:
# if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops:
_LOGGER.debug("{} => {}".format(pred_op.name,
channel.name))
_LOGGER.debug("[DAG] Op({}) => Channel({})"
.format(pred_op.name, channel.name))
pred_op.add_output_channel(channel)
processed_op.add(op.name)
# find same input op to combine channel
......@@ -439,8 +463,8 @@ class DAG(object):
same_flag = False
break
if same_flag:
_LOGGER.debug("{} => {}".format(channel.name,
other_op.name))
_LOGGER.debug("[DAG] Channel({}) => Op({})"
.format(channel.name, other_op.name))
other_op.add_input_channel(channel)
processed_op.add(other_op.name)
output_channel = self._gen_channel(channel_name_gen)
......@@ -458,7 +482,8 @@ class DAG(object):
actual_ops.append(op)
for c in channels:
_LOGGER.debug(c.debug())
_LOGGER.debug("Channel({}):\n -producers: {}\n -consumers: {}"
.format(c.name, c.get_producers(), c.get_consumers()))
return (actual_ops, channels, input_channel, output_channel, pack_func,
unpack_func)
......@@ -466,6 +491,7 @@ class DAG(object):
def build(self):
(actual_ops, channels, input_channel, output_channel, pack_func,
unpack_func) = self._build_dag(self._response_op)
_LOGGER.info("[DAG] succ build dag")
self._actual_ops = actual_ops
self._channels = channels
......@@ -486,6 +512,8 @@ class DAG(object):
else:
self._threads_or_proces.extend(
op.start_with_process(self._client_type))
_LOGGER.info("[DAG] start")
# not join yet
return self._threads_or_proces
......
......@@ -67,7 +67,8 @@ class Op(object):
self._batch_size = batch_size
self._auto_batching_timeout = auto_batching_timeout
if self._auto_batching_timeout is not None and self._auto_batching_timeout <= 0:
if self._auto_batching_timeout is not None:
if self._auto_batching_timeout <= 0 or self._batch_size == 1:
self._auto_batching_timeout = None
self._server_use_profile = False
......@@ -78,6 +79,10 @@ class Op(object):
self._succ_init_op = False
self._succ_close_op = False
def use_default_auto_batching_config(self):
self._batch_size = 1
self._auto_batching_timeout = None
def use_profiler(self, use_profile):
self._server_use_profile = use_profile
......@@ -89,11 +94,12 @@ class Op(object):
def init_client(self, client_type, client_config, server_endpoints,
fetch_names):
if self.with_serving == False:
_LOGGER.debug("{} no client".format(self.name))
_LOGGER.info("Op({}) no client".format(self.name))
return None
_LOGGER.debug("{} client_config: {}".format(self.name, client_config))
_LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
_LOGGER.info("Op({}) service endpoints: {}".format(self.name, server_endpoints))
_LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc':
_LOGGER.debug("Op({}) client_config: {}".format(self.name, client_config))
client = Client()
client.load_client_config(client_config)
elif client_type == 'grpc':
......@@ -163,7 +169,6 @@ class Op(object):
"{} Please override preprocess func.".format(err_info))
call_result = self.client.predict(
feed=feed_batch, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result
def postprocess(self, input_dict, fetch_dict):
......@@ -237,6 +242,7 @@ class Op(object):
pass
def _run_preprocess(self, parsed_data_dict, log_func):
_LOGGER.debug(log_func("try to run preprocess"))
preped_data_dict = {}
err_channeldata_dict = {}
for data_id, parsed_data in parsed_data_dict.items():
......@@ -245,22 +251,27 @@ class Op(object):
preped_data = self.preprocess(parsed_data)
except NotImplementedError as e:
# preprocess function not implemented
error_info = log_func(e)
_LOGGER.error(error_info)
error_info = log_func(
"preprocess data[{}] failed: {}".format(
data_id, e))
error_channeldata = ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info,
data_id=data_id)
except TypeError as e:
# Error type in channeldata.datatype
error_info = log_func(e)
error_info = log_func(
"preprocess data[{}] failed: {}".format(
data_id, e))
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id)
except Exception as e:
error_info = log_func(e)
error_info = log_func(
"preprocess data[{}] failed: {}".format(
data_id, e))
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
......@@ -270,9 +281,11 @@ class Op(object):
err_channeldata_dict[data_id] = error_channeldata
else:
preped_data_dict[data_id] = preped_data
_LOGGER.debug(log_func("succ run preprocess"))
return preped_data_dict, err_channeldata_dict
def _run_process(self, preped_data_dict, log_func):
_LOGGER.debug(log_func("try to run process"))
midped_data_dict = {}
err_channeldata_dict = {}
if self.with_serving:
......@@ -285,7 +298,7 @@ class Op(object):
midped_batch = self.process(feed_batch)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
error_info = log_func("process batch failed: {}".format(e))
_LOGGER.error(error_info)
else:
for i in range(self._retry):
......@@ -299,10 +312,11 @@ class Op(object):
_LOGGER.error(error_info)
else:
_LOGGER.warn(
log_func("timeout, retry({})".format(i + 1)))
log_func("timeout, retry({}/{})"
.format(i + 1, self._retry)))
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
error_info = log_func("process batch failed: {}".format(e))
_LOGGER.error(error_info)
break
else:
......@@ -315,11 +329,13 @@ class Op(object):
data_id=data_id)
elif midped_batch is None:
# op client return None
error_info=log_func(
"predict failed. pls check the server side.")
_LOGGER.error(error_info)
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=log_func(
"predict failed. pls check the server side."),
error_info=error_info,
data_id=data_id)
else:
# transform np format to dict format
......@@ -329,9 +345,11 @@ class Op(object):
}
else:
midped_data_dict = preped_data_dict
_LOGGER.debug(log_func("succ run process"))
return midped_data_dict, err_channeldata_dict
def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
_LOGGER.debug(log_func("try to run postprocess"))
postped_data_dict = {}
err_channeldata_dict = {}
for data_id, midped_data in midped_data_dict.items():
......@@ -340,7 +358,8 @@ class Op(object):
postped_data = self.postprocess(
parsed_data_dict[data_id], midped_data)
except Exception as e:
error_info = log_func(e)
error_info = log_func("postprocess data[{}] failed: {}"
.format(data_id, e))
_LOGGER.error(error_info)
err_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
......@@ -374,11 +393,17 @@ class Op(object):
dictdata=postped_data,
data_id=data_id)
postped_data_dict[data_id] = output_data
_LOGGER.debug(log_func("succ run postprocess"))
return postped_data_dict, err_channeldata_dict
def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout):
def _auto_batching_generator(self, input_channel, op_name,
batch_size, timeout, log_func):
while True:
batch = []
_LOGGER.debug(
log_func(
"Auto-batching expect size: {}; timeout: {}".format(
batch_size, timeout)))
while len(batch) == 0:
endtime = None
if timeout is not None:
......@@ -389,15 +414,16 @@ class Op(object):
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
_LOGGER.info(log("auto-batching timeout"))
_LOGGER.debug(log_func("Auto-batching timeout"))
break
channeldata_dict = input_channel.front(op_name, timeout)
else:
channeldata_dict = input_channel.front(op_name)
batch.append(channeldata_dict)
except ChannelTimeoutError:
_LOGGER.info(log("auto-batching timeout"))
_LOGGER.debug(log_func("Auto-batching timeout"))
break
_LOGGER.debug(log_func("Auto-batching actual size: {}".format(len(batch))))
yield batch
def _parse_channeldata_batch(self, batch, output_channels):
......@@ -429,29 +455,28 @@ class Op(object):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = get_log_func(op_info_prefix)
preplog = get_log_func(op_info_prefix + "(prep)")
midplog = get_log_func(op_info_prefix + "(midp)")
postplog = get_log_func(op_info_prefix + "(postp)")
tid = threading.current_thread().ident
# init op
try:
self._initialize(is_thread_op, client_type)
except Exception as e:
_LOGGER.error(log(e))
_LOGGER.error(log("init op failed: {}".format(e)))
os._exit(-1)
_LOGGER.info(log("succ init"))
batch_generator = self._auto_batching_generator(
input_channel=input_channel,
op_name=self.name,
batch_size=self._batch_size,
timeout=self._auto_batching_timeout)
timeout=self._auto_batching_timeout,
log_func=log)
while True:
try:
channeldata_dict_batch = next(batch_generator)
except ChannelStopError:
_LOGGER.debug(log("stop."))
_LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op)
break
......@@ -461,7 +486,7 @@ class Op(object):
= self._parse_channeldata_batch(
channeldata_dict_batch, output_channels)
except ChannelStopError:
_LOGGER.debug(log("stop."))
_LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op)
break
if len(parsed_data_dict) == 0:
......@@ -471,7 +496,7 @@ class Op(object):
# preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix))
preped_data_dict, err_channeldata_dict \
= self._run_preprocess(parsed_data_dict, preplog)
= self._run_preprocess(parsed_data_dict, log)
self._profiler_record("prep#{}_1".format(op_info_prefix))
try:
for data_id, err_channeldata in err_channeldata_dict.items():
......@@ -481,7 +506,7 @@ class Op(object):
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError:
_LOGGER.debug(log("stop."))
_LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op)
break
if len(parsed_data_dict) == 0:
......@@ -490,7 +515,7 @@ class Op(object):
# process
self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data_dict, err_channeldata_dict \
= self._run_process(preped_data_dict, midplog)
= self._run_process(preped_data_dict, log)
self._profiler_record("midp#{}_1".format(op_info_prefix))
try:
for data_id, err_channeldata in err_channeldata_dict.items():
......@@ -500,7 +525,7 @@ class Op(object):
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError:
_LOGGER.debug(log("stop."))
_LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op)
break
if len(midped_data_dict) == 0:
......@@ -510,7 +535,7 @@ class Op(object):
self._profiler_record("postp#{}_0".format(op_info_prefix))
postped_data_dict, err_channeldata_dict \
= self._run_postprocess(
parsed_data_dict, midped_data_dict, postplog)
parsed_data_dict, midped_data_dict, log)
self._profiler_record("postp#{}_1".format(op_info_prefix))
try:
for data_id, err_channeldata in err_channeldata_dict.items():
......@@ -520,7 +545,7 @@ class Op(object):
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError:
_LOGGER.debug(log("stop."))
_LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op)
break
if len(postped_data_dict) == 0:
......@@ -535,7 +560,7 @@ class Op(object):
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError:
_LOGGER.debug(log("stop."))
_LOGGER.debug(log("channel stop."))
self._finalize(is_thread_op)
break
......
......@@ -28,13 +28,14 @@ from .dag import DAGExecutor
_LOGGER = logging.getLogger()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_config, show_info):
super(PipelineService, self).__init__()
super(PipelineServicer, self).__init__()
# init dag executor
self._dag_executor = DAGExecutor(
response_op, dag_config, show_info=show_info)
self._dag_executor.start()
_LOGGER.info("[PipelineServicer] succ init")
def inference(self, request, context):
resp = self._dag_executor.call(request)
......@@ -80,26 +81,33 @@ class PipelineServer(object):
def prepare_server(self, yml_file):
with open(yml_file) as f:
yml_config = yaml.load(f.read())
self._port = yml_config.get('port')
if self._port is None:
raise SystemExit("Please set *port* in [{}] yaml file.".format(
yml_file))
default_config = {
"port": 9292,
"worker_num": 1,
"build_dag_each_worker": False,
}
for key, val in default_config.items():
if yml_config.get(key) is None:
_LOGGER.warning("[CONF] {} not set, use default: {}"
.format(key, val))
yml_config[key] = val
self._port = yml_config["port"]
if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = yml_config.get('worker_num', 1)
self._build_dag_each_worker = yml_config.get('build_dag_each_worker',
False)
self._worker_num = yml_config["worker_num"]
self._build_dag_each_worker = yml_config["build_dag_each_worker"]
_LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("port: {}".format(self._port))
_LOGGER.info("worker_num: {}".format(self._worker_num))
servicer_info = "build_dag_each_worker: {}".format(
self._build_dag_each_worker)
for key in default_config.keys():
_LOGGER.info("{}: {}".format(key, yml_config[key]))
if self._build_dag_each_worker is True:
servicer_info += " (Make sure that install grpcio whl with --no-binary flag)"
_LOGGER.info(servicer_info)
_LOGGER.info("(Make sure that install grpcio whl with --no-binary flag)")
_LOGGER.info("-------------------------------------------")
self._dag_config = yml_config.get("dag", {})
self._dag_config["build_dag_each_worker"] = self._build_dag_each_worker
def run_server(self):
if self._build_dag_each_worker:
......@@ -120,7 +128,7 @@ class PipelineServer(object):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._response_op, self._dag_config, True),
PipelineServicer(self._response_op, self._dag_config, True),
server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
......@@ -132,7 +140,7 @@ class PipelineServer(object):
futures.ThreadPoolExecutor(
max_workers=1, ), options=options)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(response_op, dag_config, False), server)
PipelineServicer(response_op, dag_config, False), server)
server.add_insecure_port(bind_address)
server.start()
server.wait_for_termination()
......@@ -22,7 +22,7 @@ elif sys.version_info.major == 3:
import queue as Queue
else:
raise Exception("Error Python version")
import time
from time import time as _time
import threading
_LOGGER = logging.getLogger()
......@@ -42,7 +42,7 @@ class TimeProfiler(object):
def record(self, name_with_tag):
if self._enable is False:
return
timestamp = int(round(time.time() * 1000000))
timestamp = int(round(_time() * 1000000))
name_with_tag = name_with_tag.split("_")
tag = name_with_tag[-1]
name = '_'.join(name_with_tag[:-1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册