提交 51e21362 编写于 作者: B barrierye

add ProcessPipelineServicer

上级 23664ff8
...@@ -37,7 +37,8 @@ class ChannelDataEcode(enum.Enum): ...@@ -37,7 +37,8 @@ class ChannelDataEcode(enum.Enum):
TYPE_ERROR = 3 TYPE_ERROR = 3
RPC_PACKAGE_ERROR = 4 RPC_PACKAGE_ERROR = 4
CLIENT_ERROR = 5 CLIENT_ERROR = 5
UNKNOW = 6 CLOSED_ERROR = 6
UNKNOW = 7
class ChannelDataType(enum.Enum): class ChannelDataType(enum.Enum):
...@@ -258,6 +259,8 @@ class ProcessChannel(object): ...@@ -258,6 +259,8 @@ class ProcessChannel(object):
break break
except Queue.Full: except Queue.Full:
self._cv.wait() self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("{} channel size: {}".format(op_name, self._log("{} channel size: {}".format(op_name,
self._que.qsize()))) self._que.qsize())))
...@@ -308,6 +311,8 @@ class ProcessChannel(object): ...@@ -308,6 +311,8 @@ class ProcessChannel(object):
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name))) self._log("multi | {} push data succ!".format(op_name)))
...@@ -337,6 +342,8 @@ class ProcessChannel(object): ...@@ -337,6 +342,8 @@ class ProcessChannel(object):
"{} wait for empty queue(with channel empty: {})". "{} wait for empty queue(with channel empty: {})".
format(op_name, self._que.empty()))) format(op_name, self._que.empty())))
self._cv.wait() self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__( self._log("{} get data succ: {}".format(op_name, resp.__str__(
)))) ))))
...@@ -379,6 +386,8 @@ class ProcessChannel(object): ...@@ -379,6 +386,8 @@ class ProcessChannel(object):
"{} wait for empty queue(with channel size: {})". "{} wait for empty queue(with channel size: {})".
format(op_name, self._que.qsize()))) format(op_name, self._que.qsize())))
self._cv.wait() self._cv.wait()
if self._stop:
raise ChannelStopError()
consumer_cursor = self._consumer_cursors[op_name] consumer_cursor = self._consumer_cursors[op_name]
base_cursor = self._base_cursor.value base_cursor = self._base_cursor.value
...@@ -425,10 +434,10 @@ class ProcessChannel(object): ...@@ -425,10 +434,10 @@ class ProcessChannel(object):
return resp # reference, read only return resp # reference, read only
def stop(self): def stop(self):
#TODO _LOGGER.info(self._log("stop."))
self.close()
self._stop = True self._stop = True
self._cv.notify_all() with self._cv:
self._cv.notify_all()
class ThreadChannel(Queue.Queue): class ThreadChannel(Queue.Queue):
...@@ -527,6 +536,8 @@ class ThreadChannel(Queue.Queue): ...@@ -527,6 +536,8 @@ class ThreadChannel(Queue.Queue):
break break
except Queue.Full: except Queue.Full:
self._cv.wait() self._cv.wait()
if self._stop:
raise ChannelStopError()
self._cv.notify_all() self._cv.notify_all()
_LOGGER.debug(self._log("{} push data succ!".format(op_name))) _LOGGER.debug(self._log("{} push data succ!".format(op_name)))
return True return True
...@@ -565,6 +576,8 @@ class ThreadChannel(Queue.Queue): ...@@ -565,6 +576,8 @@ class ThreadChannel(Queue.Queue):
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name))) self._log("multi | {} push data succ!".format(op_name)))
...@@ -587,6 +600,8 @@ class ThreadChannel(Queue.Queue): ...@@ -587,6 +600,8 @@ class ThreadChannel(Queue.Queue):
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__( self._log("{} get data succ: {}".format(op_name, resp.__str__(
)))) ))))
...@@ -617,6 +632,8 @@ class ThreadChannel(Queue.Queue): ...@@ -617,6 +632,8 @@ class ThreadChannel(Queue.Queue):
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
if self._stop:
raise ChannelStopError()
consumer_cursor = self._consumer_cursors[op_name] consumer_cursor = self._consumer_cursors[op_name]
base_cursor = self._base_cursor base_cursor = self._base_cursor
...@@ -657,7 +674,12 @@ class ThreadChannel(Queue.Queue): ...@@ -657,7 +674,12 @@ class ThreadChannel(Queue.Queue):
return resp return resp
def stop(self): def stop(self):
#TODO _LOGGER.info(self._log("stop."))
self.close()
self._stop = True self._stop = True
self._cv.notify_all() with self._cv:
self._cv.notify_all()
class ChannelStopError(RuntimeError):
def __init__(self):
pass
...@@ -26,7 +26,8 @@ import os ...@@ -26,7 +26,8 @@ import os
import logging import logging
from .operator import Op, RequestOp, ResponseOp, VirtualOp from .operator import Op, RequestOp, ResponseOp, VirtualOp
from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType from .channel import (ThreadChannel, ProcessChannel, ChannelData,
ChannelDataEcode, ChannelDataType, ChannelStopError)
from .profiler import TimeProfiler from .profiler import TimeProfiler
from .util import NameGenerator from .util import NameGenerator
...@@ -34,33 +35,29 @@ _LOGGER = logging.getLogger() ...@@ -34,33 +35,29 @@ _LOGGER = logging.getLogger()
class DAGExecutor(object): class DAGExecutor(object):
def __init__(self, response_op, yml_config, show_info): def __init__(self, response_op, dag_config, show_info):
self._retry = yml_config.get('retry', 1) self._retry = dag_config.get('retry', 1)
client_type = yml_config.get('client_type', 'brpc') client_type = dag_config.get('client_type', 'brpc')
use_multithread = yml_config.get('use_multithread', True) use_profile = dag_config.get('use_profile', False)
use_profile = yml_config.get('profile', False) channel_size = dag_config.get('channel_size', 0)
channel_size = yml_config.get('channel_size', 0) self._is_thread_op = dag_config.get('is_thread_op', True)
self._asyn_profile = yml_config.get('asyn_profile', False)
if show_info and use_profile: if show_info and use_profile:
_LOGGER.info("================= PROFILER ================") _LOGGER.info("================= PROFILER ================")
if use_multithread: if self._is_thread_op:
_LOGGER.info("op: thread") _LOGGER.info("op: thread")
_LOGGER.info("profile mode: sync")
else: else:
_LOGGER.info("op: process") _LOGGER.info("op: process")
if self._asyn_profile: _LOGGER.info("profile mode: asyn")
_LOGGER.info("profile mode: asyn (This mode is only used"
" when using the process version Op)")
else:
_LOGGER.info("profile mode: sync")
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
self.name = "@G" self.name = "@G"
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(use_profile) self._profiler.enable(use_profile)
self._dag = DAG(self.name, response_op, use_profile, use_multithread, self._dag = DAG(self.name, response_op, use_profile, self._is_thread_op,
client_type, channel_size, show_info) client_type, channel_size, show_info)
(in_channel, out_channel, pack_rpc_func, (in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build() unpack_rpc_func) = self._dag.build()
...@@ -80,17 +77,14 @@ class DAGExecutor(object): ...@@ -80,17 +77,14 @@ class DAGExecutor(object):
self._cv_pool = {} self._cv_pool = {}
self._cv_for_cv_pool = threading.Condition() self._cv_for_cv_pool = threading.Condition()
self._fetch_buffer = None self._fetch_buffer = None
self._is_run = False
self._recive_func = None self._recive_func = None
def start(self): def start(self):
self._is_run = True
self._recive_func = threading.Thread( self._recive_func = threading.Thread(
target=DAGExecutor._recive_out_channel_func, args=(self, )) target=DAGExecutor._recive_out_channel_func, args=(self, ))
self._recive_func.start() self._recive_func.start()
def stop(self): def stop(self):
self._is_run = False
self._dag.stop() self._dag.stop()
self._dag.join() self._dag.join()
...@@ -119,8 +113,22 @@ class DAGExecutor(object): ...@@ -119,8 +113,22 @@ class DAGExecutor(object):
def _recive_out_channel_func(self): def _recive_out_channel_func(self):
cv = None cv = None
while self._is_run: while True:
channeldata_dict = self._out_channel.front(self.name) try:
channeldata_dict = self._out_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(self._log("stop."))
with self._cv_for_cv_pool:
for data_id, cv in self._cv_pool.items():
closed_errror_data = ChannelData(
ecode=ChannelDataEcode.CLOSED_ERROR.value,
error_info="dag closed.",
data_id=data_id)
with cv:
self._fetch_buffer = closed_errror_data
cv.notify_all()
break
if len(channeldata_dict) != 1: if len(channeldata_dict) != 1:
_LOGGER.error("out_channel cannot have multiple input ops") _LOGGER.error("out_channel cannot have multiple input ops")
os._exit(-1) os._exit(-1)
...@@ -147,7 +155,6 @@ class DAGExecutor(object): ...@@ -147,7 +155,6 @@ class DAGExecutor(object):
cv.wait() cv.wait()
_LOGGER.debug("resp func get lock (data_id: {})".format(data_id)) _LOGGER.debug("resp func get lock (data_id: {})".format(data_id))
resp = copy.deepcopy(self._fetch_buffer) resp = copy.deepcopy(self._fetch_buffer)
# cv.notify_all()
with self._cv_for_cv_pool: with self._cv_for_cv_pool:
self._cv_pool.pop(data_id) self._cv_pool.pop(data_id)
return resp return resp
...@@ -170,7 +177,7 @@ class DAGExecutor(object): ...@@ -170,7 +177,7 @@ class DAGExecutor(object):
def call(self, rpc_request): def call(self, rpc_request):
data_id = self._get_next_data_id() data_id = self._get_next_data_id()
if self._asyn_profile: if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id)) self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id))
else: else:
self._profiler.record("call_{}#DAG_0".format(data_id)) self._profiler.record("call_{}#DAG_0".format(data_id))
...@@ -183,7 +190,15 @@ class DAGExecutor(object): ...@@ -183,7 +190,15 @@ class DAGExecutor(object):
for i in range(self._retry): for i in range(self._retry):
_LOGGER.debug(self._log('push data')) _LOGGER.debug(self._log('push data'))
#self._profiler.record("push_{}#{}_0".format(data_id, self.name)) #self._profiler.record("push_{}#{}_0".format(data_id, self.name))
self._in_channel.push(req_channeldata, self.name) try:
self._in_channel.push(req_channeldata, self.name)
except ChannelStopError:
_LOGGER.info(self._log("stop."))
return self._pack_for_rpc_resp(
ChannelData(
ecode=ChannelDataEcode.CLOSED_ERROR.value,
error_info="dag closed.",
data_id=data_id))
#self._profiler.record("push_{}#{}_1".format(data_id, self.name)) #self._profiler.record("push_{}#{}_1".format(data_id, self.name))
_LOGGER.debug(self._log('wait for infer')) _LOGGER.debug(self._log('wait for infer'))
...@@ -201,7 +216,7 @@ class DAGExecutor(object): ...@@ -201,7 +216,7 @@ class DAGExecutor(object):
rpc_resp = self._pack_for_rpc_resp(resp_channeldata) rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
self._profiler.record("postpack_{}#{}_1".format(data_id, self.name)) self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
if self._asyn_profile: if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id)) self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
else: else:
self._profiler.record("call_{}#DAG_1".format(data_id)) self._profiler.record("call_{}#DAG_1".format(data_id))
...@@ -217,16 +232,16 @@ class DAGExecutor(object): ...@@ -217,16 +232,16 @@ class DAGExecutor(object):
class DAG(object): class DAG(object):
def __init__(self, request_name, response_op, use_profile, use_multithread, def __init__(self, request_name, response_op, use_profile, is_thread_op,
client_type, channel_size, show_info): client_type, channel_size, show_info):
self._request_name = request_name self._request_name = request_name
self._response_op = response_op self._response_op = response_op
self._use_profile = use_profile self._use_profile = use_profile
self._use_multithread = use_multithread self._is_thread_op = is_thread_op
self._channel_size = channel_size self._channel_size = channel_size
self._client_type = client_type self._client_type = client_type
self._show_info = show_info self._show_info = show_info
if not self._use_multithread: if not self._is_thread_op:
self._manager = multiprocessing.Manager() self._manager = multiprocessing.Manager()
def get_use_ops(self, response_op): def get_use_ops(self, response_op):
...@@ -254,7 +269,7 @@ class DAG(object): ...@@ -254,7 +269,7 @@ class DAG(object):
def _gen_channel(self, name_gen): def _gen_channel(self, name_gen):
channel = None channel = None
if self._use_multithread: if self._is_thread_op:
channel = ThreadChannel( channel = ThreadChannel(
name=name_gen.next(), maxsize=self._channel_size) name=name_gen.next(), maxsize=self._channel_size)
else: else:
...@@ -439,7 +454,7 @@ class DAG(object): ...@@ -439,7 +454,7 @@ class DAG(object):
self._threads_or_proces = [] self._threads_or_proces = []
for op in self._actual_ops: for op in self._actual_ops:
op.use_profiler(self._use_profile) op.use_profiler(self._use_profile)
if self._use_multithread: if self._is_thread_op:
self._threads_or_proces.extend( self._threads_or_proces.extend(
op.start_with_thread(self._client_type)) op.start_with_thread(self._client_type))
else: else:
...@@ -453,7 +468,5 @@ class DAG(object): ...@@ -453,7 +468,5 @@ class DAG(object):
x.join() x.join()
def stop(self): def stop(self):
for op in self._actual_ops:
op.stop()
for chl in self._channels: for chl in self._channels:
chl.stop() chl.stop()
...@@ -24,7 +24,8 @@ import sys ...@@ -24,7 +24,8 @@ import sys
from numpy import * from numpy import *
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
ChannelData, ChannelDataType, ChannelStopError)
from .util import NameGenerator from .util import NameGenerator
from .profiler import TimeProfiler from .profiler import TimeProfiler
...@@ -44,7 +45,6 @@ class Op(object): ...@@ -44,7 +45,6 @@ class Op(object):
retry=1): retry=1):
if name is None: if name is None:
name = _op_name_gen.next() name = _op_name_gen.next()
self._is_run = False
self.name = name # to identify the type of OP, it must be globally unique self.name = name # to identify the type of OP, it must be globally unique
self.concurrency = concurrency # amount of concurrency self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops) self.set_input_ops(input_ops)
...@@ -65,7 +65,9 @@ class Op(object): ...@@ -65,7 +65,9 @@ class Op(object):
# only for multithread # only for multithread
self._for_init_op_lock = threading.Lock() self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False self._succ_init_op = False
self._succ_close_op = False
def use_profiler(self, use_profile): def use_profiler(self, use_profile):
self._use_profile = use_profile self._use_profile = use_profile
...@@ -93,9 +95,6 @@ class Op(object): ...@@ -93,9 +95,6 @@ class Op(object):
self._fetch_names = fetch_names self._fetch_names = fetch_names
return client return client
def _get_input_channel(self):
return self._input
def get_input_ops(self): def get_input_ops(self):
return self._input_ops return self._input_ops
...@@ -118,8 +117,11 @@ class Op(object): ...@@ -118,8 +117,11 @@ class Op(object):
channel.add_consumer(self.name) channel.add_consumer(self.name)
self._input = channel self._input = channel
def _get_output_channels(self): def _clean_input_channel(self):
return self._outputs self._input = None
def _get_input_channel(self):
return self._input
def add_output_channel(self, channel): def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)): if not isinstance(channel, (ThreadChannel, ProcessChannel)):
...@@ -129,6 +131,12 @@ class Op(object): ...@@ -129,6 +131,12 @@ class Op(object):
channel.add_producer(self.name) channel.add_producer(self.name)
self._outputs.append(channel) self._outputs.append(channel)
def _clean_output_channels(self):
self._outputs = []
def _get_output_channels(self):
return self._outputs
def preprocess(self, input_dicts): def preprocess(self, input_dicts):
# multiple previous Op # multiple previous Op
if len(input_dicts) != 1: if len(input_dicts) != 1:
...@@ -144,7 +152,7 @@ class Op(object): ...@@ -144,7 +152,7 @@ class Op(object):
if err != 0: if err != 0:
raise NotImplementedError( raise NotImplementedError(
"{} Please override preprocess func.".format(err_info)) "{} Please override preprocess func.".format(err_info))
call_result = self._client.predict( call_result = self.client.predict(
feed=feed_dict, fetch=self._fetch_names) feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result")) _LOGGER.debug(self._log("get call_result"))
return call_result return call_result
...@@ -152,9 +160,6 @@ class Op(object): ...@@ -152,9 +160,6 @@ class Op(object):
def postprocess(self, input_dict, fetch_dict): def postprocess(self, input_dict, fetch_dict):
return fetch_dict return fetch_dict
def stop(self):
self._is_run = False
def _parse_channeldata(self, channeldata_dict): def _parse_channeldata(self, channeldata_dict):
data_id, error_channeldata = None, None data_id, error_channeldata = None, None
parsed_data = {} parsed_data = {}
...@@ -332,30 +337,42 @@ class Op(object): ...@@ -332,30 +337,42 @@ class Op(object):
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile) self._profiler.enable(self._use_profile)
# init client # init client
self._client = self.init_client( self.client = self.init_client(
client_type, self._client_config, client_type, self._client_config,
self._server_endpoints, self._fetch_names) self._server_endpoints, self._fetch_names)
# user defined # user defined
self.init_op() self.init_op()
self._succ_init_op = True self._succ_init_op = True
self._succ_close_op = False
else: else:
# init profiler # init profiler
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile) self._profiler.enable(self._use_profile)
# init client # init client
self._client = self.init_client( self.client = self.init_client(client_type, self._client_config,
client_type, self._client_config, self._server_endpoints, self._server_endpoints,
self._fetch_names) self._fetch_names)
# user defined # user defined
self.init_op() self.init_op()
except Exception as e: except Exception as e:
_LOGGER.error(log(e)) _LOGGER.error(log(e))
os._exit(-1) os._exit(-1)
self._is_run = True while True:
while self._is_run:
#self._profiler_record("get#{}_0".format(op_info_prefix)) #self._profiler_record("get#{}_0".format(op_info_prefix))
channeldata_dict = input_channel.front(self.name) try:
channeldata_dict = input_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(log("stop."))
with self._for_close_op_lock:
if not self._succ_close_op:
self._clean_input_channel()
self._clean_output_channels()
self._profiler = None
self.client = None
self._succ_init_op = False
self._succ_close_op = True
break
#self._profiler_record("get#{}_1".format(op_info_prefix)) #self._profiler_record("get#{}_1".format(op_info_prefix))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict))) _LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
...@@ -363,8 +380,11 @@ class Op(object): ...@@ -363,8 +380,11 @@ class Op(object):
channeldata_dict) channeldata_dict)
# error data in predecessor Op # error data in predecessor Op
if error_channeldata is not None: if error_channeldata is not None:
self._push_to_output_channels(error_channeldata, try:
output_channels) self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
continue continue
# preprecess # preprecess
...@@ -373,8 +393,11 @@ class Op(object): ...@@ -373,8 +393,11 @@ class Op(object):
data_id, log) data_id, log)
self._profiler_record("prep#{}_1".format(op_info_prefix)) self._profiler_record("prep#{}_1".format(op_info_prefix))
if error_channeldata is not None: if error_channeldata is not None:
self._push_to_output_channels(error_channeldata, try:
output_channels) self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
continue continue
# process # process
...@@ -383,8 +406,11 @@ class Op(object): ...@@ -383,8 +406,11 @@ class Op(object):
data_id, log) data_id, log)
self._profiler_record("midp#{}_1".format(op_info_prefix)) self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None: if error_channeldata is not None:
self._push_to_output_channels(error_channeldata, try:
output_channels) self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
continue continue
# postprocess # postprocess
...@@ -393,8 +419,11 @@ class Op(object): ...@@ -393,8 +419,11 @@ class Op(object):
parsed_data, midped_data, data_id, log) parsed_data, midped_data, data_id, log)
self._profiler_record("postp#{}_1".format(op_info_prefix)) self._profiler_record("postp#{}_1".format(op_info_prefix))
if error_channeldata is not None: if error_channeldata is not None:
self._push_to_output_channels(error_channeldata, try:
output_channels) self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
continue continue
if self._use_profile: if self._use_profile:
...@@ -405,7 +434,11 @@ class Op(object): ...@@ -405,7 +434,11 @@ class Op(object):
# push data to channel (if run succ) # push data to channel (if run succ)
#self._profiler_record("push#{}_0".format(op_info_prefix)) #self._profiler_record("push#{}_0".format(op_info_prefix))
self._push_to_output_channels(output_data, output_channels) try:
self._push_to_output_channels(output_data, output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
break
#self._profiler_record("push#{}_1".format(op_info_prefix)) #self._profiler_record("push#{}_1".format(op_info_prefix))
#self._profiler.print_profile() #self._profiler.print_profile()
...@@ -525,10 +558,17 @@ class VirtualOp(Op): ...@@ -525,10 +558,17 @@ class VirtualOp(Op):
log = get_log_func(op_info_prefix) log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident tid = threading.current_thread().ident
self._is_run = True while True:
while self._is_run: try:
channeldata_dict = input_channel.front(self.name) channeldata_dict = input_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(log("stop."))
break
for name, data in channeldata_dict.items(): try:
self._push_to_output_channels( for name, data in channeldata_dict.items():
data, channels=output_channels, name=name) self._push_to_output_channels(
data, channels=output_channels, name=name)
except ChannelStopError:
_LOGGER.info(log("stop."))
break
...@@ -29,16 +29,36 @@ _LOGGER = logging.getLogger() ...@@ -29,16 +29,36 @@ _LOGGER = logging.getLogger()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, yml_config, show_info=True): def __init__(self, response_op, dag_config):
super(PipelineService, self).__init__() super(PipelineService, self).__init__()
# init dag executor # init dag executor
self._dag_executor = DAGExecutor(response_op, yml_config, show_info) self._dag_executor = DAGExecutor(
response_op, dag_config, show_info=True)
self._dag_executor.start() self._dag_executor.start()
def inference(self, request, context): def inference(self, request, context):
resp = self._dag_executor.call(request) resp = self._dag_executor.call(request)
return resp return resp
def __del__(self):
self._dag_executor.stop()
class ProcessPipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_config):
super(ProcessPipelineService, self).__init__()
self._response_op = response_op
self._dag_config = dag_config
def inference(self, request, context):
# init dag executor
dag_executor = DAGExecutor(
self._response_op, self._dag_config, show_info=False)
dag_executor.start()
resp = dag_executor.call(request)
dag_executor.stop()
return resp
@contextlib.contextmanager @contextlib.contextmanager
def _reserve_port(port): def _reserve_port(port):
...@@ -75,25 +95,30 @@ class PipelineServer(object): ...@@ -75,25 +95,30 @@ class PipelineServer(object):
def prepare_server(self, yml_file): def prepare_server(self, yml_file):
with open(yml_file) as f: with open(yml_file) as f:
self._yml_config = yaml.load(f.read()) yml_config = yaml.load(f.read())
self._port = self._yml_config.get('port', 8080) self._port = yml_config.get('port')
if self._port is None:
raise SystemExit("Please set *port* in [{}] yaml file.".format(
yml_file))
if not self._port_is_available(self._port): if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port)) raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = self._yml_config.get('worker_num', 2) self._worker_num = yml_config.get('worker_num', 1)
self._multiprocess_servicer = self._yml_config.get( self._build_dag_each_request = yml_config.get('build_dag_each_request',
'multiprocess_servicer', False) False)
_LOGGER.info("============= PIPELINE SERVER =============") _LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("port: {}".format(self._port)) _LOGGER.info("port: {}".format(self._port))
_LOGGER.info("worker_num: {}".format(self._worker_num)) _LOGGER.info("worker_num: {}".format(self._worker_num))
servicer_info = "multiprocess_servicer: {}".format( servicer_info = "build_dag_each_request: {}".format(
self._multiprocess_servicer) self._build_dag_each_request)
if self._multiprocess_servicer is True: if self._build_dag_each_request is True:
servicer_info += " (Make sure that install grpcio whl with --no-binary flag)" servicer_info += " (Make sure that install grpcio whl with --no-binary flag)"
_LOGGER.info(servicer_info) _LOGGER.info(servicer_info)
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
self._dag_config = yml_config.get("dag", {})
def run_server(self): def run_server(self):
if self._multiprocess_servicer: if self._build_dag_each_request:
with _reserve_port(self._port) as port: with _reserve_port(self._port) as port:
bind_address = 'localhost:{}'.format(port) bind_address = 'localhost:{}'.format(port)
workers = [] workers = []
...@@ -101,8 +126,8 @@ class PipelineServer(object): ...@@ -101,8 +126,8 @@ class PipelineServer(object):
show_info = (i == 0) show_info = (i == 0)
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._run_server_func, target=self._run_server_func,
args=(bind_address, self._response_op, self._yml_config, args=(bind_address, self._response_op, self._dag_config,
self._worker_num, show_info)) self._worker_num))
worker.start() worker.start()
workers.append(worker) workers.append(worker)
for worker in workers: for worker in workers:
...@@ -111,20 +136,20 @@ class PipelineServer(object): ...@@ -111,20 +136,20 @@ class PipelineServer(object):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._response_op, self._yml_config), server) PipelineService(self._response_op, self._dag_config), server)
server.add_insecure_port('[::]:{}'.format(self._port)) server.add_insecure_port('[::]:{}'.format(self._port))
server.start() server.start()
server.wait_for_termination() server.wait_for_termination()
def _run_server_func(self, bind_address, response_op, yml_config, def _run_server_func(self, bind_address, response_op, dag_config,
worker_num, show_info): worker_num):
options = (('grpc.so_reuseport', 1), ) options = (('grpc.so_reuseport', 1), )
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor( futures.ThreadPoolExecutor(
max_workers=worker_num, ), max_workers=worker_num, ),
options=options) options=options)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(response_op, yml_config, show_info), server) ProcessPipelineService(response_op, dag_config), server)
server.add_insecure_port(bind_address) server.add_insecure_port(bind_address)
server.start() server.start()
server.wait_for_termination() server.wait_for_termination()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册