pyserver.py 19.0 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18 19 20 21
import os
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
B
barrierye 已提交
22
import numpy as np
B
barrierye 已提交
23 24 25
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
B
barrierye 已提交
26
import python_service_channel_pb2
B
barrierye 已提交
27
import logging
28
import random
B
barrierye 已提交
29
import time
B
barrierye 已提交
30 31


B
barrierye 已提交
32
class Channel(Queue.Queue):
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

    def __init__(self, name=None, maxsize=-1, timeout=None, batchsize=1):
B
barrierye 已提交
51
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
52 53
        self._maxsize = maxsize
        self._timeout = timeout
54 55 56
        self._name = name
        #self._batchsize = batchsize
        # self._pushbatch = []
57

58 59 60 61 62 63 64 65 66
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._front_wait_interval = 0.1  # second
        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
67
        self._consumer_base_idx = 0
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization """
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
89 90

    def add_consumer(self, op_name):
91 92 93 94 95
        """ not thread safe, and can only be called during initialization """
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
96 97 98 99

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
100

101 102 103 104
    def push(self, data, op_name=None):
        logging.debug(
            self._log("{} try to push data: {}".format(op_name, data)))
        if len(self._producers) == 0:
105
            raise Exception(
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
            self._cv.acquire()
            while True:
                try:
                    self.put(data, timeout=0)
                    break
                except Queue.Empty:
                    self._cv.wait()
            self._cv.notify_all()
            self._cv.release()
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
125

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        producer_num = len(self._producers)
        data_id = data.id
        put_data = None
        self._cv.acquire()
        logging.debug(self._log("{} get lock ~".format(op_name)))
        if data_id not in self._push_res:
            self._push_res[data_id] = {name: None for name in self._producers}
            self._producer_res_count[data_id] = 0
        self._push_res[data_id][op_name] = data
        if self._producer_res_count[data_id] + 1 == producer_num:
            put_data = self._push_res[data_id]
            self._push_res.pop(data_id)
            self._producer_res_count.pop(data_id)
        else:
            self._producer_res_count[data_id] += 1
141

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
        if put_data is None:
            logging.debug(
                self._log("{} push data succ, not not push to queue.".format(
                    op_name)))
        else:
            while True:
                try:
                    self.put(put_data, timeout=0)
                    break
                except Queue.Empty:
                    self._cv.wait()

            logging.debug(
                self._log("multi | {} push data succ!".format(op_name)))
        self._cv.notify_all()
        self._cv.release()
        return True
159

160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            self._cv.acquire()
            resp = None
            while resp is None:
                try:
                    resp = self.get(timeout=0)
                    break
                except Queue.Empty:
                    self._cv.wait()
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
182

183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        self._cv.acquire()
        # data_idx = consumer_idx - base_idx
        while self._consumers[op_name] - self._consumer_base_idx >= len(
                self._front_res):
            try:
                data = self.get(timeout=0)
                self._front_res.append(data)
                break
            except Queue.Empty:
                self._cv.wait()

        consumer_idx = self._consumers[op_name]
        base_idx = self._consumer_base_idx
        data_idx = consumer_idx - base_idx
        resp = self._front_res[data_idx]
        logging.debug(self._log("{} get data: {}".format(op_name, resp)))

        self._idx_consumer_num[consumer_idx] -= 1
        if consumer_idx == base_idx and self._idx_consumer_num[
                consumer_idx] == 0:
            self._idx_consumer_num.pop(consumer_idx)
            self._front_res.pop(0)
            self._consumer_base_idx += 1

        self._consumers[op_name] += 1
        new_consumer_idx = self._consumers[op_name]
        if self._idx_consumer_num.get(new_consumer_idx) is None:
            self._idx_consumer_num[new_consumer_idx] = 0
        self._idx_consumer_num[new_consumer_idx] += 1
212

213 214
        self._cv.notify_all()
        self._cv.release()
215

216
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
217
        return resp  # reference, read only
B
barrierye 已提交
218 219 220 221


class Op(object):
    def __init__(self,
222
                 name,
223
                 input,
B
barrierye 已提交
224
                 in_dtype,
B
barrierye 已提交
225
                 outputs,
B
barrierye 已提交
226 227
                 out_dtype,
                 batchsize=1,
B
barrierye 已提交
228 229 230 231 232
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
233 234
                 fetch_names=None,
                 concurrency=1):
B
barrierye 已提交
235
        self._run = False
236 237 238
        # TODO: globally unique check
        self._name = name  # to identify the type of OP, it must be globally unique
        self._concurrency = concurrency  # amount of concurrency
239
        self.set_input(input)
B
barrierye 已提交
240
        self._in_dtype = in_dtype
B
barrierye 已提交
241
        self.set_outputs(outputs)
B
barrierye 已提交
242
        self._out_dtype = out_dtype
243
        # self._batch_size = batchsize
B
barrierye 已提交
244
        self._client = None
B
barrierye 已提交
245 246 247 248 249 250
        if client_config is not None and \
                server_name is not None and \
                fetch_names is not None:
            self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
251
        self._device = device
B
barrierye 已提交
252 253 254 255 256 257 258 259 260 261

    def set_client(self, client_config, server_name, fetch_names):
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

262 263
    def get_input(self):
        return self._input
B
barrierye 已提交
264

265 266 267 268 269 270 271
    def set_input(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self._name)
        self._input = channel
B
barrierye 已提交
272 273 274 275 276 277

    def get_outputs(self):
        return self._outputs

    def set_outputs(self, channels):
        if not isinstance(channels, list):
278 279 280 281 282
            raise TypeError(
                self._log('output channels must be list type, not {}'.format(
                    type(channels))))
        for channel in channels:
            channel.add_producer(self._name)
B
barrierye 已提交
283 284
        self._outputs = channels

285 286
    def preprocess(self, data):
        if isinstance(data, dict):
B
barrierye 已提交
287
            raise Exception(
288 289 290 291 292 293 294
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
        feed = {}
        for inst in data.insts:
            feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
        return feed
B
barrierye 已提交
295 296

    def midprocess(self, data):
297 298 299 300 301 302 303
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
B
barrierye 已提交
304
        fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
305
        logging.debug(self._log("finish predict"))
B
barrierye 已提交
306 307 308
        return fetch_map

    def postprocess(self, output_data):
309 310 311 312
        raise Exception(
            self._log(
                'Please override this method to convert data to the format in channel.'
            ))
B
barrierye 已提交
313 314 315 316 317 318 319

    def stop(self):
        self._run = False

    def start(self):
        self._run = True
        while self._run:
320 321 322 323 324 325
            input_data = self._input.front(self._name)
            data_id = None
            logging.debug(self._log("input_data: {}".format(input_data)))
            if isinstance(input_data, dict):
                key = input_data.keys()[0]
                data_id = input_data[key].id
B
barrierye 已提交
326
            else:
327
                data_id = input_data.id
B
barrierye 已提交
328

329
            data = self.preprocess(input_data)
B
barrierye 已提交
330
            if self.with_serving():
331 332 333 334 335 336 337 338 339 340
                data = self.midprocess(data)
            output_data = self.postprocess(data)

            if not isinstance(output_data,
                              python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log(
                        'output_data must be ChannelData type, but get {}'.
                        format(type(output_data))))
            output_data.id = data_id
B
barrierye 已提交
341 342

            for channel in self._outputs:
343 344 345 346
                channel.push(output_data, self._name)

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
347

348 349 350
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
351 352 353

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
354
    def __init__(self, in_channel, out_channel):
B
barrierye 已提交
355
        super(GeneralPythonService, self).__init__()
356
        self._name = "__GeneralPythonService__"
357 358
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
359 360
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
361 362 363 364 365
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
366 367 368 369 370
        self._globel_resp_dict = {}
        self._id_counter = 0
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
371 372 373

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
374

375
    def set_in_channel(self, in_channel):
376 377 378 379 380
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
        in_channel.add_producer(self._name)
381 382 383
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
384 385 386 387 388
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
        out_channel.add_consumer(self._name)
389 390
        self._out_channel = out_channel

B
barrierye 已提交
391 392
    def _recive_out_channel_func(self):
        while True:
393 394 395 396 397
            data = self._out_channel.front(self._name)
            if not isinstance(data, python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
                              format(type(data))))
B
barrierye 已提交
398
            self._cv.acquire()
399
            self._globel_resp_dict[data.id] = data
B
barrierye 已提交
400 401
            self._cv.notify_all()
            self._cv.release()
B
barrierye 已提交
402 403

    def _get_next_id(self):
B
barrierye 已提交
404
        with self._id_lock:
B
barrierye 已提交
405 406 407 408
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
409 410 411 412 413 414 415
        self._cv.acquire()
        while data_id not in self._globel_resp_dict:
            self._cv.wait()
        resp = self._globel_resp_dict.pop(data_id)
        self._cv.notify_all()
        self._cv.release()
        return resp
B
barrierye 已提交
416 417

    def _pack_data_for_infer(self, request):
418
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
419
        data = python_service_channel_pb2.ChannelData()
B
barrierye 已提交
420 421
        data_id = self._get_next_id()
        data.id = data_id
B
barrierye 已提交
422
        for idx, name in enumerate(request.feed_var_names):
423 424 425
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
426
            inst = python_service_channel_pb2.Inst()
B
barrierye 已提交
427
            inst.data = request.feed_insts[idx]
B
barrierye 已提交
428 429
            inst.name = name
            data.insts.append(inst)
B
barrierye 已提交
430 431 432
        return data, data_id

    def _pack_data_for_resp(self, data):
433
        logging.debug(self._log('get data'))
B
barrierye 已提交
434
        resp = general_python_service_pb2.Response()
435
        logging.debug(self._log('gen resp'))
B
barrierye 已提交
436
        logging.debug(data)
B
barrierye 已提交
437
        for inst in data.insts:
438
            logging.debug(self._log('append data'))
B
barrierye 已提交
439
            resp.fetch_insts.append(inst.data)
440
            logging.debug(self._log('append name'))
B
barrierye 已提交
441 442
            resp.fetch_var_names.append(inst.name)
        return resp
B
barrierye 已提交
443

B
barrierye 已提交
444 445
    def inference(self, request, context):
        data, data_id = self._pack_data_for_infer(request)
446 447 448
        logging.debug(self._log('push data'))
        self._in_channel.push(data, self._name)
        logging.debug(self._log('wait for infer'))
B
barrierye 已提交
449
        resp_data = None
B
barrierye 已提交
450
        resp_data = self._get_data_in_globel_resp_dict(data_id)
B
barrierye 已提交
451 452 453
        resp = self._pack_data_for_resp(resp_data)
        return resp

B
barrierye 已提交
454 455 456 457 458 459 460 461

class PyServer(object):
    def __init__(self):
        self._channels = []
        self._ops = []
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
462 463
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
464 465 466 467 468

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
B
barrierye 已提交
469
        self._ops.append(op)
B
barrierye 已提交
470 471

    def gen_desc(self):
B
barrierye 已提交
472
        logging.info('here will generate desc for paas')
B
barrierye 已提交
473 474 475 476 477
        pass

    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
B
barrierye 已提交
478 479
        inputs = set()
        outputs = set()
B
barrierye 已提交
480
        for op in self._ops:
481
            inputs |= set([op.get_input()])
B
barrierye 已提交
482
            outputs |= set(op.get_outputs())
B
barrierye 已提交
483 484
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
485 486 487 488 489 490 491 492 493 494
        in_channel = inputs - outputs
        out_channel = outputs - inputs
        if len(in_channel) != 1 or len(out_channel) != 1:
            raise Exception(
                "in_channel(out_channel) more than 1 or no in_channel(out_channel)"
            )
        self._in_channel = in_channel.pop()
        self._out_channel = out_channel.pop()
        self.gen_desc()

495
    def _op_start_wrapper(self, op):
B
barrierye 已提交
496 497
        return op.start()

498
    def _run_ops(self):
B
barrierye 已提交
499
        for op in self._ops:
500
            op_concurrency = op.get_concurrency()
501 502
            logging.debug("run op: {}, op_concurrency: {}".format(
                op._name, op_concurrency))
503 504 505 506 507 508 509 510 511
            for c in range(op_concurrency):
                # th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, ))
                th = threading.Thread(
                    target=self._op_start_wrapper, args=(op, ))
                th.start()
                self._op_threads.append(th)

    def run_server(self):
        self._run_ops()
B
barrierye 已提交
512 513
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
514
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
515
            GeneralPythonService(self._in_channel, self._out_channel), server)
B
barrierye 已提交
516
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
517 518 519 520
        server.start()
        try:
            for th in self._op_threads:
                th.join()
B
barrierye 已提交
521
            server.join()
B
barrierye 已提交
522 523 524 525 526 527 528 529 530 531 532 533 534 535
        except KeyboardInterrupt:
            server.stop(0)

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
            cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
        else:
            cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
536 537
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
538
        return
539
        # os.system(cmd)