提交 048b22bb 编写于 作者: B barrierye

support: start servicer with process

上级 23cd3548
...@@ -34,7 +34,7 @@ _LOGGER = logging.getLogger() ...@@ -34,7 +34,7 @@ _LOGGER = logging.getLogger()
class DAGExecutor(object): class DAGExecutor(object):
def __init__(self, response_op, yml_config): def __init__(self, response_op, yml_config, show_info):
self._retry = yml_config.get('retry', 1) self._retry = yml_config.get('retry', 1)
client_type = yml_config.get('client_type', 'brpc') client_type = yml_config.get('client_type', 'brpc')
...@@ -43,24 +43,25 @@ class DAGExecutor(object): ...@@ -43,24 +43,25 @@ class DAGExecutor(object):
channel_size = yml_config.get('channel_size', 0) channel_size = yml_config.get('channel_size', 0)
self._asyn_profile = yml_config.get('asyn_profile', False) self._asyn_profile = yml_config.get('asyn_profile', False)
if use_profile: if show_info and use_profile:
_LOGGER.info("====> profiler <====") _LOGGER.info("================= PROFILER ================")
if use_multithread: if use_multithread:
_LOGGER.info("op: thread") _LOGGER.info("op: thread")
else: else:
_LOGGER.info("op: process") _LOGGER.info("op: process")
if self._asyn_profile: if self._asyn_profile:
_LOGGER.info("profile mode: asyn") _LOGGER.info("profile mode: asyn (This mode is only used"
" when using the process version Op)")
else: else:
_LOGGER.info("profile mode: sync") _LOGGER.info("profile mode: sync")
_LOGGER.info("====================") _LOGGER.info("-------------------------------------------")
self.name = "@G" self.name = "@G"
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(use_profile) self._profiler.enable(use_profile)
self._dag = DAG(response_op, use_profile, use_multithread, client_type, self._dag = DAG(self.name, response_op, use_profile, use_multithread,
channel_size) client_type, channel_size, show_info)
(in_channel, out_channel, pack_rpc_func, (in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build() unpack_rpc_func) = self._dag.build()
self._dag.start() self._dag.start()
...@@ -216,13 +217,15 @@ class DAGExecutor(object): ...@@ -216,13 +217,15 @@ class DAGExecutor(object):
class DAG(object): class DAG(object):
def __init__(self, response_op, use_profile, use_multithread, client_type, def __init__(self, request_name, response_op, use_profile, use_multithread,
channel_size): client_type, channel_size, show_info):
self._request_name = request_name
self._response_op = response_op self._response_op = response_op
self._use_profile = use_profile self._use_profile = use_profile
self._use_multithread = use_multithread self._use_multithread = use_multithread
self._channel_size = channel_size self._channel_size = channel_size
self._client_type = client_type self._client_type = client_type
self._show_info = show_info
if not self._use_multithread: if not self._use_multithread:
self._manager = multiprocessing.Manager() self._manager = multiprocessing.Manager()
...@@ -306,10 +309,12 @@ class DAG(object): ...@@ -306,10 +309,12 @@ class DAG(object):
if response_op is None: if response_op is None:
raise Exception("response_op has not been set.") raise Exception("response_op has not been set.")
used_ops, out_degree_ops = self.get_use_ops(response_op) used_ops, out_degree_ops = self.get_use_ops(response_op)
_LOGGER.info("================= use op ==================") if self._show_info:
for op in used_ops: _LOGGER.info("================= USED OP =================")
_LOGGER.info(op.name) for op in used_ops:
_LOGGER.info("===========================================") if op.name != self._request_name:
_LOGGER.info(op.name)
_LOGGER.info("-------------------------------------------")
if len(used_ops) <= 1: if len(used_ops) <= 1:
raise Exception( raise Exception(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG." "Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
...@@ -317,6 +322,16 @@ class DAG(object): ...@@ -317,6 +322,16 @@ class DAG(object):
dag_views, last_op = self._topo_sort(used_ops, response_op, dag_views, last_op = self._topo_sort(used_ops, response_op,
out_degree_ops) out_degree_ops)
dag_views = list(reversed(dag_views))
if self._show_info:
_LOGGER.info("================== DAG ====================")
for idx, view in enumerate(dag_views):
_LOGGER.info("(VIEW {})".format(idx))
for op in view:
_LOGGER.info(" [{}]".format(op.name))
for out_op in out_degree_ops[op.name]:
_LOGGER.info(" - {}".format(out_op.name))
_LOGGER.info("-------------------------------------------")
# create channels and virtual ops # create channels and virtual ops
virtual_op_name_gen = NameGenerator("vir") virtual_op_name_gen = NameGenerator("vir")
...@@ -325,7 +340,6 @@ class DAG(object): ...@@ -325,7 +340,6 @@ class DAG(object):
channels = [] channels = []
input_channel = None input_channel = None
actual_view = None actual_view = None
dag_views = list(reversed(dag_views))
for v_idx, view in enumerate(dag_views): for v_idx, view in enumerate(dag_views):
if v_idx + 1 >= len(dag_views): if v_idx + 1 >= len(dag_views):
break break
......
...@@ -16,7 +16,9 @@ from concurrent import futures ...@@ -16,7 +16,9 @@ from concurrent import futures
import grpc import grpc
import logging import logging
import socket import socket
import contextlib
from contextlib import closing from contextlib import closing
import multiprocessing
import yaml import yaml
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
...@@ -27,9 +29,10 @@ _LOGGER = logging.getLogger() ...@@ -27,9 +29,10 @@ _LOGGER = logging.getLogger()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, dag_executor): def __init__(self, response_op, yml_config, show_info=True):
super(PipelineService, self).__init__() super(PipelineService, self).__init__()
self._dag_executor = dag_executor # init dag executor
self._dag_executor = DAGExecutor(response_op, yml_config, show_info)
self._dag_executor.start() self._dag_executor.start()
def inference(self, request, context): def inference(self, request, context):
...@@ -37,6 +40,20 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -37,6 +40,20 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
return resp return resp
@contextlib.contextmanager
def _reserve_port(port):
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(('', port))
try:
yield sock.getsockname()[1]
finally:
sock.close()
class PipelineServer(object): class PipelineServer(object):
def __init__(self): def __init__(self):
self._port = None self._port = None
...@@ -58,21 +75,56 @@ class PipelineServer(object): ...@@ -58,21 +75,56 @@ class PipelineServer(object):
def prepare_server(self, yml_file): def prepare_server(self, yml_file):
with open(yml_file) as f: with open(yml_file) as f:
yml_config = yaml.load(f.read()) self._yml_config = yaml.load(f.read())
self._port = yml_config.get('port', 8080) self._port = self._yml_config.get('port', 8080)
if not self._port_is_available(self._port): if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port)) raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = yml_config.get('worker_num', 2) self._worker_num = self._yml_config.get('worker_num', 2)
self._multiprocess_servicer = self._yml_config.get(
# init dag executor 'multiprocess_servicer', False)
self._dag_executor = DAGExecutor(self._response_op, yml_config) _LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("port: {}".format(self._port))
_LOGGER.info("worker_num: {}".format(self._worker_num))
servicer_info = "multiprocess_servicer: {}".format(
self._multiprocess_servicer)
if self._multiprocess_servicer is True:
servicer_info += " (Make sure that install grpcio whl with --no-binary flag)"
_LOGGER.info(servicer_info)
_LOGGER.info("-------------------------------------------")
def run_server(self): def run_server(self):
service = PipelineService(self._dag_executor) if self._multiprocess_servicer:
with _reserve_port(self._port) as port:
bind_address = 'localhost:{}'.format(port)
workers = []
for i in range(self._worker_num):
show_info = (i == 0)
worker = multiprocessing.Process(
target=self._run_server_func,
args=(bind_address, self._response_op, self._yml_config,
self._worker_num, show_info))
worker.start()
workers.append(worker)
for worker in workers:
worker.join()
else:
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._response_op, self._yml_config), server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
server.wait_for_termination()
def _run_server_func(self, bind_address, response_op, yml_config,
worker_num, show_info):
options = (('grpc.so_reuseport', 1), )
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(service, max_workers=worker_num, ),
server) options=options)
server.add_insecure_port('[::]:{}'.format(self._port)) pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(response_op, yml_config, show_info), server)
server.add_insecure_port(bind_address)
server.start() server.start()
server.wait_for_termination() server.wait_for_termination()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册