pyserver.py 21.3 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20 21 22
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24 25 26
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
B
barrierye 已提交
27
import python_service_channel_pb2
B
barrierye 已提交
28
import logging
29
import random
B
barrierye 已提交
30
import time
B
barrierye 已提交
31 32


B
barrierye 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


B
barrierye 已提交
69
class Channel(Queue.Queue):
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
87
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
88
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
89 90
        self._maxsize = maxsize
        self._timeout = timeout
91
        self._name = name
92

93 94 95 96 97 98 99 100 101
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._front_wait_interval = 0.1  # second
        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
102
        self._consumer_base_idx = 0
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
119
        """ not thread safe, and can only be called during initialization. """
120 121 122 123
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
124 125

    def add_consumer(self, op_name):
B
barrierye 已提交
126
        """ not thread safe, and can only be called during initialization. """
127 128 129 130
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
131 132 133 134

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
135

136 137 138 139
    def push(self, data, op_name=None):
        logging.debug(
            self._log("{} try to push data: {}".format(op_name, data)))
        if len(self._producers) == 0:
140
            raise Exception(
141 142 143 144
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
145 146 147 148 149 150 151 152
            with self._cv:
                while True:
                    try:
                        self.put(data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
153 154 155 156 157 158
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
159

160 161 162
        producer_num = len(self._producers)
        data_id = data.id
        put_data = None
B
barrierye 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        with self._cv:
            logging.debug(self._log("{} get lock ~".format(op_name)))
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
            self._push_res[data_id][op_name] = data
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
178

B
barrierye 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
            if put_data is None:
                logging.debug(
                    self._log("{} push data succ, not not push to queue.".
                              format(op_name)))
            else:
                while True:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
194
        return True
195

196 197 198 199 200 201 202 203 204
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
205 206 207 208 209 210 211
            with self._cv:
                while resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
212 213 214 215 216 217
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
218

B
barrierye 已提交
219 220 221 222 223 224 225 226 227 228
        with self._cv:
            # data_idx = consumer_idx - base_idx
            while self._consumers[op_name] - self._consumer_base_idx >= len(
                    self._front_res):
                try:
                    data = self.get(timeout=0)
                    self._front_res.append(data)
                    break
                except Queue.Empty:
                    self._cv.wait()
229

B
barrierye 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
248

B
barrierye 已提交
249
            self._cv.notify_all()
250

251
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
252
        return resp  # reference, read only
B
barrierye 已提交
253 254 255 256


class Op(object):
    def __init__(self,
257
                 name,
258
                 input,
B
barrierye 已提交
259
                 in_dtype,
B
barrierye 已提交
260
                 outputs,
B
barrierye 已提交
261
                 out_dtype,
B
barrierye 已提交
262 263 264 265 266
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
267 268
                 fetch_names=None,
                 concurrency=1):
B
barrierye 已提交
269
        self._run = False
270 271 272
        # TODO: globally unique check
        self._name = name  # to identify the type of OP, it must be globally unique
        self._concurrency = concurrency  # amount of concurrency
273
        self.set_input(input)
B
barrierye 已提交
274
        self._in_dtype = in_dtype
B
barrierye 已提交
275
        self.set_outputs(outputs)
B
barrierye 已提交
276
        self._out_dtype = out_dtype
B
barrierye 已提交
277
        self._client = None
B
barrierye 已提交
278 279 280 281 282 283
        if client_config is not None and \
                server_name is not None and \
                fetch_names is not None:
            self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
284
        self._device = device
B
barrierye 已提交
285 286 287 288 289 290 291 292 293 294

    def set_client(self, client_config, server_name, fetch_names):
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

295 296
    def get_input(self):
        return self._input
B
barrierye 已提交
297

298 299 300 301 302 303 304
    def set_input(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self._name)
        self._input = channel
B
barrierye 已提交
305 306 307 308 309 310

    def get_outputs(self):
        return self._outputs

    def set_outputs(self, channels):
        if not isinstance(channels, list):
311 312 313 314 315
            raise TypeError(
                self._log('output channels must be list type, not {}'.format(
                    type(channels))))
        for channel in channels:
            channel.add_producer(self._name)
B
barrierye 已提交
316 317
        self._outputs = channels

318 319
    def preprocess(self, data):
        if isinstance(data, dict):
B
barrierye 已提交
320
            raise Exception(
321 322 323 324 325 326 327
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
        feed = {}
        for inst in data.insts:
            feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
        return feed
B
barrierye 已提交
328 329

    def midprocess(self, data):
330 331 332 333 334 335 336
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
B
barrierye 已提交
337
        fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
338
        logging.debug(self._log("finish predict"))
B
barrierye 已提交
339 340 341
        return fetch_map

    def postprocess(self, output_data):
342 343 344 345
        raise Exception(
            self._log(
                'Please override this method to convert data to the format in channel.'
            ))
B
barrierye 已提交
346 347 348 349 350 351 352

    def stop(self):
        self._run = False

    def start(self):
        self._run = True
        while self._run:
B
barrierye 已提交
353
            _profiler.record("{}-get_0".format(self._name))
354
            input_data = self._input.front(self._name)
B
barrierye 已提交
355
            _profiler.record("{}-get_1".format(self._name))
356 357 358 359 360
            data_id = None
            logging.debug(self._log("input_data: {}".format(input_data)))
            if isinstance(input_data, dict):
                key = input_data.keys()[0]
                data_id = input_data[key].id
B
barrierye 已提交
361
            else:
362
                data_id = input_data.id
B
barrierye 已提交
363

B
barrierye 已提交
364
            _profiler.record("{}-prep_0".format(self._name))
365
            data = self.preprocess(input_data)
B
barrierye 已提交
366 367
            _profiler.record("{}-prep_1".format(self._name))

B
barrierye 已提交
368
            if self.with_serving():
B
barrierye 已提交
369
                _profiler.record("{}-midp_0".format(self._name))
370
                data = self.midprocess(data)
B
barrierye 已提交
371 372 373
                _profiler.record("{}-midp_1".format(self._name))

            _profiler.record("{}-postp_0".format(self._name))
374
            output_data = self.postprocess(data)
B
barrierye 已提交
375
            _profiler.record("{}-postp_1".format(self._name))
376 377 378 379 380 381 382 383

            if not isinstance(output_data,
                              python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log(
                        'output_data must be ChannelData type, but get {}'.
                        format(type(output_data))))
            output_data.id = data_id
B
barrierye 已提交
384

B
barrierye 已提交
385
            _profiler.record("{}-push_0".format(self._name))
B
barrierye 已提交
386
            for channel in self._outputs:
387
                channel.push(output_data, self._name)
B
barrierye 已提交
388
            _profiler.record("{}-push_1".format(self._name))
389 390 391

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
392

393 394 395
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
396 397 398

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
399
    def __init__(self, in_channel, out_channel):
B
barrierye 已提交
400
        super(GeneralPythonService, self).__init__()
B
barrierye 已提交
401
        self._name = "#G"
402 403
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
404 405
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
406 407 408 409 410
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
411 412 413 414 415
        self._globel_resp_dict = {}
        self._id_counter = 0
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
416 417 418

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
419

420
    def set_in_channel(self, in_channel):
421 422 423 424 425
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
        in_channel.add_producer(self._name)
426 427 428
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
429 430 431 432 433
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
        out_channel.add_consumer(self._name)
434 435
        self._out_channel = out_channel

B
barrierye 已提交
436 437
    def _recive_out_channel_func(self):
        while True:
438 439 440 441 442
            data = self._out_channel.front(self._name)
            if not isinstance(data, python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
                              format(type(data))))
B
barrierye 已提交
443 444 445
            with self._cv:
                self._globel_resp_dict[data.id] = data
                self._cv.notify_all()
B
barrierye 已提交
446 447

    def _get_next_id(self):
B
barrierye 已提交
448
        with self._id_lock:
B
barrierye 已提交
449 450 451 452
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
453 454 455 456 457 458
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
459
        return resp
B
barrierye 已提交
460 461

    def _pack_data_for_infer(self, request):
462
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
463
        data = python_service_channel_pb2.ChannelData()
B
barrierye 已提交
464 465
        data_id = self._get_next_id()
        data.id = data_id
B
barrierye 已提交
466
        for idx, name in enumerate(request.feed_var_names):
467 468 469
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
470
            inst = python_service_channel_pb2.Inst()
B
barrierye 已提交
471
            inst.data = request.feed_insts[idx]
B
barrierye 已提交
472 473
            inst.name = name
            data.insts.append(inst)
B
barrierye 已提交
474 475 476
        return data, data_id

    def _pack_data_for_resp(self, data):
477
        logging.debug(self._log('get data'))
B
barrierye 已提交
478
        resp = general_python_service_pb2.Response()
479
        logging.debug(self._log('gen resp'))
B
barrierye 已提交
480
        logging.debug(data)
B
barrierye 已提交
481
        for inst in data.insts:
482
            logging.debug(self._log('append data'))
B
barrierye 已提交
483
            resp.fetch_insts.append(inst.data)
484
            logging.debug(self._log('append name'))
B
barrierye 已提交
485 486
            resp.fetch_var_names.append(inst.name)
        return resp
B
barrierye 已提交
487

B
barrierye 已提交
488
    def inference(self, request, context):
B
barrierye 已提交
489
        _profiler.record("{}-prepack_0".format(self._name))
B
barrierye 已提交
490
        data, data_id = self._pack_data_for_infer(request)
B
barrierye 已提交
491 492
        _profiler.record("{}-prepack_1".format(self._name))

493
        logging.debug(self._log('push data'))
B
barrierye 已提交
494
        _profiler.record("{}-push_0".format(self._name))
495
        self._in_channel.push(data, self._name)
B
barrierye 已提交
496 497
        _profiler.record("{}-push_1".format(self._name))

498
        logging.debug(self._log('wait for infer'))
B
barrierye 已提交
499
        resp_data = None
B
barrierye 已提交
500
        _profiler.record("{}-fetch_0".format(self._name))
B
barrierye 已提交
501
        resp_data = self._get_data_in_globel_resp_dict(data_id)
B
barrierye 已提交
502 503 504
        _profiler.record("{}-fetch_1".format(self._name))

        _profiler.record("{}-postpack_0".format(self._name))
B
barrierye 已提交
505
        resp = self._pack_data_for_resp(resp_data)
B
barrierye 已提交
506 507
        _profiler.record("{}-postpack_1".format(self._name))
        _profiler.print_profile()
B
barrierye 已提交
508 509
        return resp

B
barrierye 已提交
510 511

class PyServer(object):
B
barrierye 已提交
512
    def __init__(self, profile=False):
B
barrierye 已提交
513 514 515 516 517
        self._channels = []
        self._ops = []
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
518 519
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
520
        _profiler.enable(profile)
B
barrierye 已提交
521 522 523 524 525

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
B
barrierye 已提交
526
        self._ops.append(op)
B
barrierye 已提交
527 528

    def gen_desc(self):
B
barrierye 已提交
529
        logging.info('here will generate desc for paas')
B
barrierye 已提交
530 531 532 533 534
        pass

    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
B
barrierye 已提交
535 536
        inputs = set()
        outputs = set()
B
barrierye 已提交
537
        for op in self._ops:
538
            inputs |= set([op.get_input()])
B
barrierye 已提交
539
            outputs |= set(op.get_outputs())
B
barrierye 已提交
540 541
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
542 543 544 545 546 547 548 549 550 551
        in_channel = inputs - outputs
        out_channel = outputs - inputs
        if len(in_channel) != 1 or len(out_channel) != 1:
            raise Exception(
                "in_channel(out_channel) more than 1 or no in_channel(out_channel)"
            )
        self._in_channel = in_channel.pop()
        self._out_channel = out_channel.pop()
        self.gen_desc()

552
    def _op_start_wrapper(self, op):
B
barrierye 已提交
553 554
        return op.start()

555
    def _run_ops(self):
B
barrierye 已提交
556
        for op in self._ops:
557
            op_concurrency = op.get_concurrency()
558 559
            logging.debug("run op: {}, op_concurrency: {}".format(
                op._name, op_concurrency))
560
            for c in range(op_concurrency):
B
barrierye 已提交
561
                # th = multiprocessing.Process(
562 563 564 565 566 567 568
                th = threading.Thread(
                    target=self._op_start_wrapper, args=(op, ))
                th.start()
                self._op_threads.append(th)

    def run_server(self):
        self._run_ops()
B
barrierye 已提交
569 570
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
571
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
572
            GeneralPythonService(self._in_channel, self._out_channel), server)
B
barrierye 已提交
573
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
574 575 576 577
        server.start()
        try:
            for th in self._op_threads:
                th.join()
B
barrierye 已提交
578
            server.join()
B
barrierye 已提交
579 580 581 582 583 584 585 586 587 588 589 590 591 592
        except KeyboardInterrupt:
            server.stop(0)

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
            cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
        else:
            cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
593 594
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
595
        return
596
        # os.system(cmd)