提交 b92e2b77 编写于 作者: B barrierye

change `build DAG each request` to `build DAG each worker`

上级 b4029edb
port: 18080
worker_num: 1
build_dag_each_request: false
build_dag_each_worker: false
dag:
is_thread_op: true
client_type: brpc
......
......@@ -29,11 +29,11 @@ _LOGGER = logging.getLogger()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_config):
def __init__(self, response_op, dag_config, show_info):
super(PipelineService, self).__init__()
# init dag executor
self._dag_executor = DAGExecutor(
response_op, dag_config, show_info=True)
response_op, dag_config, show_info=show_info)
self._dag_executor.start()
def inference(self, request, context):
......@@ -44,22 +44,6 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
self._dag_executor.stop()
class ProcessPipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, response_op, dag_config):
super(ProcessPipelineService, self).__init__()
self._response_op = response_op
self._dag_config = dag_config
def inference(self, request, context):
# init dag executor
dag_executor = DAGExecutor(
self._response_op, self._dag_config, show_info=False)
dag_executor.start()
resp = dag_executor.call(request)
dag_executor.stop()
return resp
@contextlib.contextmanager
def _reserve_port(port):
"""Find and reserve a port for all subprocesses to use."""
......@@ -103,14 +87,14 @@ class PipelineServer(object):
if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = yml_config.get('worker_num', 1)
self._build_dag_each_request = yml_config.get('build_dag_each_request',
False)
self._build_dag_each_worker = yml_config.get('build_dag_each_worker',
False)
_LOGGER.info("============= PIPELINE SERVER =============")
_LOGGER.info("port: {}".format(self._port))
_LOGGER.info("worker_num: {}".format(self._worker_num))
servicer_info = "build_dag_each_request: {}".format(
self._build_dag_each_request)
if self._build_dag_each_request is True:
servicer_info = "build_dag_each_worker: {}".format(
self._build_dag_each_worker)
if self._build_dag_each_worker is True:
servicer_info += " (Make sure that install grpcio whl with --no-binary flag)"
_LOGGER.info(servicer_info)
_LOGGER.info("-------------------------------------------")
......@@ -118,7 +102,7 @@ class PipelineServer(object):
self._dag_config = yml_config.get("dag", {})
def run_server(self):
if self._build_dag_each_request:
if self._build_dag_each_worker:
with _reserve_port(self._port) as port:
bind_address = 'localhost:{}'.format(port)
workers = []
......@@ -136,7 +120,8 @@ class PipelineServer(object):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._response_op, self._dag_config), server)
PipelineService(self._response_op, self._dag_config, True),
server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
server.wait_for_termination()
......@@ -147,7 +132,7 @@ class PipelineServer(object):
futures.ThreadPoolExecutor(
max_workers=1, ), options=options)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
ProcessPipelineService(response_op, dag_config), server)
PipelineService(response_op, dag_config, False), server)
server.add_insecure_port(bind_address)
server.start()
server.wait_for_termination()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册