pipeline_server.py 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
15 16 17
from concurrent import futures
import grpc
import logging
B
barrierye 已提交
18
import socket
B
barrierye 已提交
19
import contextlib
B
barrierye 已提交
20
from contextlib import closing
B
barrierye 已提交
21
import multiprocessing
B
barrierye 已提交
22
import yaml
23

B
barrierye 已提交
24
from .proto import pipeline_service_pb2_grpc
B
barrierye 已提交
25
from .operator import ResponseOp
26
from .dag import DAGExecutor
27

W
wangjiawei04 已提交
28
_LOGGER = logging.getLogger()
29 30


B
barrierye 已提交
31
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
B
barrierye 已提交
32
    def __init__(self, response_op, dag_config):
B
barrierye 已提交
33
        super(PipelineService, self).__init__()
B
barrierye 已提交
34
        # init dag executor
B
barrierye 已提交
35 36
        self._dag_executor = DAGExecutor(
            response_op, dag_config, show_info=True)
B
barrierye 已提交
37
        self._dag_executor.start()
38 39

    def inference(self, request, context):
40
        resp = self._dag_executor.call(request)
41 42
        return resp

B
barrierye 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    def __del__(self):
        self._dag_executor.stop()


class ProcessPipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
    def __init__(self, response_op, dag_config):
        super(ProcessPipelineService, self).__init__()
        self._response_op = response_op
        self._dag_config = dag_config

    def inference(self, request, context):
        # init dag executor
        dag_executor = DAGExecutor(
            self._response_op, self._dag_config, show_info=False)
        dag_executor.start()
        resp = dag_executor.call(request)
        dag_executor.stop()
        return resp

62

B
barrierye 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76
@contextlib.contextmanager
def _reserve_port(port):
    """Find and reserve a port for all subprocesses to use."""
    sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
        raise RuntimeError("Failed to set SO_REUSEPORT.")
    sock.bind(('', port))
    try:
        yield sock.getsockname()[1]
    finally:
        sock.close()


77
class PipelineServer(object):
B
barrierye 已提交
78
    def __init__(self):
79 80
        self._port = None
        self._worker_num = None
B
barrierye 已提交
81
        self._response_op = None
82

B
barrierye 已提交
83
    def set_response_op(self, response_op):
B
barrierye 已提交
84 85
        if not isinstance(response_op, ResponseOp):
            raise Exception("response_op must be ResponseOp type.")
B
barrierye 已提交
86 87
        if len(response_op.get_input_ops()) != 1:
            raise Exception("response_op can only have one previous op.")
B
barrierye 已提交
88 89 90 91 92 93 94 95 96 97
        self._response_op = response_op

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def prepare_server(self, yml_file):
        with open(yml_file) as f:
B
barrierye 已提交
98 99 100 101 102
            yml_config = yaml.load(f.read())
        self._port = yml_config.get('port')
        if self._port is None:
            raise SystemExit("Please set *port* in [{}] yaml file.".format(
                yml_file))
B
barrierye 已提交
103 104
        if not self._port_is_available(self._port):
            raise SystemExit("Prot {} is already used".format(self._port))
B
barrierye 已提交
105 106 107
        self._worker_num = yml_config.get('worker_num', 1)
        self._build_dag_each_request = yml_config.get('build_dag_each_request',
                                                      False)
B
barrierye 已提交
108 109 110
        _LOGGER.info("============= PIPELINE SERVER =============")
        _LOGGER.info("port: {}".format(self._port))
        _LOGGER.info("worker_num: {}".format(self._worker_num))
B
barrierye 已提交
111 112 113
        servicer_info = "build_dag_each_request: {}".format(
            self._build_dag_each_request)
        if self._build_dag_each_request is True:
B
barrierye 已提交
114 115 116
            servicer_info += " (Make sure that install grpcio whl with --no-binary flag)"
        _LOGGER.info(servicer_info)
        _LOGGER.info("-------------------------------------------")
117

B
barrierye 已提交
118 119
        self._dag_config = yml_config.get("dag", {})

120
    def run_server(self):
B
barrierye 已提交
121
        if self._build_dag_each_request:
B
barrierye 已提交
122 123 124 125 126 127 128
            with _reserve_port(self._port) as port:
                bind_address = 'localhost:{}'.format(port)
                workers = []
                for i in range(self._worker_num):
                    show_info = (i == 0)
                    worker = multiprocessing.Process(
                        target=self._run_server_func,
B
barrierye 已提交
129 130
                        args=(bind_address, self._response_op,
                              self._dag_config))
B
barrierye 已提交
131 132 133 134 135 136 137 138
                    worker.start()
                    workers.append(worker)
                for worker in workers:
                    worker.join()
        else:
            server = grpc.server(
                futures.ThreadPoolExecutor(max_workers=self._worker_num))
            pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
B
barrierye 已提交
139
                PipelineService(self._response_op, self._dag_config), server)
B
barrierye 已提交
140 141 142 143
            server.add_insecure_port('[::]:{}'.format(self._port))
            server.start()
            server.wait_for_termination()

B
barrierye 已提交
144
    def _run_server_func(self, bind_address, response_op, dag_config):
B
barrierye 已提交
145
        options = (('grpc.so_reuseport', 1), )
146
        server = grpc.server(
B
barrierye 已提交
147
            futures.ThreadPoolExecutor(
B
barrierye 已提交
148
                max_workers=1, ), options=options)
B
barrierye 已提交
149
        pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
B
barrierye 已提交
150
            ProcessPipelineService(response_op, dag_config), server)
B
barrierye 已提交
151
        server.add_insecure_port(bind_address)
152 153
        server.start()
        server.wait_for_termination()