pipeline_server.py 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
15 16 17
from concurrent import futures
import grpc
import logging
B
barrierye 已提交
18 19 20
import socket
from contextlib import closing
import yaml
21

B
barrierye 已提交
22
from .proto import pipeline_service_pb2_grpc
23
from .operator import Op
24
from .profiler import TimeProfiler
25
from .dag import DAGExecutor
26

W
wangjiawei04 已提交
27
_LOGGER = logging.getLogger()
28 29 30
_profiler = TimeProfiler()


B
barrierye 已提交
31
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
32
    def __init__(self, dag_executor):
B
barrierye 已提交
33
        super(PipelineService, self).__init__()
34
        self._dag_executor = dag_executor
35 36

    def inference(self, request, context):
37
        resp = self._dag_executor.call(request)
38 39 40 41
        return resp


class PipelineServer(object):
B
barrierye 已提交
42
    def __init__(self):
43 44
        self._port = None
        self._worker_num = None
B
barrierye 已提交
45
        self._response_op = None
46 47

    def gen_desc(self):
B
barrierye 已提交
48
        _LOGGER.info('here will generate desc for PAAS')
49 50
        pass

B
barrierye 已提交
51 52 53
    def set_response_op(self, response_op):
        if not isinstance(response_op, Op):
            raise Exception("response_op must be Op type.")
B
barrierye 已提交
54 55
        if len(response_op.get_input_ops()) != 1:
            raise Exception("response_op can only have one previous op.")
B
barrierye 已提交
56 57 58 59 60 61 62 63 64 65
        self._response_op = response_op

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def prepare_server(self, yml_file):
        with open(yml_file) as f:
B
barrierye 已提交
66
            yml_config = yaml.load(f.read())
B
barrierye 已提交
67 68 69 70 71
        self._port = yml_config.get('port', 8080)
        if not self._port_is_available(self._port):
            raise SystemExit("Prot {} is already used".format(self._port))
        self._worker_num = yml_config.get('worker_num', 2)

72 73 74 75
        retry = yml_config.get('retry', 1)
        client_type = yml_config.get('client_type', 'brpc')
        use_multithread = yml_config.get('use_multithread', True)
        use_profile = yml_config.get('profile', False)
B
barrierye 已提交
76
        channel_size = yml_config.get('channel_size', 0)
77

78 79
        if not use_multithread:
            if use_profile:
B
barrierye 已提交
80 81
                raise Exception(
                    "profile cannot be used in multiprocess version temporarily")
82
        _profiler.enable(use_profile)
B
barrierye 已提交
83

84 85 86 87 88
        # init dag executor
        self._dag_executor = DAGExecutor(self._response_op, _profiler,
                                         use_multithread, retry, client_type,
                                         channel_size)
        self._dag_executor.start()
89

90
        self.gen_desc()
91 92

    def run_server(self):
93
        service = PipelineService(self._dag_executor)
94 95
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
96 97
        pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(service,
                                                                        server)
98 99 100
        server.add_insecure_port('[::]:{}'.format(self._port))
        server.start()
        server.wait_for_termination()