From 048b22bbd49050bbf53ffe824e585af7f717430f Mon Sep 17 00:00:00 2001 From: barrierye Date: Wed, 8 Jul 2020 03:30:13 +0800 Subject: [PATCH] support: start servicer with process --- python/pipeline/dag.py | 42 ++++++++++------ python/pipeline/pipeline_server.py | 78 +++++++++++++++++++++++++----- 2 files changed, 93 insertions(+), 27 deletions(-) diff --git a/python/pipeline/dag.py b/python/pipeline/dag.py index f162d25d..deeb87ed 100644 --- a/python/pipeline/dag.py +++ b/python/pipeline/dag.py @@ -34,7 +34,7 @@ _LOGGER = logging.getLogger() class DAGExecutor(object): - def __init__(self, response_op, yml_config): + def __init__(self, response_op, yml_config, show_info): self._retry = yml_config.get('retry', 1) client_type = yml_config.get('client_type', 'brpc') @@ -43,24 +43,25 @@ class DAGExecutor(object): channel_size = yml_config.get('channel_size', 0) self._asyn_profile = yml_config.get('asyn_profile', False) - if use_profile: - _LOGGER.info("====> profiler <====") + if show_info and use_profile: + _LOGGER.info("================= PROFILER ================") if use_multithread: _LOGGER.info("op: thread") else: _LOGGER.info("op: process") if self._asyn_profile: - _LOGGER.info("profile mode: asyn") + _LOGGER.info("profile mode: asyn (This mode is only used" + " when using the process version Op)") else: _LOGGER.info("profile mode: sync") - _LOGGER.info("====================") + _LOGGER.info("-------------------------------------------") self.name = "@G" self._profiler = TimeProfiler() self._profiler.enable(use_profile) - self._dag = DAG(response_op, use_profile, use_multithread, client_type, - channel_size) + self._dag = DAG(self.name, response_op, use_profile, use_multithread, + client_type, channel_size, show_info) (in_channel, out_channel, pack_rpc_func, unpack_rpc_func) = self._dag.build() self._dag.start() @@ -216,13 +217,15 @@ class DAGExecutor(object): class DAG(object): - def __init__(self, response_op, use_profile, use_multithread, client_type, - channel_size): + def __init__(self, request_name, response_op, use_profile, use_multithread, + client_type, channel_size, show_info): + self._request_name = request_name self._response_op = response_op self._use_profile = use_profile self._use_multithread = use_multithread self._channel_size = channel_size self._client_type = client_type + self._show_info = show_info if not self._use_multithread: self._manager = multiprocessing.Manager() @@ -306,10 +309,12 @@ class DAG(object): if response_op is None: raise Exception("response_op has not been set.") used_ops, out_degree_ops = self.get_use_ops(response_op) - _LOGGER.info("================= use op ==================") - for op in used_ops: - _LOGGER.info(op.name) - _LOGGER.info("===========================================") + if self._show_info: + _LOGGER.info("================= USED OP =================") + for op in used_ops: + if op.name != self._request_name: + _LOGGER.info(op.name) + _LOGGER.info("-------------------------------------------") if len(used_ops) <= 1: raise Exception( "Besides RequestOp and ResponseOp, there should be at least one Op in DAG." @@ -317,6 +322,16 @@ class DAG(object): dag_views, last_op = self._topo_sort(used_ops, response_op, out_degree_ops) + dag_views = list(reversed(dag_views)) + if self._show_info: + _LOGGER.info("================== DAG ====================") + for idx, view in enumerate(dag_views): + _LOGGER.info("(VIEW {})".format(idx)) + for op in view: + _LOGGER.info(" [{}]".format(op.name)) + for out_op in out_degree_ops[op.name]: + _LOGGER.info(" - {}".format(out_op.name)) + _LOGGER.info("-------------------------------------------") # create channels and virtual ops virtual_op_name_gen = NameGenerator("vir") @@ -325,7 +340,6 @@ class DAG(object): channels = [] input_channel = None actual_view = None - dag_views = list(reversed(dag_views)) for v_idx, view in enumerate(dag_views): if v_idx + 1 >= len(dag_views): break diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index f1a07f7c..e41556e0 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -16,7 +16,9 @@ from concurrent import futures import grpc import logging import socket +import contextlib from contextlib import closing +import multiprocessing import yaml from .proto import pipeline_service_pb2_grpc @@ -27,9 +29,10 @@ _LOGGER = logging.getLogger() class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): - def __init__(self, dag_executor): + def __init__(self, response_op, yml_config, show_info=True): super(PipelineService, self).__init__() - self._dag_executor = dag_executor + # init dag executor + self._dag_executor = DAGExecutor(response_op, yml_config, show_info) self._dag_executor.start() def inference(self, request, context): @@ -37,6 +40,20 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): return resp +@contextlib.contextmanager +def _reserve_port(port): + """Find and reserve a port for all subprocesses to use.""" + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: + raise RuntimeError("Failed to set SO_REUSEPORT.") + sock.bind(('', port)) + try: + yield sock.getsockname()[1] + finally: + sock.close() + + class PipelineServer(object): def __init__(self): self._port = None @@ -58,21 +75,56 @@ class PipelineServer(object): def prepare_server(self, yml_file): with open(yml_file) as f: - yml_config = yaml.load(f.read()) - self._port = yml_config.get('port', 8080) + self._yml_config = yaml.load(f.read()) + self._port = self._yml_config.get('port', 8080) if not self._port_is_available(self._port): raise SystemExit("Prot {} is already used".format(self._port)) - self._worker_num = yml_config.get('worker_num', 2) - - # init dag executor - self._dag_executor = DAGExecutor(self._response_op, yml_config) + self._worker_num = self._yml_config.get('worker_num', 2) + self._multiprocess_servicer = self._yml_config.get( + 'multiprocess_servicer', False) + _LOGGER.info("============= PIPELINE SERVER =============") + _LOGGER.info("port: {}".format(self._port)) + _LOGGER.info("worker_num: {}".format(self._worker_num)) + servicer_info = "multiprocess_servicer: {}".format( + self._multiprocess_servicer) + if self._multiprocess_servicer is True: + servicer_info += " (Make sure that install grpcio whl with --no-binary flag)" + _LOGGER.info(servicer_info) + _LOGGER.info("-------------------------------------------") def run_server(self): - service = PipelineService(self._dag_executor) + if self._multiprocess_servicer: + with _reserve_port(self._port) as port: + bind_address = 'localhost:{}'.format(port) + workers = [] + for i in range(self._worker_num): + show_info = (i == 0) + worker = multiprocessing.Process( + target=self._run_server_func, + args=(bind_address, self._response_op, self._yml_config, + self._worker_num, show_info)) + worker.start() + workers.append(worker) + for worker in workers: + worker.join() + else: + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self._worker_num)) + pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( + PipelineService(self._response_op, self._yml_config), server) + server.add_insecure_port('[::]:{}'.format(self._port)) + server.start() + server.wait_for_termination() + + def _run_server_func(self, bind_address, response_op, yml_config, + worker_num, show_info): + options = (('grpc.so_reuseport', 1), ) server = grpc.server( - futures.ThreadPoolExecutor(max_workers=self._worker_num)) - pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(service, - server) - server.add_insecure_port('[::]:{}'.format(self._port)) + futures.ThreadPoolExecutor( + max_workers=worker_num, ), + options=options) + pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( + PipelineService(response_op, yml_config, show_info), server) + server.add_insecure_port(bind_address) server.start() server.wait_for_termination() -- GitLab