pipeline_server.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
15 16 17
from concurrent import futures
import grpc
import logging
B
barrierye 已提交
18
import socket
B
barrierye 已提交
19
import contextlib
B
barrierye 已提交
20
from contextlib import closing
B
barrierye 已提交
21
import multiprocessing
B
barrierye 已提交
22
import yaml
23

B
barrierye 已提交
24
from .proto import pipeline_service_pb2_grpc
B
barrierye 已提交
25
from .operator import ResponseOp
26
from .dag import DAGExecutor
27

W
wangjiawei04 已提交
28
_LOGGER = logging.getLogger()
29 30


B
barriery 已提交
31
class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
32
    def __init__(self, response_op, dag_config, show_info):
B
barriery 已提交
33
        super(PipelineServicer, self).__init__()
B
barrierye 已提交
34
        # init dag executor
B
barrierye 已提交
35
        self._dag_executor = DAGExecutor(
36
            response_op, dag_config, show_info=show_info)
B
barrierye 已提交
37
        self._dag_executor.start()
B
barriery 已提交
38
        _LOGGER.info("[PipelineServicer] succ init")
39 40

    def inference(self, request, context):
41
        resp = self._dag_executor.call(request)
42 43
        return resp

B
barrierye 已提交
44 45 46 47
    def __del__(self):
        self._dag_executor.stop()


B
barrierye 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61
@contextlib.contextmanager
def _reserve_port(port):
    """Find and reserve a port for all subprocesses to use."""
    sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
        raise RuntimeError("Failed to set SO_REUSEPORT.")
    sock.bind(('', port))
    try:
        yield sock.getsockname()[1]
    finally:
        sock.close()


62
class PipelineServer(object):
B
barrierye 已提交
63
    def __init__(self):
64 65
        self._port = None
        self._worker_num = None
B
barrierye 已提交
66
        self._response_op = None
67

B
barrierye 已提交
68
    def set_response_op(self, response_op):
B
barrierye 已提交
69 70
        if not isinstance(response_op, ResponseOp):
            raise Exception("response_op must be ResponseOp type.")
B
barrierye 已提交
71 72
        if len(response_op.get_input_ops()) != 1:
            raise Exception("response_op can only have one previous op.")
B
barrierye 已提交
73 74 75 76 77 78 79 80 81 82
        self._response_op = response_op

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def prepare_server(self, yml_file):
        with open(yml_file) as f:
B
barrierye 已提交
83
            yml_config = yaml.load(f.read())
B
barriery 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96
        default_config = {
            "port": 9292,
            "worker_num": 1,
            "build_dag_each_worker": False,
        }

        for key, val in default_config.items():
            if yml_config.get(key) is None:
                _LOGGER.warning("[CONF] {} not set, use default: {}"
                        .format(key, val))
                yml_config[key] = val

        self._port = yml_config["port"]
B
barrierye 已提交
97 98
        if not self._port_is_available(self._port):
            raise SystemExit("Prot {} is already used".format(self._port))
B
barriery 已提交
99 100 101
        self._worker_num = yml_config["worker_num"]
        self._build_dag_each_worker = yml_config["build_dag_each_worker"]
        
B
barrierye 已提交
102
        _LOGGER.info("============= PIPELINE SERVER =============")
B
barriery 已提交
103 104
        for key in default_config.keys():
            _LOGGER.info("{}: {}".format(key, yml_config[key]))
105
        if self._build_dag_each_worker is True:
B
barriery 已提交
106
            _LOGGER.info("(Make sure that install grpcio whl with --no-binary flag)")
B
barrierye 已提交
107
        _LOGGER.info("-------------------------------------------")
108

B
barrierye 已提交
109
        self._dag_config = yml_config.get("dag", {})
B
barriery 已提交
110
        self._dag_config["build_dag_each_worker"] = self._build_dag_each_worker
B
barrierye 已提交
111

112
    def run_server(self):
113
        if self._build_dag_each_worker:
B
barrierye 已提交
114 115 116 117 118 119 120
            with _reserve_port(self._port) as port:
                bind_address = 'localhost:{}'.format(port)
                workers = []
                for i in range(self._worker_num):
                    show_info = (i == 0)
                    worker = multiprocessing.Process(
                        target=self._run_server_func,
B
barrierye 已提交
121 122
                        args=(bind_address, self._response_op,
                              self._dag_config))
B
barrierye 已提交
123 124 125 126 127 128 129 130
                    worker.start()
                    workers.append(worker)
                for worker in workers:
                    worker.join()
        else:
            server = grpc.server(
                futures.ThreadPoolExecutor(max_workers=self._worker_num))
            pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
B
barriery 已提交
131
                PipelineServicer(self._response_op, self._dag_config, True),
132
                server)
B
barrierye 已提交
133 134 135 136
            server.add_insecure_port('[::]:{}'.format(self._port))
            server.start()
            server.wait_for_termination()

B
barrierye 已提交
137
    def _run_server_func(self, bind_address, response_op, dag_config):
B
barrierye 已提交
138
        options = (('grpc.so_reuseport', 1), )
139
        server = grpc.server(
B
barrierye 已提交
140
            futures.ThreadPoolExecutor(
B
barrierye 已提交
141
                max_workers=1, ), options=options)
B
barrierye 已提交
142
        pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
B
barriery 已提交
143
            PipelineServicer(response_op, dag_config, False), server)
B
barrierye 已提交
144
        server.add_insecure_port(bind_address)
145 146
        server.start()
        server.wait_for_termination()