pyserver.py 21.7 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20 21 22
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24 25 26
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
B
barrierye 已提交
27
import python_service_channel_pb2
B
barrierye 已提交
28
import logging
29
import random
B
barrierye 已提交
30
import time
B
barrierye 已提交
31 32


B
barrierye 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


B
barrierye 已提交
69
class Channel(Queue.Queue):
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
87
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
88
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
89 90
        self._maxsize = maxsize
        self._timeout = timeout
91
        self._name = name
92

93 94 95 96 97 98 99 100 101
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._front_wait_interval = 0.1  # second
        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
102
        self._consumer_base_idx = 0
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
119
        """ not thread safe, and can only be called during initialization. """
120 121 122 123
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
124 125

    def add_consumer(self, op_name):
B
barrierye 已提交
126
        """ not thread safe, and can only be called during initialization. """
127 128 129 130
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
131 132 133 134

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
135

136 137 138 139
    def push(self, data, op_name=None):
        logging.debug(
            self._log("{} try to push data: {}".format(op_name, data)))
        if len(self._producers) == 0:
140
            raise Exception(
141 142 143 144
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
145 146 147 148 149 150 151 152
            with self._cv:
                while True:
                    try:
                        self.put(data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
153 154 155 156 157 158
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
159

160 161 162
        producer_num = len(self._producers)
        data_id = data.id
        put_data = None
B
barrierye 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        with self._cv:
            logging.debug(self._log("{} get lock ~".format(op_name)))
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
            self._push_res[data_id][op_name] = data
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
178

B
barrierye 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
            if put_data is None:
                logging.debug(
                    self._log("{} push data succ, not not push to queue.".
                              format(op_name)))
            else:
                while True:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
194
        return True
195

196 197 198 199 200 201 202 203 204
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
205 206 207 208 209 210 211
            with self._cv:
                while resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
212 213 214 215 216 217
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
218

B
barrierye 已提交
219 220 221 222 223 224 225 226 227 228
        with self._cv:
            # data_idx = consumer_idx - base_idx
            while self._consumers[op_name] - self._consumer_base_idx >= len(
                    self._front_res):
                try:
                    data = self.get(timeout=0)
                    self._front_res.append(data)
                    break
                except Queue.Empty:
                    self._cv.wait()
229

B
barrierye 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
248

B
barrierye 已提交
249
            self._cv.notify_all()
250

251
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
252
        return resp  # reference, read only
B
barrierye 已提交
253 254 255 256


class Op(object):
    def __init__(self,
257
                 name,
258
                 input,
B
barrierye 已提交
259
                 in_dtype,
B
barrierye 已提交
260
                 outputs,
B
barrierye 已提交
261
                 out_dtype,
B
barrierye 已提交
262 263 264 265 266
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
267 268
                 fetch_names=None,
                 concurrency=1):
B
barrierye 已提交
269
        self._run = False
270 271 272
        # TODO: globally unique check
        self._name = name  # to identify the type of OP, it must be globally unique
        self._concurrency = concurrency  # amount of concurrency
273
        self.set_input(input)
B
barrierye 已提交
274
        self._in_dtype = in_dtype
B
barrierye 已提交
275
        self.set_outputs(outputs)
B
barrierye 已提交
276
        self._out_dtype = out_dtype
B
barrierye 已提交
277
        self._client = None
B
barrierye 已提交
278 279 280 281 282 283
        if client_config is not None and \
                server_name is not None and \
                fetch_names is not None:
            self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
284
        self._device = device
B
barrierye 已提交
285 286 287 288 289 290 291 292 293 294

    def set_client(self, client_config, server_name, fetch_names):
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

295 296
    def get_input(self):
        return self._input
B
barrierye 已提交
297

298 299 300 301 302 303 304
    def set_input(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self._name)
        self._input = channel
B
barrierye 已提交
305 306 307 308 309 310

    def get_outputs(self):
        return self._outputs

    def set_outputs(self, channels):
        if not isinstance(channels, list):
311 312 313 314 315
            raise TypeError(
                self._log('output channels must be list type, not {}'.format(
                    type(channels))))
        for channel in channels:
            channel.add_producer(self._name)
B
barrierye 已提交
316 317
        self._outputs = channels

318 319
    def preprocess(self, data):
        if isinstance(data, dict):
B
barrierye 已提交
320
            raise Exception(
321 322 323 324 325 326 327
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
        feed = {}
        for inst in data.insts:
            feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
        return feed
B
barrierye 已提交
328 329

    def midprocess(self, data):
330 331 332 333 334 335 336
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
B
barrierye 已提交
337
        fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
338
        logging.debug(self._log("finish predict"))
B
barrierye 已提交
339 340 341
        return fetch_map

    def postprocess(self, output_data):
342 343 344 345
        raise Exception(
            self._log(
                'Please override this method to convert data to the format in channel.'
            ))
B
barrierye 已提交
346 347 348 349

    def stop(self):
        self._run = False

B
barrierye 已提交
350
    def start(self, concurrency_idx):
B
barrierye 已提交
351 352
        self._run = True
        while self._run:
B
barrierye 已提交
353
            _profiler.record("{}{}-get_0".format(self._name, concurrency_idx))
354
            input_data = self._input.front(self._name)
B
barrierye 已提交
355
            _profiler.record("{}{}-get_1".format(self._name, concurrency_idx))
356 357 358 359 360
            data_id = None
            logging.debug(self._log("input_data: {}".format(input_data)))
            if isinstance(input_data, dict):
                key = input_data.keys()[0]
                data_id = input_data[key].id
B
barrierye 已提交
361
            else:
362
                data_id = input_data.id
B
barrierye 已提交
363

B
barrierye 已提交
364
            _profiler.record("{}{}-prep_0".format(self._name, concurrency_idx))
365
            data = self.preprocess(input_data)
B
barrierye 已提交
366
            _profiler.record("{}{}-prep_1".format(self._name, concurrency_idx))
B
barrierye 已提交
367

B
barrierye 已提交
368
            if self.with_serving():
B
barrierye 已提交
369 370
                _profiler.record("{}{}-midp_0".format(self._name,
                                                      concurrency_idx))
371
                data = self.midprocess(data)
B
barrierye 已提交
372 373
                _profiler.record("{}{}-midp_1".format(self._name,
                                                      concurrency_idx))
B
barrierye 已提交
374

B
barrierye 已提交
375
            _profiler.record("{}{}-postp_0".format(self._name, concurrency_idx))
376
            output_data = self.postprocess(data)
B
barrierye 已提交
377
            _profiler.record("{}{}-postp_1".format(self._name, concurrency_idx))
378 379 380 381 382 383 384 385

            if not isinstance(output_data,
                              python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log(
                        'output_data must be ChannelData type, but get {}'.
                        format(type(output_data))))
            output_data.id = data_id
B
barrierye 已提交
386

B
barrierye 已提交
387
            _profiler.record("{}{}-push_0".format(self._name, concurrency_idx))
B
barrierye 已提交
388
            for channel in self._outputs:
389
                channel.push(output_data, self._name)
B
barrierye 已提交
390
            _profiler.record("{}{}-push_1".format(self._name, concurrency_idx))
391 392 393

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
394

395 396 397
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
398 399 400

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
401
    def __init__(self, in_channel, out_channel):
B
barrierye 已提交
402
        super(GeneralPythonService, self).__init__()
B
barrierye 已提交
403
        self._name = "#G"
404 405
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
406 407
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
408 409 410 411 412
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
413 414 415 416 417
        self._globel_resp_dict = {}
        self._id_counter = 0
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
418 419 420

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
421

422
    def set_in_channel(self, in_channel):
423 424 425 426 427
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
        in_channel.add_producer(self._name)
428 429 430
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
431 432 433 434 435
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
        out_channel.add_consumer(self._name)
436 437
        self._out_channel = out_channel

B
barrierye 已提交
438 439
    def _recive_out_channel_func(self):
        while True:
440 441 442 443 444
            data = self._out_channel.front(self._name)
            if not isinstance(data, python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
                              format(type(data))))
B
barrierye 已提交
445 446 447
            with self._cv:
                self._globel_resp_dict[data.id] = data
                self._cv.notify_all()
B
barrierye 已提交
448 449

    def _get_next_id(self):
B
barrierye 已提交
450
        with self._id_lock:
B
barrierye 已提交
451 452 453 454
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
455 456 457 458 459 460
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
461
        return resp
B
barrierye 已提交
462 463

    def _pack_data_for_infer(self, request):
464
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
465
        data = python_service_channel_pb2.ChannelData()
B
barrierye 已提交
466 467
        data_id = self._get_next_id()
        data.id = data_id
B
barrierye 已提交
468
        for idx, name in enumerate(request.feed_var_names):
469 470 471
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
472
            inst = python_service_channel_pb2.Inst()
B
barrierye 已提交
473
            inst.data = request.feed_insts[idx]
B
barrierye 已提交
474 475
            inst.name = name
            data.insts.append(inst)
B
barrierye 已提交
476 477 478
        return data, data_id

    def _pack_data_for_resp(self, data):
479
        logging.debug(self._log('get data'))
B
barrierye 已提交
480
        resp = general_python_service_pb2.Response()
481
        logging.debug(self._log('gen resp'))
B
barrierye 已提交
482
        logging.debug(data)
B
barrierye 已提交
483
        for inst in data.insts:
484
            logging.debug(self._log('append data'))
B
barrierye 已提交
485
            resp.fetch_insts.append(inst.data)
486
            logging.debug(self._log('append name'))
B
barrierye 已提交
487 488
            resp.fetch_var_names.append(inst.name)
        return resp
B
barrierye 已提交
489

B
barrierye 已提交
490
    def inference(self, request, context):
B
barrierye 已提交
491
        _profiler.record("{}-prepack_0".format(self._name))
B
barrierye 已提交
492
        data, data_id = self._pack_data_for_infer(request)
B
barrierye 已提交
493 494
        _profiler.record("{}-prepack_1".format(self._name))

495
        logging.debug(self._log('push data'))
B
barrierye 已提交
496
        _profiler.record("{}-push_0".format(self._name))
497
        self._in_channel.push(data, self._name)
B
barrierye 已提交
498 499
        _profiler.record("{}-push_1".format(self._name))

500
        logging.debug(self._log('wait for infer'))
B
barrierye 已提交
501
        resp_data = None
B
barrierye 已提交
502
        _profiler.record("{}-fetch_0".format(self._name))
B
barrierye 已提交
503
        resp_data = self._get_data_in_globel_resp_dict(data_id)
B
barrierye 已提交
504 505 506
        _profiler.record("{}-fetch_1".format(self._name))

        _profiler.record("{}-postpack_0".format(self._name))
B
barrierye 已提交
507
        resp = self._pack_data_for_resp(resp_data)
B
barrierye 已提交
508 509
        _profiler.record("{}-postpack_1".format(self._name))
        _profiler.print_profile()
B
barrierye 已提交
510 511
        return resp

B
barrierye 已提交
512 513

class PyServer(object):
B
barrierye 已提交
514
    def __init__(self, profile=False):
B
barrierye 已提交
515 516 517 518 519
        self._channels = []
        self._ops = []
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
520 521
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
522
        _profiler.enable(profile)
B
barrierye 已提交
523 524 525 526 527

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
B
barrierye 已提交
528
        self._ops.append(op)
B
barrierye 已提交
529 530

    def gen_desc(self):
B
barrierye 已提交
531
        logging.info('here will generate desc for paas')
B
barrierye 已提交
532 533 534 535 536
        pass

    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
B
barrierye 已提交
537 538
        inputs = set()
        outputs = set()
B
barrierye 已提交
539
        for op in self._ops:
540
            inputs |= set([op.get_input()])
B
barrierye 已提交
541
            outputs |= set(op.get_outputs())
B
barrierye 已提交
542 543
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
544 545 546 547 548 549 550 551 552 553
        in_channel = inputs - outputs
        out_channel = outputs - inputs
        if len(in_channel) != 1 or len(out_channel) != 1:
            raise Exception(
                "in_channel(out_channel) more than 1 or no in_channel(out_channel)"
            )
        self._in_channel = in_channel.pop()
        self._out_channel = out_channel.pop()
        self.gen_desc()

B
barrierye 已提交
554 555
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
556

557
    def _run_ops(self):
B
barrierye 已提交
558
        for op in self._ops:
559
            op_concurrency = op.get_concurrency()
560 561
            logging.debug("run op: {}, op_concurrency: {}".format(
                op._name, op_concurrency))
562
            for c in range(op_concurrency):
B
barrierye 已提交
563
                # th = multiprocessing.Process(
564
                th = threading.Thread(
B
barrierye 已提交
565
                    target=self._op_start_wrapper, args=(op, c))
566 567 568 569 570
                th.start()
                self._op_threads.append(th)

    def run_server(self):
        self._run_ops()
B
barrierye 已提交
571 572
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
573
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
574
            GeneralPythonService(self._in_channel, self._out_channel), server)
B
barrierye 已提交
575
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
576 577 578 579
        server.start()
        try:
            for th in self._op_threads:
                th.join()
B
barrierye 已提交
580
            server.join()
B
barrierye 已提交
581 582 583 584 585 586 587 588 589 590 591 592 593 594
        except KeyboardInterrupt:
            server.stop(0)

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
            cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
        else:
            cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
595 596
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
597
        return
598
        # os.system(cmd)