pipeline_server.py 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
15 16 17
from concurrent import futures
import grpc
import logging
B
barrierye 已提交
18
import socket
B
barrierye 已提交
19
import contextlib
B
barrierye 已提交
20
from contextlib import closing
B
barrierye 已提交
21
import multiprocessing
B
barrierye 已提交
22
import yaml
23

B
barrierye 已提交
24
from .proto import pipeline_service_pb2_grpc
B
barrierye 已提交
25
from .operator import ResponseOp
26
from .dag import DAGExecutor
27

W
wangjiawei04 已提交
28
_LOGGER = logging.getLogger()
29 30


B
barrierye 已提交
31
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
B
barrierye 已提交
32
    def __init__(self, response_op, yml_config, show_info=True):
B
barrierye 已提交
33
        super(PipelineService, self).__init__()
B
barrierye 已提交
34 35
        # init dag executor
        self._dag_executor = DAGExecutor(response_op, yml_config, show_info)
B
barrierye 已提交
36
        self._dag_executor.start()
37 38

    def inference(self, request, context):
39
        resp = self._dag_executor.call(request)
40 41 42
        return resp


B
barrierye 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56
@contextlib.contextmanager
def _reserve_port(port):
    """Find and reserve a port for all subprocesses to use."""
    sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
        raise RuntimeError("Failed to set SO_REUSEPORT.")
    sock.bind(('', port))
    try:
        yield sock.getsockname()[1]
    finally:
        sock.close()


57
class PipelineServer(object):
B
barrierye 已提交
58
    def __init__(self):
59 60
        self._port = None
        self._worker_num = None
B
barrierye 已提交
61
        self._response_op = None
62

B
barrierye 已提交
63
    def set_response_op(self, response_op):
B
barrierye 已提交
64 65
        if not isinstance(response_op, ResponseOp):
            raise Exception("response_op must be ResponseOp type.")
B
barrierye 已提交
66 67
        if len(response_op.get_input_ops()) != 1:
            raise Exception("response_op can only have one previous op.")
B
barrierye 已提交
68 69 70 71 72 73 74 75 76 77
        self._response_op = response_op

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def prepare_server(self, yml_file):
        with open(yml_file) as f:
B
barrierye 已提交
78 79
            self._yml_config = yaml.load(f.read())
        self._port = self._yml_config.get('port', 8080)
B
barrierye 已提交
80 81
        if not self._port_is_available(self._port):
            raise SystemExit("Prot {} is already used".format(self._port))
B
barrierye 已提交
82 83 84 85 86 87 88 89 90 91 92 93
        self._worker_num = self._yml_config.get('worker_num', 2)
        self._multiprocess_servicer = self._yml_config.get(
            'multiprocess_servicer', False)
        _LOGGER.info("============= PIPELINE SERVER =============")
        _LOGGER.info("port: {}".format(self._port))
        _LOGGER.info("worker_num: {}".format(self._worker_num))
        servicer_info = "multiprocess_servicer: {}".format(
            self._multiprocess_servicer)
        if self._multiprocess_servicer is True:
            servicer_info += " (Make sure that install grpcio whl with --no-binary flag)"
        _LOGGER.info(servicer_info)
        _LOGGER.info("-------------------------------------------")
94 95

    def run_server(self):
B
barrierye 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
        if self._multiprocess_servicer:
            with _reserve_port(self._port) as port:
                bind_address = 'localhost:{}'.format(port)
                workers = []
                for i in range(self._worker_num):
                    show_info = (i == 0)
                    worker = multiprocessing.Process(
                        target=self._run_server_func,
                        args=(bind_address, self._response_op, self._yml_config,
                              self._worker_num, show_info))
                    worker.start()
                    workers.append(worker)
                for worker in workers:
                    worker.join()
        else:
            server = grpc.server(
                futures.ThreadPoolExecutor(max_workers=self._worker_num))
            pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
                PipelineService(self._response_op, self._yml_config), server)
            server.add_insecure_port('[::]:{}'.format(self._port))
            server.start()
            server.wait_for_termination()

    def _run_server_func(self, bind_address, response_op, yml_config,
                         worker_num, show_info):
        options = (('grpc.so_reuseport', 1), )
122
        server = grpc.server(
B
barrierye 已提交
123 124 125 126 127 128
            futures.ThreadPoolExecutor(
                max_workers=worker_num, ),
            options=options)
        pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
            PipelineService(response_op, yml_config, show_info), server)
        server.add_insecure_port(bind_address)
129 130
        server.start()
        server.wait_for_termination()