pipeline_server.py 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
15 16 17
from concurrent import futures
import grpc
import logging
B
barrierye 已提交
18
import socket
B
barrierye 已提交
19
import contextlib
B
barrierye 已提交
20
from contextlib import closing
B
barrierye 已提交
21
import multiprocessing
B
barrierye 已提交
22
import yaml
23

B
barrierye 已提交
24
from .proto import pipeline_service_pb2_grpc
B
barrierye 已提交
25
from .operator import ResponseOp
26
from .dag import DAGExecutor
27

W
wangjiawei04 已提交
28
_LOGGER = logging.getLogger()
29 30


B
barrierye 已提交
31
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
32
    def __init__(self, response_op, dag_config, show_info):
B
barrierye 已提交
33
        super(PipelineService, self).__init__()
B
barrierye 已提交
34
        # init dag executor
B
barrierye 已提交
35
        self._dag_executor = DAGExecutor(
36
            response_op, dag_config, show_info=show_info)
B
barrierye 已提交
37
        self._dag_executor.start()
38 39

    def inference(self, request, context):
40
        resp = self._dag_executor.call(request)
41 42
        return resp

B
barrierye 已提交
43 44 45 46
    def __del__(self):
        self._dag_executor.stop()


B
barrierye 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60
@contextlib.contextmanager
def _reserve_port(port):
    """Find and reserve a port for all subprocesses to use."""
    sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
        raise RuntimeError("Failed to set SO_REUSEPORT.")
    sock.bind(('', port))
    try:
        yield sock.getsockname()[1]
    finally:
        sock.close()


61
class PipelineServer(object):
B
barrierye 已提交
62
    def __init__(self):
63 64
        self._port = None
        self._worker_num = None
B
barrierye 已提交
65
        self._response_op = None
66

B
barrierye 已提交
67
    def set_response_op(self, response_op):
B
barrierye 已提交
68 69
        if not isinstance(response_op, ResponseOp):
            raise Exception("response_op must be ResponseOp type.")
B
barrierye 已提交
70 71
        if len(response_op.get_input_ops()) != 1:
            raise Exception("response_op can only have one previous op.")
B
barrierye 已提交
72 73 74 75 76 77 78 79 80 81
        self._response_op = response_op

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def prepare_server(self, yml_file):
        with open(yml_file) as f:
B
barrierye 已提交
82 83 84 85 86
            yml_config = yaml.load(f.read())
        self._port = yml_config.get('port')
        if self._port is None:
            raise SystemExit("Please set *port* in [{}] yaml file.".format(
                yml_file))
B
barrierye 已提交
87 88
        if not self._port_is_available(self._port):
            raise SystemExit("Prot {} is already used".format(self._port))
B
barrierye 已提交
89
        self._worker_num = yml_config.get('worker_num', 1)
90 91
        self._build_dag_each_worker = yml_config.get('build_dag_each_worker',
                                                     False)
B
barrierye 已提交
92 93 94
        _LOGGER.info("============= PIPELINE SERVER =============")
        _LOGGER.info("port: {}".format(self._port))
        _LOGGER.info("worker_num: {}".format(self._worker_num))
95 96 97
        servicer_info = "build_dag_each_worker: {}".format(
            self._build_dag_each_worker)
        if self._build_dag_each_worker is True:
B
barrierye 已提交
98 99 100
            servicer_info += " (Make sure that install grpcio whl with --no-binary flag)"
        _LOGGER.info(servicer_info)
        _LOGGER.info("-------------------------------------------")
101

B
barrierye 已提交
102 103
        self._dag_config = yml_config.get("dag", {})

104
    def run_server(self):
105
        if self._build_dag_each_worker:
B
barrierye 已提交
106 107 108 109 110 111 112
            with _reserve_port(self._port) as port:
                bind_address = 'localhost:{}'.format(port)
                workers = []
                for i in range(self._worker_num):
                    show_info = (i == 0)
                    worker = multiprocessing.Process(
                        target=self._run_server_func,
B
barrierye 已提交
113 114
                        args=(bind_address, self._response_op,
                              self._dag_config))
B
barrierye 已提交
115 116 117 118 119 120 121 122
                    worker.start()
                    workers.append(worker)
                for worker in workers:
                    worker.join()
        else:
            server = grpc.server(
                futures.ThreadPoolExecutor(max_workers=self._worker_num))
            pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
123 124
                PipelineService(self._response_op, self._dag_config, True),
                server)
B
barrierye 已提交
125 126 127 128
            server.add_insecure_port('[::]:{}'.format(self._port))
            server.start()
            server.wait_for_termination()

B
barrierye 已提交
129
    def _run_server_func(self, bind_address, response_op, dag_config):
B
barrierye 已提交
130
        options = (('grpc.so_reuseport', 1), )
131
        server = grpc.server(
B
barrierye 已提交
132
            futures.ThreadPoolExecutor(
B
barrierye 已提交
133
                max_workers=1, ), options=options)
B
barrierye 已提交
134
        pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
135
            PipelineService(response_op, dag_config, False), server)
B
barrierye 已提交
136
        server.add_insecure_port(bind_address)
137 138
        server.start()
        server.wait_for_termination()