pyserver.py 42.1 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20
import paddle_serving_server
21
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
22
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24
import grpc
25 26 27 28
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
29
import logging
30
import random
B
barrierye 已提交
31
import time
B
barrierye 已提交
32
import func_timeout
33
import enum
34
import collections
B
barrierye 已提交
35 36


B
barrierye 已提交
37 38 39 40 41 42 43 44 45 46 47
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
48 49
        if self._enable is False:
            return
B
barrierye 已提交
50 51 52 53 54 55
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
56 57
        if self._enable is False:
            return
B
barrierye 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


77 78 79
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
80 81
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
B
barrierye 已提交
82 83
    RPC_PACKAGE_ERROR = 4
    UNKNOW = 5
84 85 86 87 88


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1
B
barrierye 已提交
89
    CHANNEL_NPDATA = 2
90 91 92 93


class ChannelData(object):
    def __init__(self,
B
barrierye 已提交
94
                 datatype=None,
95 96
                 future=None,
                 pbdata=None,
B
barrierye 已提交
97
                 npdata=None,
98
                 data_id=None,
B
barrierye 已提交
99 100 101 102 103 104
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
barrierye 已提交
105 106 107 108 109 110
        1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
        2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
        3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
        4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
        5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        6. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
111 112 113 114
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
115
            pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
116
            pbdata.ecode = ecode
117
            pbdata.id = data_id
B
barrierye 已提交
118 119
            pbdata.error_info = error_info
        else:
B
barrierye 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
            if datatype == ChannelDataType.CHANNEL_FUTURE.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    pbdata.ecode = ChannelDataEcode.OK.value
                    pbdata.id = data_id
            elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    pbdata.id = data_id
                    ecode, error_info = self._check_npdata(npdata)
                    pbdata.ecode = ecode
                    if pbdata.ecode != ChannelDataEcode.OK.value:
                        pbdata.error_info = error_info
                        logging.error(pbdata.error_info)
                    else:
                        for name, value in postped_data.items():
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
            elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
149 150
                pbdata = channel_pb2.ChannelData()
                pbdata.id = data_id
B
barrierye 已提交
151 152 153 154 155 156 157 158 159 160
                pbdata.ecode = ecode
                if pbdata.ecode != ChannelDataEcode.OK.value:
                    pbdata.error_info = error_info
                    logging.error(pbdata.error_info)
            else:
                raise ValueError("datatype not match")
        if not isinstance(pbdata, channel_pb2.ChannelData):
            raise TypeError(
                "pbdata must be pyserving_channel_pb2.ChannelData type({})".
                format(type(pbdata)))
B
barrierye 已提交
161
        self.future = future
162
        self.pbdata = pbdata
B
barrierye 已提交
163 164
        self.npdata = npdata
        self.datatype = datatype
165 166
        self.callback_func = callback_func

B
barrierye 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    def _check_npdata(self, npdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for name, value in npdata.items():
            if not isinstance(name, (str, unicode)):
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the key of postped_data must " \
                        "be str, but get {}".format(type(name)))
                break
            if not isinstance(value, np.ndarray):
                pbdata.ecode = ChannelDataEcode.TYPE_ERROR.value
                pbdata.error_info = log("the value of postped_data must " \
                        "be np.ndarray, but get {}".format(type(value)))
                break
        return ecode, error_info

183 184
    def parse(self):
        # return narray
B
barrierye 已提交
185 186 187
        feed = None
        if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
            feed = {}
188 189 190
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
B
barrierye 已提交
191
        elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
192 193 194
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
B
barrierye 已提交
195 196
        elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            feed = self.npdata
197
        else:
B
bug fix  
barrierye 已提交
198 199
            raise TypeError("Error type({}) in pbdata.type.".format(
                self.pbdata.type))
200 201
        return feed

B
barrierye 已提交
202 203 204 205
    def __str__(self):
        return "type[{}], ecode[{}]".format(
            ChannelDataType(self.datatype).name, self.pbdata.ecode)

206

B
barrierye 已提交
207
class Channel(Queue.Queue):
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
225
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
226
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
227 228
        self._maxsize = maxsize
        self._timeout = timeout
229
        self.name = name
230
        self._stop = False
231

232 233 234 235 236 237 238 239
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
240
        self._consumer_base_idx = 0
241 242 243 244 245 246 247 248 249
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
250
        return "[{}] {}".format(self.name, info_str)
251 252 253 254 255 256

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
257
        """ not thread safe, and can only be called during initialization. """
258 259 260 261
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
262 263

    def add_consumer(self, op_name):
B
barrierye 已提交
264
        """ not thread safe, and can only be called during initialization. """
265 266 267 268
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
269 270 271 272

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
273

274
    def push(self, channeldata, op_name=None):
275
        logging.debug(
276
            self._log("{} try to push data: {}".format(op_name,
B
barrierye 已提交
277
                                                       channeldata.__str__())))
278
        if len(self._producers) == 0:
279
            raise Exception(
280 281 282 283
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
284
            with self._cv:
285
                while self._stop is False:
B
barrierye 已提交
286
                    try:
287
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
288 289 290 291
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
292 293 294 295 296 297
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
298

299
        producer_num = len(self._producers)
300
        data_id = channeldata.pbdata.id
301
        put_data = None
B
barrierye 已提交
302
        with self._cv:
303
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
304 305 306 307 308 309
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
310
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
311 312 313 314 315 316
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
317

B
barrierye 已提交
318 319
            if put_data is None:
                logging.debug(
320
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
321 322
                              format(op_name)))
            else:
323
                while self._stop is False:
B
barrierye 已提交
324 325 326 327 328 329 330 331 332
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
333
        return True
334

335 336 337 338 339 340 341 342 343
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
344
            with self._cv:
345
                while self._stop is False and resp is None:
B
barrierye 已提交
346 347 348 349 350
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
351 352 353
            logging.debug(
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
354 355 356 357 358
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
359

B
barrierye 已提交
360 361
        with self._cv:
            # data_idx = consumer_idx - base_idx
362 363
            while self._stop is False and self._consumers[
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
364
                try:
365 366
                    channeldata = self.get(timeout=0)
                    self._front_res.append(channeldata)
B
barrierye 已提交
367 368 369
                    break
                except Queue.Empty:
                    self._cv.wait()
370

B
barrierye 已提交
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
389

B
barrierye 已提交
390
            self._cv.notify_all()
391

392
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
393
        return resp  # reference, read only
B
barrierye 已提交
394

395 396 397 398
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
399
        self._cv.notify_all()
400

B
barrierye 已提交
401 402 403

class Op(object):
    def __init__(self,
404
                 name,
405
                 inputs,
B
barrierye 已提交
406 407 408 409 410
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
411
                 fetch_names=None,
B
barrierye 已提交
412
                 concurrency=1,
B
barrierye 已提交
413 414
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
415
        self._run = False
416
        self.name = name  # to identify the type of OP, it must be globally unique
417
        self._concurrency = concurrency  # amount of concurrency
418 419
        self.set_input_ops(inputs)
        self.set_client(client_config, server_name, fetch_names)
B
barrierye 已提交
420 421
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
422
        self._device = device
B
barrierye 已提交
423
        self._timeout = timeout
B
bug fix  
barrierye 已提交
424
        self._retry = max(1, retry)
425 426
        self._input = None
        self._outputs = []
B
barrierye 已提交
427 428

    def set_client(self, client_config, server_name, fetch_names):
429 430 431 432 433
        self._client = None
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
            return
B
barrierye 已提交
434 435 436 437 438 439 440 441
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

442
    def get_input_channel(self):
443
        return self._input
B
barrierye 已提交
444

445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
460 461 462 463
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
464
        channel.add_consumer(self.name)
465
        self._input = channel
B
barrierye 已提交
466

467
    def get_output_channels(self):
B
barrierye 已提交
468 469
        return self._outputs

470 471
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
472
            raise TypeError(
473 474 475 476
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
477

478 479
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
480 481 482
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
483
        feed = channeldata.parse()
484
        return feed
B
barrierye 已提交
485 486

    def midprocess(self, data):
487 488 489 490 491 492 493
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
494 495 496 497
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
498 499

    def postprocess(self, output_data):
B
barrierye 已提交
500
        return output_data
B
barrierye 已提交
501 502

    def stop(self):
503 504 505
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
506 507
        self._run = False

508
    def _parse_channeldata(self, channeldata):
B
bug fix  
barrierye 已提交
509
        data_id, error_pbdata = None, None
510 511 512 513 514
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
            data_id = channeldata[key].pbdata.id
            for _, data in channeldata.items():
B
bug fix  
barrierye 已提交
515 516
                if data.pbdata.ecode != ChannelDataEcode.OK.value:
                    error_pbdata = data.pbdata
517 518 519
                    break
        else:
            data_id = channeldata.pbdata.id
B
bug fix  
barrierye 已提交
520 521 522
            if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
                error_pbdata = channeldata.pbdata
        return data_id, error_pbdata
523

B
bug fix  
barrierye 已提交
524 525 526
    def _push_to_output_channels(self, data, name=None):
        if name is None:
            name = self.name
B
barrierye 已提交
527
        for channel in self._outputs:
B
bug fix  
barrierye 已提交
528
            channel.push(data, name)
B
barrierye 已提交
529

B
barrierye 已提交
530
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
531
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
532
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
533 534
        self._run = True
        while self._run:
B
barrierye 已提交
535
            _profiler.record("{}-get_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
536
            channeldata = self._input.front(self.name)
B
barrierye 已提交
537
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
538
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
539

B
bug fix  
barrierye 已提交
540
            data_id, error_pbdata = self._parse_channeldata(channeldata)
541

B
bug fix  
barrierye 已提交
542 543
            # error data in predecessor Op
            if error_pbdata is not None:
B
barrierye 已提交
544 545 546 547
                self._push_to_output_channels(
                    ChannelData(
                        datatype=ChannelDataType.CHANNEL_PBDATA.value,
                        pbdata=error_pbdata))
B
barrierye 已提交
548 549
                continue

B
bug fix  
barrierye 已提交
550
            # preprecess
B
barrierye 已提交
551 552
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
553
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
554 555
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
556
                # preprocess function not implemented
B
barrierye 已提交
557 558 559 560 561 562 563 564
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
B
bug fix  
barrierye 已提交
565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
            except TypeError as e:
                # Error type in channeldata.pbdata.type
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
584

B
barrierye 已提交
585 586 587
            # midprocess
            call_future = None
            if self.with_serving():
B
bug fix  
barrierye 已提交
588
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
589 590 591
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
B
bug fix  
barrierye 已提交
592
                        call_future = self.midprocess(preped_data)
B
barrierye 已提交
593 594 595 596
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
597
                else:
B
barrierye 已提交
598 599 600
                    for i in range(self._retry):
                        try:
                            call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
601 602 603 604
                                self._timeout,
                                self.midprocess,
                                args=(preped_data, ))
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
605 606
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
607 608
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
609 610
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
611
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
612 613 614 615 616 617 618
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
619
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
620 621 622 623 624 625
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
626

B
barrierye 已提交
627 628 629
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
630 631
            if self.with_serving():
                # use call_future
B
barrierye 已提交
632
                output_data = ChannelData(
B
barrierye 已提交
633
                    datatype=ChannelDataType.CHANNEL_FUTURE.value,
B
barrierye 已提交
634 635 636 637
                    future=call_future,
                    data_id=data_id,
                    callback_func=self.postprocess)
            else:
B
bug fix  
barrierye 已提交
638 639 640 641 642 643 644 645 646 647 648 649
                try:
                    postped_data = self.postprocess(preped_data)
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
650 651
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
652
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
653 654 655 656 657 658
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
B
bug fix  
barrierye 已提交
659

B
barrierye 已提交
660 661 662 663
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
            self._push_to_output_channels(output_data)
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
679

680 681 682
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
683

B
bug fix  
barrierye 已提交
684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

    def start(self, concurrency_idx):
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
        self._run = True
        while self._run:
            _profiler.record("{}-get_0".format(op_info_prefix))
            channeldata = self._input.front(self.name)
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
                    self._push_to_output_channels(data, name=name)
            else:
                self._push_to_output_channels(channeldata,
                                              self._virtual_pred_ops[0].name)
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
723 724
class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
725
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
726
        super(GeneralPythonService, self).__init__()
727
        self.name = "#G"
728 729
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
730 731
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
732 733 734 735 736
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
737 738
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
739
        self._retry = retry
B
barrierye 已提交
740 741 742
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
743 744

    def _log(self, info_str):
745
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
746

747
    def set_in_channel(self, in_channel):
748 749 750 751
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
752
        in_channel.add_producer(self.name)
753 754 755
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
756 757 758 759
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
760
        out_channel.add_consumer(self.name)
761 762
        self._out_channel = out_channel

B
barrierye 已提交
763 764
    def _recive_out_channel_func(self):
        while True:
765
            channeldata = self._out_channel.front(self.name)
766
            if not isinstance(channeldata, ChannelData):
767 768
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
769
                              format(type(channeldata))))
B
barrierye 已提交
770
            with self._cv:
771 772
                data_id = channeldata.pbdata.id
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
773
                self._cv.notify_all()
B
barrierye 已提交
774 775

    def _get_next_id(self):
B
barrierye 已提交
776
        with self._id_lock:
B
barrierye 已提交
777 778 779 780
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
781 782 783 784 785 786
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
787
        return resp
B
barrierye 已提交
788 789

    def _pack_data_for_infer(self, request):
790
        logging.debug(self._log('start inferce'))
791
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
792
        data_id = self._get_next_id()
793
        pbdata.id = data_id
B
barrierye 已提交
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812
        pbdata.ecode = ChannelDataEcode.OK.value
        try:
            for idx, name in enumerate(request.feed_var_names):
                logging.debug(
                    self._log('name: {}'.format(request.feed_var_names[idx])))
                logging.debug(
                    self._log('data: {}'.format(request.feed_insts[idx])))
                inst = channel_pb2.Inst()
                inst.data = request.feed_insts[idx]
                inst.shape = request.shape[idx]
                inst.name = name
                inst.type = request.type[idx]
                pbdata.insts.append(inst)
        except Exception as e:
            pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
            pbdata.error_info = "rpc package error"
        return ChannelData(
            datatype=ChannelDataType.CHANNEL_PBDATA.value,
            pbdata=pbdata), data_id
B
barrierye 已提交
813

814 815 816 817
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
        resp.ecode = channeldata.pbdata.ecode
B
bug fix  
barrierye 已提交
818
        if resp.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
819
            if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
820 821 822 823 824
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
B
barrierye 已提交
825 826 827
            elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
                                          ChannelDataType.CHANNEL_NPDATA.value):
                feed = channeldata.parse()
B
bug fix  
barrierye 已提交
828
                for name, var in feed.items():
829 830 831 832 833 834 835 836 837
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
                    self._log("Error type({}) in pbdata.type.".format(
B
barrierye 已提交
838
                        channeldata.datatype)))
B
barrierye 已提交
839
        else:
840
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
841
        return resp
B
barrierye 已提交
842

B
barrierye 已提交
843
    def inference(self, request, context):
844
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
845
        data, data_id = self._pack_data_for_infer(request)
846
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
847

848
        resp_channeldata = None
B
barrierye 已提交
849 850
        for i in range(self._retry):
            logging.debug(self._log('push data'))
851 852 853
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
854 855

            logging.debug(self._log('wait for infer'))
856
            _profiler.record("{}-fetch_0".format(self.name))
857
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
858
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
859

B
bug fix  
barrierye 已提交
860
            if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
861
                break
B
barrierye 已提交
862 863 864
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
                    i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
865

866
        _profiler.record("{}-postpack_0".format(self.name))
867
        resp = self._pack_data_for_resp(resp_channeldata)
868
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
869
        _profiler.print_profile()
B
barrierye 已提交
870 871
        return resp

B
barrierye 已提交
872 873

class PyServer(object):
B
barrierye 已提交
874
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
875
        self._channels = []
876
        self._user_ops = []
B
bug fix  
barrierye 已提交
877
        self._actual_ops = []
B
barrierye 已提交
878 879 880
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
881 882
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
883
        self._retry = retry
B
barrierye 已提交
884
        _profiler.enable(profile)
B
barrierye 已提交
885 886 887 888 889

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
890 891 892
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
893
        self._user_ops.extend(ops)
B
barrierye 已提交
894 895

    def gen_desc(self):
896
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
897 898
        pass

899 900 901
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
902
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
903 904 905 906 907
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
908 909 910 911 912 913 914 915
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
916
                outdegs[pred_op.name].append(op)
917

B
bug fix  
barrierye 已提交
918
        # topo sort to get dag_views
919 920 921 922 923 924 925 926 927 928
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
929
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
930
                    indeg_num[succ_op.name] -= 1
931 932 933 934 935 936 937 938 939 940 941 942 943 944
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
945 946 947 948 949 950 951 952 953 954 955
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
956 957 958
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
959
        actual_view = None
960 961 962 963
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
964 965
            if actual_view is None:
                actual_view = view
966 967
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
968 969
            for op in actual_view:
                # find actual succ op in next view and create virtual op
970 971
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
972 973
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
974 975 976 977
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
978 979 980
                        # create virtual op
                        virtual_op = None
                        virtual_op = VirtualOp(name=virtual_op_name_gen.next())
981
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
982 983 984 985 986
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
987 988 989
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
990
                if op.name in processed_op:
991
                    continue
B
bug fix  
barrierye 已提交
992
                channel = Channel(name=channel_name_gen.next())
993
                channels.append(channel)
B
bug fix  
barrierye 已提交
994
                logging.debug("{} => {}".format(channel.name, op.name))
995
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
996
                pred_ops = pred_op_of_next_view_op[op.name]
997 998 999
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
1000
                    # if pred_op is virtual op, it will use ancestors as producers to channel
1001
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
1002 1003
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
1004
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
1005 1006 1007
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
1019 1020
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
1021 1022
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
B
bug fix  
barrierye 已提交
1023
        output_channel = Channel(name=channel_name_gen.next())
1024 1025 1026 1027
        channels.append(output_channel)
        last_op = dag_views[-1][0]
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
1028
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
1029 1030
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
1031
                # pass read op
B
fix bug  
barrierye 已提交
1032
                continue
B
bug fix  
barrierye 已提交
1033
            self._actual_ops.append(op)
1034
        self._channels = channels
B
bug fix  
barrierye 已提交
1035 1036
        for c in channels:
            logging.debug(c.debug())
1037 1038
        return input_channel, output_channel

B
barrierye 已提交
1039 1040 1041
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
1042 1043 1044

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
1045
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
1046
        for op in self._actual_ops:
B
fix bug  
barrierye 已提交
1047 1048
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
1049 1050
        self.gen_desc()

B
barrierye 已提交
1051 1052
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
1053

1054
    def _run_ops(self):
B
bug fix  
barrierye 已提交
1055
        for op in self._actual_ops:
1056
            op_concurrency = op.get_concurrency()
1057
            logging.debug("run op: {}, op_concurrency: {}".format(
1058
                op.name, op_concurrency))
1059 1060
            for c in range(op_concurrency):
                th = threading.Thread(
B
barrierye 已提交
1061
                    target=self._op_start_wrapper, args=(op, c))
1062 1063 1064
                th.start()
                self._op_threads.append(th)

1065
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1066
        for op in self._actual_ops:
1067 1068
            op.stop()

1069 1070
    def run_server(self):
        self._run_ops()
B
barrierye 已提交
1071 1072
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1073
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1074 1075
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1076
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1077
        server.start()
1078 1079
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
bug fix  
barrierye 已提交
1080 1081
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
1082 1083 1084 1085 1086 1087 1088

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
1089 1090
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1091
        else:
1092 1093
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
1094 1095
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))