pyserver.py 35.2 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20
import paddle_serving_server
21
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
22
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24
import grpc
25 26 27 28
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
29
import logging
30
import random
B
barrierye 已提交
31
import time
B
barrierye 已提交
32
import func_timeout
33
import enum
34
import collections
B
barrierye 已提交
35 36


B
barrierye 已提交
37 38 39 40 41 42 43 44 45 46 47
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
48 49
        if self._enable is False:
            return
B
barrierye 已提交
50 51 52 53 54 55
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
56 57
        if self._enable is False:
            return
B
barrierye 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


77 78 79
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
80 81 82
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
    UNKNOW = 4
83 84 85 86 87 88 89 90 91 92 93 94


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1


class ChannelData(object):
    def __init__(self,
                 future=None,
                 pbdata=None,
                 data_id=None,
B
barrierye 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
        - ChannelData(future, pbdata[, callback_func])
        - ChannelData(future, data_id[, callback_func])
        - ChannelData(pbdata)
        - ChannelData(ecode, error_info, data_id)
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
109
            pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
110
            pbdata.ecode = ecode
111
            pbdata.id = data_id
B
barrierye 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125
            pbdata.error_info = error_info
        else:
            if pbdata is None:
                if data_id is None:
                    raise ValueError("data_id cannot be None")
                pbdata = channel_pb2.ChannelData()
                pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
                pbdata.ecode = ChannelDataEcode.OK.value
                pbdata.id = data_id
            elif not isinstance(pbdata, channel_pb2.ChannelData):
                raise TypeError(
                    "pbdata must be pyserving_channel_pb2.ChannelData type({})".
                    format(type(pbdata)))
        self.future = future
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        self.pbdata = pbdata
        self.callback_func = callback_func

    def parse(self):
        # return narray
        feed = {}
        if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
        elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
        else:
            raise TypeError(
                self._log("Error type({}) in pbdata.type.".format(
                    self.pbdata.type)))
        return feed


B
barrierye 已提交
147
class Channel(Queue.Queue):
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
165
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
166
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
167 168
        self._maxsize = maxsize
        self._timeout = timeout
169
        self.name = name
170
        self._stop = False
171

172 173 174 175 176 177 178 179
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
180
        self._consumer_base_idx = 0
181 182 183 184 185 186 187 188 189
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
190
        return "[{}] {}".format(self.name, info_str)
191 192 193 194 195 196

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
197
        """ not thread safe, and can only be called during initialization. """
198 199 200 201
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
202 203

    def add_consumer(self, op_name):
B
barrierye 已提交
204
        """ not thread safe, and can only be called during initialization. """
205 206 207 208
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
209 210 211 212

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
213

214
    def push(self, channeldata, op_name=None):
215
        logging.debug(
216 217
            self._log("{} try to push data: {}".format(op_name,
                                                       channeldata.pbdata)))
218
        if len(self._producers) == 0:
219
            raise Exception(
220 221 222 223
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
224
            with self._cv:
225
                while self._stop is False:
B
barrierye 已提交
226
                    try:
227
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
228 229 230 231
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
232 233 234 235 236 237
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
238

239
        producer_num = len(self._producers)
240
        data_id = channeldata.pbdata.id
241
        put_data = None
B
barrierye 已提交
242
        with self._cv:
243
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
244 245 246 247 248 249
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
250
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
251 252 253 254 255 256
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
257

B
barrierye 已提交
258 259
            if put_data is None:
                logging.debug(
260
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
261 262
                              format(op_name)))
            else:
263
                while self._stop is False:
B
barrierye 已提交
264 265 266 267 268 269 270 271 272
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
273
        return True
274

275 276 277 278 279 280 281 282 283
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
284
            with self._cv:
285
                while self._stop is False and resp is None:
B
barrierye 已提交
286 287 288 289 290
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
291 292 293 294 295 296
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
297

B
barrierye 已提交
298 299
        with self._cv:
            # data_idx = consumer_idx - base_idx
300 301
            while self._stop is False and self._consumers[
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
302
                try:
303 304
                    channeldata = self.get(timeout=0)
                    self._front_res.append(channeldata)
B
barrierye 已提交
305 306 307
                    break
                except Queue.Empty:
                    self._cv.wait()
308

B
barrierye 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
327

B
barrierye 已提交
328
            self._cv.notify_all()
329

330
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
331
        return resp  # reference, read only
B
barrierye 已提交
332

333 334 335 336 337
    def stop(self):
        #TODO
        self.close()
        self._stop = True

B
barrierye 已提交
338 339 340

class Op(object):
    def __init__(self,
341
                 name,
342
                 inputs,
B
barrierye 已提交
343 344 345 346 347
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
348
                 fetch_names=None,
B
barrierye 已提交
349
                 concurrency=1,
B
barrierye 已提交
350 351
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
352
        self._run = False
353
        self.name = name  # to identify the type of OP, it must be globally unique
354
        self._concurrency = concurrency  # amount of concurrency
355 356
        self.set_input_ops(inputs)
        self.set_client(client_config, server_name, fetch_names)
B
barrierye 已提交
357 358
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
359
        self._device = device
B
barrierye 已提交
360
        self._timeout = timeout
B
barrierye 已提交
361
        self._retry = retry
362 363
        self._input = None
        self._outputs = []
B
barrierye 已提交
364 365

    def set_client(self, client_config, server_name, fetch_names):
366 367 368 369 370
        self._client = None
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
            return
B
barrierye 已提交
371 372 373 374 375 376 377 378
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

379
    def get_input_channel(self):
380
        return self._input
B
barrierye 已提交
381

382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
397 398 399 400
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
401
        channel.add_consumer(self.name)
402
        self._input = channel
B
barrierye 已提交
403

404
    def get_output_channels(self):
B
barrierye 已提交
405 406
        return self._outputs

407 408
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
409
            raise TypeError(
410 411 412 413
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
414

415 416
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
417 418 419
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
420
        feed = channeldata.parse()
421
        return feed
B
barrierye 已提交
422 423

    def midprocess(self, data):
424 425 426 427 428 429 430
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
431 432 433 434
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
435 436

    def postprocess(self, output_data):
B
barrierye 已提交
437
        return output_data
B
barrierye 已提交
438 439

    def stop(self):
440 441 442
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
443 444
        self._run = False

445 446 447 448 449 450 451 452
    def _parse_channeldata(self, channeldata):
        data_id, error_data = None, None
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
            data_id = channeldata[key].pbdata.id
            for _, data in channeldata.items():
                if data.pbdata.ecode != 0:
B
barrierye 已提交
453
                    error_data = data.pbdata
454 455 456 457 458 459 460
                    break
        else:
            data_id = channeldata.pbdata.id
            if channeldata.pbdata.ecode != 0:
                error_data = channeldata.pbdata
        return data_id, error_data

B
barrierye 已提交
461 462 463 464
    def _push_to_output_channels(self, data):
        for channel in self._outputs:
            channel.push(data, self.name)

B
barrierye 已提交
465
    def start(self, concurrency_idx):
B
barrierye 已提交
466 467
        op_info_prefix = "[{}{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
468 469
        self._run = True
        while self._run:
B
barrierye 已提交
470
            _profiler.record("{}-get_0".format(op_info_prefix))
471
            input_data = self._input.front(self.name)
B
barrierye 已提交
472 473
            _profiler.record("{}-get_1".format(op_info_prefix))
            logging.debug(log("input_data: {}".format(input_data)))
B
barrierye 已提交
474

475 476
            data_id, error_data = self._parse_channeldata(input_data)

B
barrierye 已提交
477 478 479 480 481 482 483 484
            # predecessor Op error
            if error_data is not None:
                self._push_to_output_channels(ChannelData(pbdata=error_data))
                continue

            # preprocess function not implemented
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
barrierye 已提交
485
                data = self.preprocess(input_data)
B
barrierye 已提交
486 487 488 489 490 491 492 493 494 495
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
496

B
barrierye 已提交
497 498 499 500 501 502 503 504 505 506 507 508 509 510
            # midprocess
            call_future = None
            ecode = 0
            error_info = None
            if self.with_serving():
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
                        call_future = self.midprocess(data)
                    except Exception as e:
                        logging.error(self._log(e))
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
511
                else:
B
barrierye 已提交
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
                    for i in range(self._retry):
                        try:
                            call_future = func_timeout.func_timeout(
                                self._timeout, self.midprocess, args=(data, ))
                        except func_timeout.FunctionTimedOut:
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
                                error_info = "{} timeout".format(op_info_prefix)
                            else:
                                logging.warn(
                                    log("warn: timeout, retry({})".format(i +
                                                                          1)))
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
                if ecode != 0:
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
538

B
barrierye 已提交
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
            if self.with_serving():  # use call_future
                output_data = ChannelData(
                    future=call_future,
                    data_id=data_id,
                    callback_func=self.postprocess)
            else:
                post_data = self.postprocess(data)
                if not isinstance(post_data, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
                            "dict type, but get {}".format(type(post_data)))
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                pbdata = channel_pb2.ChannelData()
                for name, value in post_data.items():
                    inst = channel_pb2.Inst()
                    inst.data = value.tobytes()
                    inst.name = name
                    inst.shape = np.array(value.shape, dtype="int32").tobytes()
                    inst.type = str(value.dtype)
                    pbdata.insts.append(inst)
                pbdata.ecode = 0
                pbdata.id = data_id
                output_data = ChannelData(pbdata=pbdata)
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
            self._push_to_output_channels(output_data)
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
585

586 587 588
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
589 590 591

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
592
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
593
        super(GeneralPythonService, self).__init__()
594
        self.name = "#G"
595 596
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
597 598
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
599 600 601 602 603
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
604 605
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
606
        self._retry = retry
B
barrierye 已提交
607 608 609
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
610 611

    def _log(self, info_str):
612
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
613

614
    def set_in_channel(self, in_channel):
615 616 617 618
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
619
        in_channel.add_producer(self.name)
620 621 622
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
623 624 625 626
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
627
        out_channel.add_consumer(self.name)
628 629
        self._out_channel = out_channel

B
barrierye 已提交
630 631
    def _recive_out_channel_func(self):
        while True:
632
            channeldata = self._out_channel.front(self.name)
633
            if not isinstance(channeldata, ChannelData):
634 635
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
636
                              format(type(channeldata))))
B
barrierye 已提交
637
            with self._cv:
638 639
                data_id = channeldata.pbdata.id
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
640
                self._cv.notify_all()
B
barrierye 已提交
641 642

    def _get_next_id(self):
B
barrierye 已提交
643
        with self._id_lock:
B
barrierye 已提交
644 645 646 647
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
648 649 650 651 652 653
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
654
        return resp
B
barrierye 已提交
655 656

    def _pack_data_for_infer(self, request):
657
        logging.debug(self._log('start inferce'))
658
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
659
        data_id = self._get_next_id()
660
        pbdata.id = data_id
B
barrierye 已提交
661
        for idx, name in enumerate(request.feed_var_names):
662 663 664
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
665
            inst = channel_pb2.Inst()
B
barrierye 已提交
666
            inst.data = request.feed_insts[idx]
667
            inst.shape = request.shape[idx]
B
barrierye 已提交
668
            inst.name = name
669 670 671 672
            inst.type = request.type[idx]
            pbdata.insts.append(inst)
        pbdata.ecode = 0  #TODO: parse request error
        return ChannelData(pbdata=pbdata), data_id
B
barrierye 已提交
673

674 675
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
676
        logging.debug(self._log('gen resp'))
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
        resp = pyservice_pb2.Response()
        resp.ecode = channeldata.pbdata.ecode
        if resp.ecode == 0:
            if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
                for inst in channeldata.pbdata.insts:
                    logging.debug(self._log('append data'))
                    resp.fetch_insts.append(inst.data)
                    logging.debug(self._log('append name'))
                    resp.fetch_var_names.append(inst.name)
                    logging.debug(self._log('append shape'))
                    resp.shape.append(inst.shape)
                    logging.debug(self._log('append type'))
                    resp.type.append(inst.type)
            elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
                feed = channeldata.futures.result()
                if channeldata.callback_func is not None:
                    feed = channeldata.callback_func(feed)
                for name, var in feed:
                    logging.debug(self._log('append data'))
                    resp.fetch_insts.append(var.tobytes())
                    logging.debug(self._log('append name'))
                    resp.fetch_var_names.append(name)
                    logging.debug(self._log('append shape'))
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
                    self._log("Error type({}) in pbdata.type.".format(
                        self.pbdata.type)))
B
barrierye 已提交
708
        else:
709
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
710
        return resp
B
barrierye 已提交
711

B
barrierye 已提交
712
    def inference(self, request, context):
713
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
714
        data, data_id = self._pack_data_for_infer(request)
715
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
716

717
        resp_channeldata = None
B
barrierye 已提交
718 719
        for i in range(self._retry):
            logging.debug(self._log('push data'))
720 721 722
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
723 724

            logging.debug(self._log('wait for infer'))
725
            _profiler.record("{}-fetch_0".format(self.name))
726
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
727
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
728

729
            if resp_channeldata.pbdata.ecode == 0:
B
barrierye 已提交
730
                break
B
barrierye 已提交
731 732 733
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
                    i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
734

735
        _profiler.record("{}-postpack_0".format(self.name))
736
        resp = self._pack_data_for_resp(resp_channeldata)
737
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
738
        _profiler.print_profile()
B
barrierye 已提交
739 740
        return resp

B
barrierye 已提交
741 742

class PyServer(object):
B
barrierye 已提交
743
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
744
        self._channels = []
745 746
        self._user_ops = []
        self._total_ops = []
B
barrierye 已提交
747 748 749
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
750 751
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
752
        self._retry = retry
B
barrierye 已提交
753
        _profiler.enable(profile)
B
barrierye 已提交
754 755 756 757 758

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
759 760 761
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
762
        self._user_ops.extend(ops)
B
barrierye 已提交
763 764

    def gen_desc(self):
765
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
766 767
        pass

768 769
    def _topo_sort(self):
        indeg_num = {}
B
fix bug  
barrierye 已提交
770
        outdegs = {op.name: [] for op in self._user_ops}
771
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
772
        ques = [Queue.Queue() for _ in range(2)]
773 774 775 776 777 778 779 780
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
781
                outdegs[pred_op.name].append(op)
782 783 784 785 786 787 788 789 790 791 792 793 794 795

        # get dag_views
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                op_name = op.name
                sorted_op_num += 1
                for succ_op in outdegs[op_name]:
B
fix bug  
barrierye 已提交
796
                    indeg_num[succ_op.name] -= 1
797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")

        # create channels and virtual ops
        virtual_op_idx = 0
        channel_idx = 0
        virtual_ops = []
        channels = []
        input_channel = None
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
            actual_next_view = []
            pred_op_of_next_view_op = {}
            for op in view:
                # create virtual op
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
                        actual_next_view.append(succ_op)
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
fix bug  
barrierye 已提交
831
                        vop = Op(name="vir{}".format(virtual_op_idx), inputs=[])
832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874
                        virtual_op_idx += 1
                        virtual_ops.append(virtual_op)
                        outdegs[vop.name] = [succ_op]
                        actual_next_view.append(vop)
                        # TODO: combine vop
                        pred_op_of_next_view_op[vop.name] = [op]
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
                op_name = op.name
                if op_name in processed_op:
                    continue
                channel = Channel(name="chl{}".format(channel_idx))
                channel_idx += 1
                channels.append(channel)
                op.add_input_channel(channel)
                pred_ops = pred_op_of_next_view_op[op_name]
                if v_idx == 0:
                    input_channel = channel
                else:
                    for pred_op in pred_ops:
                        pred_op.add_output_channel(channel)
                processed_op.add(op_name)
                # combine channel
                for other_op in actual_next_view[o_idx:]:
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
        output_channel = Channel(name="Ochl")
        channels.append(output_channel)
        last_op = dag_views[-1][0]
        last_op.add_output_channel(output_channel)

B
fix bug  
barrierye 已提交
875 876 877 878 879
        self._ops = virtual_ops
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                continue
            self._ops.append(op)
880 881 882
        self._channels = channels
        return input_channel, output_channel

B
barrierye 已提交
883 884 885
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
886 887 888

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
889 890 891 892
        self._out_channel = output_channel
        for op in self._ops:
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
893 894
        self.gen_desc()

B
barrierye 已提交
895 896
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
897

898
    def _run_ops(self):
B
barrierye 已提交
899
        for op in self._ops:
900
            op_concurrency = op.get_concurrency()
901
            logging.debug("run op: {}, op_concurrency: {}".format(
902
                op.name, op_concurrency))
903 904
            for c in range(op_concurrency):
                th = threading.Thread(
B
barrierye 已提交
905
                    target=self._op_start_wrapper, args=(op, c))
906 907 908
                th.start()
                self._op_threads.append(th)

909 910 911 912
    def _stop_ops(self):
        for op in self._ops:
            op.stop()

913 914
    def run_server(self):
        self._run_ops()
B
barrierye 已提交
915 916
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
917
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
918 919
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
920
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
921
        server.start()
922 923
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
barrierye 已提交
924 925 926 927 928 929 930

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
931 932
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
933
        else:
934 935
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
936 937
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
938
        return
939
        # os.system(cmd)