提交 51e21362 编写于 作者: B barrierye

add ProcessPipelineServicer

上级 23664ff8
......@@ -37,7 +37,8 @@ class ChannelDataEcode(enum.Enum):
TYPE_ERROR = 3
RPC_PACKAGE_ERROR = 4
CLIENT_ERROR = 5
UNKNOW = 6
CLOSED_ERROR = 6
UNKNOW = 7
class ChannelDataType(enum.Enum):
......@@ -258,6 +259,8 @@ class ProcessChannel(object):
break
except Queue.Full:
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} channel size: {}".format(op_name,
self._que.qsize())))
......@@ -308,6 +311,8 @@ class ProcessChannel(object):
break
except Queue.Empty:
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name)))
......@@ -337,6 +342,8 @@ class ProcessChannel(object):
"{} wait for empty queue(with channel empty: {})".
format(op_name, self._que.empty())))
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
......@@ -379,6 +386,8 @@ class ProcessChannel(object):
"{} wait for empty queue(with channel size: {})".
format(op_name, self._que.qsize())))
self._cv.wait()
if self._stop:
raise ChannelStopError()
consumer_cursor = self._consumer_cursors[op_name]
base_cursor = self._base_cursor.value
......@@ -425,10 +434,10 @@ class ProcessChannel(object):
return resp # reference, read only
def stop(self):
#TODO
self.close()
_LOGGER.info(self._log("stop."))
self._stop = True
self._cv.notify_all()
with self._cv:
self._cv.notify_all()
class ThreadChannel(Queue.Queue):
......@@ -527,6 +536,8 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Full:
self._cv.wait()
if self._stop:
raise ChannelStopError()
self._cv.notify_all()
_LOGGER.debug(self._log("{} push data succ!".format(op_name)))
return True
......@@ -565,6 +576,8 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Empty:
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name)))
......@@ -587,6 +600,8 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Empty:
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
......@@ -617,6 +632,8 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Empty:
self._cv.wait()
if self._stop:
raise ChannelStopError()
consumer_cursor = self._consumer_cursors[op_name]
base_cursor = self._base_cursor
......@@ -657,7 +674,12 @@ class ThreadChannel(Queue.Queue):
return resp
def stop(self):
#TODO
self.close()
_LOGGER.info(self._log("stop."))
self._stop = True
self._cv.notify_all()
with self._cv:
self._cv.notify_all()
class ChannelStopError(RuntimeError):
def __init__(self):
pass
......@@ -26,7 +26,8 @@ import os
import logging
from .operator import Op, RequestOp, ResponseOp, VirtualOp
from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
ChannelDataEcode, ChannelDataType, ChannelStopError)
from .profiler import TimeProfiler
from .util import NameGenerator
......@@ -34,33 +35,29 @@ _LOGGER = logging.getLogger()
class DAGExecutor(object):
def __init__(self, response_op, yml_config, show_info):
self._retry = yml_config.get('retry', 1)
def __init__(self, response_op, dag_config, show_info):
self._retry = dag_config.get('retry', 1)
client_type = yml_config.get('client_type', 'brpc')
use_multithread = yml_config.get('use_multithread', True)
use_profile = yml_config.get('profile', False)
channel_size = yml_config.get('channel_size', 0)
self._asyn_profile = yml_config.get('asyn_profile', False)
client_type = dag_config.get('client_type', 'brpc')
use_profile = dag_config.get('use_profile', False)
channel_size = dag_config.get('channel_size', 0)
self._is_thread_op = dag_config.get('is_thread_op', True)
if show_info and use_profile:
_LOGGER.info("================= PROFILER ================")
if use_multithread:
if self._is_thread_op:
_LOGGER.info("op: thread")
_LOGGER.info("profile mode: sync")
else:
_LOGGER.info("op: process")
if self._asyn_profile:
_LOGGER.info("profile mode: asyn (This mode is only used"
" when using the process version Op)")
else:
_LOGGER.info("profile mode: sync")
_LOGGER.info("profile mode: asyn")
_LOGGER.info("-------------------------------------------")
self.name = "@G"
self._profiler = TimeProfiler()
self._profiler.enable(use_profile)
self._dag = DAG(self.name, response_op, use_profile, use_multithread,
self._dag = DAG(self.name, response_op, use_profile, self._is_thread_op,
client_type, channel_size, show_info)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
......@@ -80,17 +77,14 @@ class DAGExecutor(object):
self._cv_pool = {}
self._cv_for_cv_pool = threading.Condition()
self._fetch_buffer = None
self._is_run = False
self._recive_func = None
def start(self):
self._is_run = True
self._recive_func = threading.Thread(
target=DAGExecutor._recive_out_channel_func, args=(self, ))
self._recive_func.start()
def stop(self):
self._is_run = False
self._dag.stop()
self._dag.join()
......@@ -119,8 +113,22 @@ class DAGExecutor(object):
def _recive_out_channel_func(self):
cv = None
while self._is_run:
channeldata_dict = self._out_channel.front(self.name)
while True:
try:
channeldata_dict = self._out_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(self._log("stop."))
with self._cv_for_cv_pool:
for data_id, cv in self._cv_pool.items():
closed_errror_data = ChannelData(
ecode=ChannelDataEcode.CLOSED_ERROR.value,
error_info="dag closed.",
data_id=data_id)
with cv:
self._fetch_buffer = closed_errror_data
cv.notify_all()
break
if len(channeldata_dict) != 1:
_LOGGER.error("out_channel cannot have multiple input ops")
os._exit(-1)
......@@ -147,7 +155,6 @@ class DAGExecutor(object):
cv.wait()
_LOGGER.debug("resp func get lock (data_id: {})".format(data_id))
resp = copy.deepcopy(self._fetch_buffer)
# cv.notify_all()
with self._cv_for_cv_pool:
self._cv_pool.pop(data_id)
return resp
......@@ -170,7 +177,7 @@ class DAGExecutor(object):
def call(self, rpc_request):
data_id = self._get_next_data_id()
if self._asyn_profile:
if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id))
else:
self._profiler.record("call_{}#DAG_0".format(data_id))
......@@ -183,7 +190,15 @@ class DAGExecutor(object):
for i in range(self._retry):
_LOGGER.debug(self._log('push data'))
#self._profiler.record("push_{}#{}_0".format(data_id, self.name))
self._in_channel.push(req_channeldata, self.name)
try:
self._in_channel.push(req_channeldata, self.name)
except ChannelStopError:
_LOGGER.info(self._log("stop."))
return self._pack_for_rpc_resp(
ChannelData(
ecode=ChannelDataEcode.CLOSED_ERROR.value,
error_info="dag closed.",
data_id=data_id))
#self._profiler.record("push_{}#{}_1".format(data_id, self.name))
_LOGGER.debug(self._log('wait for infer'))
......@@ -201,7 +216,7 @@ class DAGExecutor(object):
rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
if self._asyn_profile:
if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
else:
self._profiler.record("call_{}#DAG_1".format(data_id))
......@@ -217,16 +232,16 @@ class DAGExecutor(object):
class DAG(object):
def __init__(self, request_name, response_op, use_profile, use_multithread,
def __init__(self, request_name, response_op, use_profile, is_thread_op,
client_type, channel_size, show_info):
self._request_name = request_name
self._response_op = response_op
self._use_profile = use_profile
self._use_multithread = use_multithread
self._is_thread_op = is_thread_op
self._channel_size = channel_size
self._client_type = client_type
self._show_info = show_info
if not self._use_multithread:
if not self._is_thread_op:
self._manager = multiprocessing.Manager()
def get_use_ops(self, response_op):
......@@ -254,7 +269,7 @@ class DAG(object):
def _gen_channel(self, name_gen):
channel = None
if self._use_multithread:
if self._is_thread_op:
channel = ThreadChannel(
name=name_gen.next(), maxsize=self._channel_size)
else:
......@@ -439,7 +454,7 @@ class DAG(object):
self._threads_or_proces = []
for op in self._actual_ops:
op.use_profiler(self._use_profile)
if self._use_multithread:
if self._is_thread_op:
self._threads_or_proces.extend(
op.start_with_thread(self._client_type))
else:
......@@ -453,7 +468,5 @@ class DAG(object):
x.join()
def stop(self):
for op in self._actual_ops:
op.stop()
for chl in self._channels:
chl.stop()
......@@ -24,7 +24,8 @@ import sys
from numpy import *
from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
ChannelData, ChannelDataType, ChannelStopError)
from .util import NameGenerator
from .profiler import TimeProfiler
......@@ -44,7 +45,6 @@ class Op(object):
retry=1):
if name is None:
name = _op_name_gen.next()
self._is_run = False
self.name = name # to identify the type of OP, it must be globally unique
self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops)
......@@ -65,7 +65,9 @@ class Op(object):
# only for multithread
self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
self._succ_close_op = False
def use_profiler(self, use_profile):
self._use_profile = use_profile
......@@ -93,9 +95,6 @@ class Op(object):
self._fetch_names = fetch_names
return client
def _get_input_channel(self):
return self._input
def get_input_ops(self):
return self._input_ops
......@@ -118,8 +117,11 @@ class Op(object):
channel.add_consumer(self.name)
self._input = channel
def _get_output_channels(self):
return self._outputs
def _clean_input_channel(self):
self._input = None
def _get_input_channel(self):
return self._input
def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
......@@ -129,6 +131,12 @@ class Op(object):
channel.add_producer(self.name)
self._outputs.append(channel)
def _clean_output_channels(self):
self._outputs = []
def _get_output_channels(self):
return self._outputs
def preprocess(self, input_dicts):
# multiple previous Op
if len(input_dicts) != 1:
......@@ -144,7 +152,7 @@ class Op(object):
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
call_result = self._client.predict(
call_result = self.client.predict(
feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result
......@@ -152,9 +160,6 @@ class Op(object):
def postprocess(self, input_dict, fetch_dict):
return fetch_dict
def stop(self):
self._is_run = False
def _parse_channeldata(self, channeldata_dict):
data_id, error_channeldata = None, None
parsed_data = {}
......@@ -332,30 +337,42 @@ class Op(object):
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
# init client
self._client = self.init_client(
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# user defined
self.init_op()
self._succ_init_op = True
self._succ_close_op = False
else:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
# init client
self._client = self.init_client(
client_type, self._client_config, self._server_endpoints,
self._fetch_names)
self.client = self.init_client(client_type, self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
self.init_op()
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
self._is_run = True
while self._is_run:
while True:
#self._profiler_record("get#{}_0".format(op_info_prefix))
channeldata_dict = input_channel.front(self.name)
try:
channeldata_dict = input_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(log("stop."))
with self._for_close_op_lock:
if not self._succ_close_op:
self._clean_input_channel()
self._clean_output_channels()
self._profiler = None
self.client = None
self._succ_init_op = False
self._succ_close_op = True
break
#self._profiler_record("get#{}_1".format(op_info_prefix))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
......@@ -363,8 +380,11 @@ class Op(object):
channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
try:
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
continue
# preprecess
......@@ -373,8 +393,11 @@ class Op(object):
data_id, log)
self._profiler_record("prep#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
try:
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
continue
# process
......@@ -383,8 +406,11 @@ class Op(object):
data_id, log)
self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
try:
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
continue
# postprocess
......@@ -393,8 +419,11 @@ class Op(object):
parsed_data, midped_data, data_id, log)
self._profiler_record("postp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
try:
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
continue
if self._use_profile:
......@@ -405,7 +434,11 @@ class Op(object):
# push data to channel (if run succ)
#self._profiler_record("push#{}_0".format(op_info_prefix))
self._push_to_output_channels(output_data, output_channels)
try:
self._push_to_output_channels(output_data, output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
break
#self._profiler_record("push#{}_1".format(op_info_prefix))
#self._profiler.print_profile()
......@@ -525,10 +558,17 @@ class VirtualOp(Op):
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
self._is_run = True
while self._is_run:
channeldata_dict = input_channel.front(self.name)
while True:
try:
channeldata_dict = input_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(log("stop."))
break
for name, data in channeldata_dict.items():
self._push_to_output_channels(
data, channels=output_channels, name=name)
try:
for name, data in channeldata_dict.items():
self._push_to_output_channels(
data, channels=output_channels, name=name)
except ChannelStopError:
_LOGGER.info(log("stop."))
break
......@@ -29,16 +29,36 @@ _LOGGER = logging.getLogger()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, yml_config, show_info=True):
def __init__(self, response_op, dag_config):
super(PipelineService, self).__init__()
# init dag executor
self._dag_executor = DAGExecutor(response_op, yml_config, show_info)
self._dag_executor = DAGExecutor(
response_op, dag_config, show_info=True)
self._dag_executor.start()
def inference(self, request, context):
resp = self._dag_executor.call(request)
return resp
def __del__(self):
self._dag_executor.stop()
class ProcessPipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_config):
super(ProcessPipelineService, self).__init__()
self._response_op = response_op
self._dag_config = dag_config
def inference(self, request, context):
# init dag executor
dag_executor = DAGExecutor(
self._response_op, self._dag_config, show_info=False)
dag_executor.start()
resp = dag_executor.call(request)
dag_executor.stop()
return resp
@contextlib.contextmanager
def _reserve_port(port):
......@@ -75,25 +95,30 @@ class PipelineServer(object):
def prepare_server(self, yml_file):
with open(yml_file) as f:
self._yml_config = yaml.load(f.read())
self._port = self._yml_config.get('port', 8080)
yml_config = yaml.load(f.read())
self._port = yml_config.get('port')
if self._port is None:
raise SystemExit("Please set *port* in [{}] yaml file.".format(
yml_file))
if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = self._yml_config.get('worker_num', 2)
self._multiprocess_servicer = self._yml_config.get(
'multiprocess_servicer', False)
self._worker_num = yml_config.get('worker_num', 1)
self._build_dag_each_request = yml_config.get('build_dag_each_request',
False)
_LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("port: {}".format(self._port))
_LOGGER.info("worker_num: {}".format(self._worker_num))
servicer_info = "multiprocess_servicer: {}".format(
self._multiprocess_servicer)
if self._multiprocess_servicer is True:
servicer_info = "build_dag_each_request: {}".format(
self._build_dag_each_request)
if self._build_dag_each_request is True:
servicer_info += " (Make sure that install grpcio whl with --no-binary flag)"
_LOGGER.info(servicer_info)
_LOGGER.info("-------------------------------------------")
self._dag_config = yml_config.get("dag", {})
def run_server(self):
if self._multiprocess_servicer:
if self._build_dag_each_request:
with _reserve_port(self._port) as port:
bind_address = 'localhost:{}'.format(port)
workers = []
......@@ -101,8 +126,8 @@ class PipelineServer(object):
show_info = (i == 0)
worker = multiprocessing.Process(
target=self._run_server_func,
args=(bind_address, self._response_op, self._yml_config,
self._worker_num, show_info))
args=(bind_address, self._response_op, self._dag_config,
self._worker_num))
worker.start()
workers.append(worker)
for worker in workers:
......@@ -111,20 +136,20 @@ class PipelineServer(object):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._response_op, self._yml_config), server)
PipelineService(self._response_op, self._dag_config), server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
server.wait_for_termination()
def _run_server_func(self, bind_address, response_op, yml_config,
worker_num, show_info):
def _run_server_func(self, bind_address, response_op, dag_config,
worker_num):
options = (('grpc.so_reuseport', 1), )
server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=worker_num, ),
options=options)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(response_op, yml_config, show_info), server)
ProcessPipelineService(response_op, dag_config), server)
server.add_insecure_port(bind_address)
server.start()
server.wait_for_termination()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册