pyserver.py 19.1 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18 19 20 21
import os
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
B
barrierye 已提交
22
import numpy as np
B
barrierye 已提交
23 24 25
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
B
barrierye 已提交
26
import python_service_channel_pb2
B
barrierye 已提交
27
import logging
28
import random
B
barrierye 已提交
29
import time
B
barrierye 已提交
30 31


B
barrierye 已提交
32
class Channel(Queue.Queue):
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
50
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
51
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
52 53
        self._maxsize = maxsize
        self._timeout = timeout
54
        self._name = name
55

56 57 58 59 60 61 62 63 64
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._front_wait_interval = 0.1  # second
        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
65
        self._consumer_base_idx = 0
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
82
        """ not thread safe, and can only be called during initialization. """
83 84 85 86
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
87 88

    def add_consumer(self, op_name):
B
barrierye 已提交
89
        """ not thread safe, and can only be called during initialization. """
90 91 92 93
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
94 95 96 97

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
98

99 100 101 102
    def push(self, data, op_name=None):
        logging.debug(
            self._log("{} try to push data: {}".format(op_name, data)))
        if len(self._producers) == 0:
103
            raise Exception(
104 105 106 107
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
108 109 110 111 112 113 114 115
            with self._cv:
                while True:
                    try:
                        self.put(data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
116 117 118 119 120 121
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
122

123 124 125
        producer_num = len(self._producers)
        data_id = data.id
        put_data = None
B
barrierye 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        with self._cv:
            logging.debug(self._log("{} get lock ~".format(op_name)))
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
            self._push_res[data_id][op_name] = data
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
141

B
barrierye 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
            if put_data is None:
                logging.debug(
                    self._log("{} push data succ, not not push to queue.".
                              format(op_name)))
            else:
                while True:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
157
        return True
158

159 160 161 162 163 164 165 166 167
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
168 169 170 171 172 173 174
            with self._cv:
                while resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
175 176 177 178 179 180
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
181

B
barrierye 已提交
182 183 184 185 186 187 188 189 190 191
        with self._cv:
            # data_idx = consumer_idx - base_idx
            while self._consumers[op_name] - self._consumer_base_idx >= len(
                    self._front_res):
                try:
                    data = self.get(timeout=0)
                    self._front_res.append(data)
                    break
                except Queue.Empty:
                    self._cv.wait()
192

B
barrierye 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
211

B
barrierye 已提交
212
            self._cv.notify_all()
213

214
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
215
        return resp  # reference, read only
B
barrierye 已提交
216 217 218 219


class Op(object):
    def __init__(self,
220
                 name,
221
                 input,
B
barrierye 已提交
222
                 in_dtype,
B
barrierye 已提交
223
                 outputs,
B
barrierye 已提交
224
                 out_dtype,
B
barrierye 已提交
225 226 227 228 229
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
230 231
                 fetch_names=None,
                 concurrency=1):
B
barrierye 已提交
232
        self._run = False
233 234 235
        # TODO: globally unique check
        self._name = name  # to identify the type of OP, it must be globally unique
        self._concurrency = concurrency  # amount of concurrency
236
        self.set_input(input)
B
barrierye 已提交
237
        self._in_dtype = in_dtype
B
barrierye 已提交
238
        self.set_outputs(outputs)
B
barrierye 已提交
239
        self._out_dtype = out_dtype
B
barrierye 已提交
240
        self._client = None
B
barrierye 已提交
241 242 243 244 245 246
        if client_config is not None and \
                server_name is not None and \
                fetch_names is not None:
            self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
247
        self._device = device
B
barrierye 已提交
248 249 250 251 252 253 254 255 256 257

    def set_client(self, client_config, server_name, fetch_names):
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

258 259
    def get_input(self):
        return self._input
B
barrierye 已提交
260

261 262 263 264 265 266 267
    def set_input(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self._name)
        self._input = channel
B
barrierye 已提交
268 269 270 271 272 273

    def get_outputs(self):
        return self._outputs

    def set_outputs(self, channels):
        if not isinstance(channels, list):
274 275 276 277 278
            raise TypeError(
                self._log('output channels must be list type, not {}'.format(
                    type(channels))))
        for channel in channels:
            channel.add_producer(self._name)
B
barrierye 已提交
279 280
        self._outputs = channels

281 282
    def preprocess(self, data):
        if isinstance(data, dict):
B
barrierye 已提交
283
            raise Exception(
284 285 286 287 288 289 290
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
        feed = {}
        for inst in data.insts:
            feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
        return feed
B
barrierye 已提交
291 292

    def midprocess(self, data):
293 294 295 296 297 298 299
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
B
barrierye 已提交
300
        fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
301
        logging.debug(self._log("finish predict"))
B
barrierye 已提交
302 303 304
        return fetch_map

    def postprocess(self, output_data):
305 306 307 308
        raise Exception(
            self._log(
                'Please override this method to convert data to the format in channel.'
            ))
B
barrierye 已提交
309 310 311 312 313 314 315

    def stop(self):
        self._run = False

    def start(self):
        self._run = True
        while self._run:
316 317 318 319 320 321
            input_data = self._input.front(self._name)
            data_id = None
            logging.debug(self._log("input_data: {}".format(input_data)))
            if isinstance(input_data, dict):
                key = input_data.keys()[0]
                data_id = input_data[key].id
B
barrierye 已提交
322
            else:
323
                data_id = input_data.id
B
barrierye 已提交
324

325
            data = self.preprocess(input_data)
B
barrierye 已提交
326
            if self.with_serving():
327 328 329 330 331 332 333 334 335 336
                data = self.midprocess(data)
            output_data = self.postprocess(data)

            if not isinstance(output_data,
                              python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log(
                        'output_data must be ChannelData type, but get {}'.
                        format(type(output_data))))
            output_data.id = data_id
B
barrierye 已提交
337 338

            for channel in self._outputs:
339 340 341 342
                channel.push(output_data, self._name)

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
343

344 345 346
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
347 348 349

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
350
    def __init__(self, in_channel, out_channel):
B
barrierye 已提交
351
        super(GeneralPythonService, self).__init__()
352
        self._name = "__GeneralPythonService__"
353 354
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
355 356
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
357 358 359 360 361
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
362 363 364 365 366
        self._globel_resp_dict = {}
        self._id_counter = 0
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
367 368 369

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
370

371
    def set_in_channel(self, in_channel):
372 373 374 375 376
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
        in_channel.add_producer(self._name)
377 378 379
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
380 381 382 383 384
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
        out_channel.add_consumer(self._name)
385 386
        self._out_channel = out_channel

B
barrierye 已提交
387 388
    def _recive_out_channel_func(self):
        while True:
389 390 391 392 393
            data = self._out_channel.front(self._name)
            if not isinstance(data, python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
                              format(type(data))))
B
barrierye 已提交
394 395 396
            with self._cv:
                self._globel_resp_dict[data.id] = data
                self._cv.notify_all()
B
barrierye 已提交
397 398

    def _get_next_id(self):
B
barrierye 已提交
399
        with self._id_lock:
B
barrierye 已提交
400 401 402 403
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
404 405 406 407 408 409
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
410
        return resp
B
barrierye 已提交
411 412

    def _pack_data_for_infer(self, request):
413
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
414
        data = python_service_channel_pb2.ChannelData()
B
barrierye 已提交
415 416
        data_id = self._get_next_id()
        data.id = data_id
B
barrierye 已提交
417
        for idx, name in enumerate(request.feed_var_names):
418 419 420
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
421
            inst = python_service_channel_pb2.Inst()
B
barrierye 已提交
422
            inst.data = request.feed_insts[idx]
B
barrierye 已提交
423 424
            inst.name = name
            data.insts.append(inst)
B
barrierye 已提交
425 426 427
        return data, data_id

    def _pack_data_for_resp(self, data):
428
        logging.debug(self._log('get data'))
B
barrierye 已提交
429
        resp = general_python_service_pb2.Response()
430
        logging.debug(self._log('gen resp'))
B
barrierye 已提交
431
        logging.debug(data)
B
barrierye 已提交
432
        for inst in data.insts:
433
            logging.debug(self._log('append data'))
B
barrierye 已提交
434
            resp.fetch_insts.append(inst.data)
435
            logging.debug(self._log('append name'))
B
barrierye 已提交
436 437
            resp.fetch_var_names.append(inst.name)
        return resp
B
barrierye 已提交
438

B
barrierye 已提交
439 440
    def inference(self, request, context):
        data, data_id = self._pack_data_for_infer(request)
441 442 443
        logging.debug(self._log('push data'))
        self._in_channel.push(data, self._name)
        logging.debug(self._log('wait for infer'))
B
barrierye 已提交
444
        resp_data = None
B
barrierye 已提交
445
        resp_data = self._get_data_in_globel_resp_dict(data_id)
B
barrierye 已提交
446 447 448
        resp = self._pack_data_for_resp(resp_data)
        return resp

B
barrierye 已提交
449 450 451 452 453 454 455 456

class PyServer(object):
    def __init__(self):
        self._channels = []
        self._ops = []
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
457 458
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
459 460 461 462 463

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
B
barrierye 已提交
464
        self._ops.append(op)
B
barrierye 已提交
465 466

    def gen_desc(self):
B
barrierye 已提交
467
        logging.info('here will generate desc for paas')
B
barrierye 已提交
468 469 470 471 472
        pass

    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
B
barrierye 已提交
473 474
        inputs = set()
        outputs = set()
B
barrierye 已提交
475
        for op in self._ops:
476
            inputs |= set([op.get_input()])
B
barrierye 已提交
477
            outputs |= set(op.get_outputs())
B
barrierye 已提交
478 479
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
480 481 482 483 484 485 486 487 488 489
        in_channel = inputs - outputs
        out_channel = outputs - inputs
        if len(in_channel) != 1 or len(out_channel) != 1:
            raise Exception(
                "in_channel(out_channel) more than 1 or no in_channel(out_channel)"
            )
        self._in_channel = in_channel.pop()
        self._out_channel = out_channel.pop()
        self.gen_desc()

490
    def _op_start_wrapper(self, op):
B
barrierye 已提交
491 492
        return op.start()

493
    def _run_ops(self):
B
barrierye 已提交
494
        for op in self._ops:
495
            op_concurrency = op.get_concurrency()
496 497
            logging.debug("run op: {}, op_concurrency: {}".format(
                op._name, op_concurrency))
498 499 500 501 502 503 504 505 506
            for c in range(op_concurrency):
                # th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, ))
                th = threading.Thread(
                    target=self._op_start_wrapper, args=(op, ))
                th.start()
                self._op_threads.append(th)

    def run_server(self):
        self._run_ops()
B
barrierye 已提交
507 508
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
509
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
510
            GeneralPythonService(self._in_channel, self._out_channel), server)
B
barrierye 已提交
511
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
512 513 514 515
        server.start()
        try:
            for th in self._op_threads:
                th.join()
B
barrierye 已提交
516
            server.join()
B
barrierye 已提交
517 518 519 520 521 522 523 524 525 526 527 528 529 530
        except KeyboardInterrupt:
            server.stop(0)

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
            cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
        else:
            cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
531 532
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
533
        return
534
        # os.system(cmd)