pyserver.py 13.5 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18 19 20 21
import os
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
B
barrierye 已提交
22
import numpy as np
B
barrierye 已提交
23 24 25
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
B
barrierye 已提交
26
import python_service_channel_pb2
B
barrierye 已提交
27 28
import logging
import time
B
barrierye 已提交
29 30


B
barrierye 已提交
31
class Channel(Queue.Queue):
32
    def __init__(self, maxsize=-1, timeout=None, batchsize=1):
B
barrierye 已提交
33
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
34 35 36 37 38 39
        self._maxsize = maxsize
        self._timeout = timeout
        self._batchsize = batchsize
        self._pushlock = threading.Lock()
        self._frontlock = threading.Lock()
        self._pushbatch = []
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

        self._consumer = {}  # {op_name: idx}
        self._consumer_base_idx = 0
        self._frontbatch = []
        self._idx_consumer_num = {}

    def add_consumer(self, op_name):
        """ not thread safe """
        if op_name in self._consumer:
            raise ValueError("op_name({}) is already in channel".format(
                op_name))
        self._consumer_id[op_name] = 0

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
56 57 58

    def push(self, item):
        with self._pushlock:
B
barrierye 已提交
59 60
            self._pushbatch.append(item)
            if len(self._pushbatch) == self._batchsize:
B
barrierye 已提交
61 62 63
                self.put(self._pushbatch, timeout=self._timeout)
                self._pushbatch = []

64 65 66 67 68 69
    def front(self, op_name):
        if len(self._consumer) == 0:
            raise Exception(
                "expected number of consumers to be greater than 0, but the it is 0."
            )
        elif len(self._consumer) == 1:
B
barrierye 已提交
70
            return self.get(timeout=self._timeout)
71

B
barrierye 已提交
72
        with self._frontlock:
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
            consumer_idx = self._consumer[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx

            if data_idx >= len(self._frontbatch):
                batch_data = self.get(timeout=self._timeout)
                self._frontbatch.append(batch_data)

            resp = self._frontbatch[data_idx]

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._frontbatch.pop(0)
                self._consumer_base_idx += 1

            self._consumer[op_name] += 1
            new_consumer_idx = self._consumer[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1

        return resp  # reference, read only
B
barrierye 已提交
97 98 99 100


class Op(object):
    def __init__(self,
101
                 name,
B
barrierye 已提交
102
                 inputs,
B
barrierye 已提交
103
                 in_dtype,
B
barrierye 已提交
104
                 outputs,
B
barrierye 已提交
105 106
                 out_dtype,
                 batchsize=1,
B
barrierye 已提交
107 108 109 110 111
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
112 113
                 fetch_names=None,
                 concurrency=1):
B
barrierye 已提交
114
        self._run = False
115 116 117
        # TODO: globally unique check
        self._name = name  # to identify the type of OP, it must be globally unique
        self._concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
118
        self.set_inputs(inputs)
B
barrierye 已提交
119
        self._in_dtype = in_dtype
B
barrierye 已提交
120
        self.set_outputs(outputs)
B
barrierye 已提交
121 122
        self._out_dtype = out_dtype
        self._batch_size = batchsize
B
barrierye 已提交
123
        self._client = None
B
barrierye 已提交
124 125 126 127 128 129
        if client_config is not None and \
                server_name is not None and \
                fetch_names is not None:
            self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
130
        self._device = device
B
barrierye 已提交
131
        self._data_ids = []
B
barrierye 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

    def set_client(self, client_config, server_name, fetch_names):
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

    def get_inputs(self):
        return self._inputs

    def set_inputs(self, channels):
        if not isinstance(channels, list):
            raise TypeError('channels must be list type')
148 149
        for channel in channels:
            channel.add_consumer(self._name)
B
barrierye 已提交
150 151 152 153 154 155 156 157 158 159
        self._inputs = channels

    def get_outputs(self):
        return self._outputs

    def set_outputs(self, channels):
        if not isinstance(channels, list):
            raise TypeError('channels must be list type')
        self._outputs = channels

B
barrierye 已提交
160 161 162 163 164 165 166 167 168
    def get_data_ids(self):
        return self._data_ids

    def clear_data_ids(self):
        self._data_ids = []

    def append_id_to_data_ids(self, data_id):
        self._data_ids.append(data_id)

B
barrierye 已提交
169
    def preprocess(self, input_data):
B
barrierye 已提交
170 171 172 173 174
        if len(input_data) != 1:
            raise Exception(
                'this Op has multiple previous channels. Please override this method'
            )
        feed_batch = []
B
barrierye 已提交
175
        self.clear_data_ids()
B
barrierye 已提交
176 177 178 179 180 181 182
        for data in input_data:
            if len(data.insts) != self._batch_size:
                raise Exception('len(data_insts) != self._batch_size')
            feed = {}
            for inst in data.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
            feed_batch.append(feed)
B
barrierye 已提交
183
            self.append_id_to_data_ids(data.id)
B
barrierye 已提交
184
        return feed_batch
B
barrierye 已提交
185 186

    def midprocess(self, data):
B
barrierye 已提交
187
        # data = preprocess(input), which must be a dict
B
barrierye 已提交
188 189
        logging.debug('data: {}'.format(data))
        logging.debug('fetch: {}'.format(self._fetch_names))
B
barrierye 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203
        fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
        return fetch_map

    def postprocess(self, output_data):
        return output_data

    def stop(self):
        self._run = False

    def start(self):
        self._run = True
        while self._run:
            input_data = []
            for channel in self._inputs:
204
                input_data.append(channel.front(self._name))
B
barrierye 已提交
205 206 207 208
            if len(input_data) > 1:
                data = self.preprocess(input_data)
            else:
                data = self.preprocess(input_data[0])
B
barrierye 已提交
209 210 211 212 213 214 215 216 217 218

            if self.with_serving():
                fetch_map = self.midprocess(data)
                output_data = self.postprocess(fetch_map)
            else:
                output_data = self.postprocess(data)

            for channel in self._outputs:
                channel.push(output_data)

219 220 221
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
222 223 224

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
225
    def __init__(self, in_channel, out_channel):
B
barrierye 已提交
226
        super(GeneralPythonService, self).__init__()
227 228
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
B
barrierye 已提交
229 230 231 232 233
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
234 235 236 237 238 239 240
        self._globel_resp_dict = {}
        self._id_counter = 0
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
        logging.debug('succ init')

241 242 243 244 245 246 247 248 249
    def set_in_channel(self, in_channel):
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
        if isinstance(out_channel, list):
            raise TypeError('out_channel can not be list type')
        out_channel.add_consumer("__GeneralPythonService__")
        self._out_channel = out_channel

B
barrierye 已提交
250 251 252 253 254 255 256 257 258 259
    def _recive_out_channel_func(self):
        while True:
            data = self._out_channel.front()
            data_id = None
            for d in data:
                if data_id is None:
                    data_id = d.id
                if data_id != d.id:
                    raise Exception("id not match: {} vs {}".format(data_id,
                                                                    d.id))
B
barrierye 已提交
260 261 262 263
            self._cv.acquire()
            self._globel_resp_dict[data_id] = data
            self._cv.notify_all()
            self._cv.release()
B
barrierye 已提交
264 265

    def _get_next_id(self):
B
barrierye 已提交
266
        with self._id_lock:
B
barrierye 已提交
267 268 269 270
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
271 272 273 274 275 276 277
        self._cv.acquire()
        while data_id not in self._globel_resp_dict:
            self._cv.wait()
        resp = self._globel_resp_dict.pop(data_id)
        self._cv.notify_all()
        self._cv.release()
        return resp
B
barrierye 已提交
278 279 280

    def _pack_data_for_infer(self, request):
        logging.debug('start inferce')
B
barrierye 已提交
281
        data = python_service_channel_pb2.ChannelData()
B
barrierye 已提交
282 283
        data_id = self._get_next_id()
        data.id = data_id
B
barrierye 已提交
284
        for idx, name in enumerate(request.feed_var_names):
B
barrierye 已提交
285 286
            logging.debug('name: {}'.format(request.feed_var_names[idx]))
            logging.debug('data: {}'.format(request.feed_insts[idx]))
B
barrierye 已提交
287
            inst = python_service_channel_pb2.Inst()
B
barrierye 已提交
288
            inst.data = request.feed_insts[idx]
B
barrierye 已提交
289 290
            inst.name = name
            data.insts.append(inst)
B
barrierye 已提交
291 292 293
        return data, data_id

    def _pack_data_for_resp(self, data):
B
barrierye 已提交
294
        data = data[0]  #TODO batchsize = 1
B
barrierye 已提交
295
        logging.debug('get data')
B
barrierye 已提交
296
        resp = general_python_service_pb2.Response()
B
barrierye 已提交
297 298
        logging.debug('gen resp')
        logging.debug(data)
B
barrierye 已提交
299
        for inst in data.insts:
B
barrierye 已提交
300
            logging.debug('append data')
B
barrierye 已提交
301
            resp.fetch_insts.append(inst.data)
B
barrierye 已提交
302
            logging.debug('append name')
B
barrierye 已提交
303 304
            resp.fetch_var_names.append(inst.name)
        return resp
B
barrierye 已提交
305

B
barrierye 已提交
306 307 308 309 310 311
    def inference(self, request, context):
        data, data_id = self._pack_data_for_infer(request)
        logging.debug('push data')
        self._in_channel.push(data)
        logging.debug('wait for infer')
        resp_data = None
B
barrierye 已提交
312
        resp_data = self._get_data_in_globel_resp_dict(data_id)
B
barrierye 已提交
313 314 315
        resp = self._pack_data_for_resp(resp_data)
        return resp

B
barrierye 已提交
316 317 318 319 320 321 322 323

class PyServer(object):
    def __init__(self):
        self._channels = []
        self._ops = []
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
324 325
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
326 327 328 329 330

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
B
barrierye 已提交
331
        self._ops.append(op)
B
barrierye 已提交
332 333

    def gen_desc(self):
B
barrierye 已提交
334
        logging.info('here will generate desc for paas')
B
barrierye 已提交
335 336 337 338 339
        pass

    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
B
barrierye 已提交
340 341
        inputs = set()
        outputs = set()
B
barrierye 已提交
342
        for op in self._ops:
B
barrierye 已提交
343 344
            inputs |= set(op.get_inputs())
            outputs |= set(op.get_outputs())
B
barrierye 已提交
345 346
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
347 348 349 350 351 352 353 354 355 356
        in_channel = inputs - outputs
        out_channel = outputs - inputs
        if len(in_channel) != 1 or len(out_channel) != 1:
            raise Exception(
                "in_channel(out_channel) more than 1 or no in_channel(out_channel)"
            )
        self._in_channel = in_channel.pop()
        self._out_channel = out_channel.pop()
        self.gen_desc()

357
    def _op_start_wrapper(self, op):
B
barrierye 已提交
358 359
        return op.start()

360
    def _run_ops(self):
B
barrierye 已提交
361
        for op in self._ops:
362 363 364 365 366 367 368 369 370 371
            op_concurrency = op.get_concurrency()
            for c in range(op_concurrency):
                # th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, ))
                th = threading.Thread(
                    target=self._op_start_wrapper, args=(op, ))
                th.start()
                self._op_threads.append(th)

    def run_server(self):
        self._run_ops()
B
barrierye 已提交
372 373
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
374
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
375
            GeneralPythonService(self._in_channel, self._out_channel), server)
B
barrierye 已提交
376
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
377 378 379 380
        server.start()
        try:
            for th in self._op_threads:
                th.join()
B
barrierye 已提交
381
            server.join()
B
barrierye 已提交
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
        except KeyboardInterrupt:
            server.stop(0)

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        # run a server (not in PyServing)
        if device == "cpu":
            cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
        else:
            cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
B
barrierye 已提交
397
        logging.info(cmd)
B
barrierye 已提交
398
        return
399
        # os.system(cmd)