pyserver.py 40.0 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20
import paddle_serving_server
21
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
22
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24
import grpc
25 26 27 28
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
29
import logging
30
import random
B
barrierye 已提交
31
import time
B
barrierye 已提交
32
import func_timeout
33
import enum
34
import collections
B
barrierye 已提交
35 36


B
barrierye 已提交
37 38 39 40 41 42 43 44 45 46 47
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
48 49
        if self._enable is False:
            return
B
barrierye 已提交
50 51 52 53 54 55
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
56 57
        if self._enable is False:
            return
B
barrierye 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


77 78 79
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
80 81 82
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
    UNKNOW = 4
83 84 85 86 87 88 89 90 91 92 93 94


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1


class ChannelData(object):
    def __init__(self,
                 future=None,
                 pbdata=None,
                 data_id=None,
B
barrierye 已提交
95 96 97 98 99 100
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
bug fix  
barrierye 已提交
101 102 103 104
        1. ChannelData(future, pbdata[, callback_func])
        2. ChannelData(future, data_id[, callback_func])
        3. ChannelData(pbdata)
        4. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
105 106 107 108
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
109
            pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
110
            pbdata.ecode = ecode
111
            pbdata.id = data_id
B
barrierye 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125
            pbdata.error_info = error_info
        else:
            if pbdata is None:
                if data_id is None:
                    raise ValueError("data_id cannot be None")
                pbdata = channel_pb2.ChannelData()
                pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
                pbdata.ecode = ChannelDataEcode.OK.value
                pbdata.id = data_id
            elif not isinstance(pbdata, channel_pb2.ChannelData):
                raise TypeError(
                    "pbdata must be pyserving_channel_pb2.ChannelData type({})".
                    format(type(pbdata)))
        self.future = future
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        self.pbdata = pbdata
        self.callback_func = callback_func

    def parse(self):
        # return narray
        feed = {}
        if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
        elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
        else:
B
bug fix  
barrierye 已提交
141 142
            raise TypeError("Error type({}) in pbdata.type.".format(
                self.pbdata.type))
143 144 145
        return feed


B
barrierye 已提交
146
class Channel(Queue.Queue):
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
164
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
165
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
166 167
        self._maxsize = maxsize
        self._timeout = timeout
168
        self.name = name
169
        self._stop = False
170

171 172 173 174 175 176 177 178
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
179
        self._consumer_base_idx = 0
180 181 182 183 184 185 186 187 188
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
189
        return "[{}] {}".format(self.name, info_str)
190 191 192 193 194 195

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
196
        """ not thread safe, and can only be called during initialization. """
197 198 199 200
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
201 202

    def add_consumer(self, op_name):
B
barrierye 已提交
203
        """ not thread safe, and can only be called during initialization. """
204 205 206 207
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
208 209 210 211

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
212

213
    def push(self, channeldata, op_name=None):
214
        logging.debug(
215 216
            self._log("{} try to push data: {}".format(op_name,
                                                       channeldata.pbdata)))
217
        if len(self._producers) == 0:
218
            raise Exception(
219 220 221 222
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
223
            with self._cv:
224
                while self._stop is False:
B
barrierye 已提交
225
                    try:
226
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
227 228 229 230
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
231 232 233 234 235 236
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
237

238
        producer_num = len(self._producers)
239
        data_id = channeldata.pbdata.id
240
        put_data = None
B
barrierye 已提交
241
        with self._cv:
242
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
243 244 245 246 247 248
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
249
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
250 251 252 253 254 255
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
256

B
barrierye 已提交
257 258
            if put_data is None:
                logging.debug(
259
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
260 261
                              format(op_name)))
            else:
262
                while self._stop is False:
B
barrierye 已提交
263 264 265 266 267 268 269 270 271
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
272
        return True
273

274 275 276 277 278 279 280 281 282
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
283
            with self._cv:
284
                while self._stop is False and resp is None:
B
barrierye 已提交
285 286 287 288 289
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
290 291 292 293 294 295
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
296

B
barrierye 已提交
297 298
        with self._cv:
            # data_idx = consumer_idx - base_idx
299 300
            while self._stop is False and self._consumers[
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
301
                try:
302 303
                    channeldata = self.get(timeout=0)
                    self._front_res.append(channeldata)
B
barrierye 已提交
304 305 306
                    break
                except Queue.Empty:
                    self._cv.wait()
307

B
barrierye 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
326

B
barrierye 已提交
327
            self._cv.notify_all()
328

329
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
330
        return resp  # reference, read only
B
barrierye 已提交
331

332 333 334 335
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
336
        self._cv.notify_all()
337

B
barrierye 已提交
338 339 340

class Op(object):
    def __init__(self,
341
                 name,
342
                 inputs,
B
barrierye 已提交
343 344 345 346 347
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
348
                 fetch_names=None,
B
barrierye 已提交
349
                 concurrency=1,
B
barrierye 已提交
350 351
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
352
        self._run = False
353
        self.name = name  # to identify the type of OP, it must be globally unique
354
        self._concurrency = concurrency  # amount of concurrency
355 356
        self.set_input_ops(inputs)
        self.set_client(client_config, server_name, fetch_names)
B
barrierye 已提交
357 358
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
359
        self._device = device
B
barrierye 已提交
360
        self._timeout = timeout
B
bug fix  
barrierye 已提交
361
        self._retry = max(1, retry)
362 363
        self._input = None
        self._outputs = []
B
barrierye 已提交
364 365

    def set_client(self, client_config, server_name, fetch_names):
366 367 368 369 370
        self._client = None
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
            return
B
barrierye 已提交
371 372 373 374 375 376 377 378
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

379
    def get_input_channel(self):
380
        return self._input
B
barrierye 已提交
381

382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
397 398 399 400
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
401
        channel.add_consumer(self.name)
402
        self._input = channel
B
barrierye 已提交
403

404
    def get_output_channels(self):
B
barrierye 已提交
405 406
        return self._outputs

407 408
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
409
            raise TypeError(
410 411 412 413
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
414

415 416
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
417 418 419
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
420
        feed = channeldata.parse()
421
        return feed
B
barrierye 已提交
422 423

    def midprocess(self, data):
424 425 426 427 428 429 430
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
431 432 433 434
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
435 436

    def postprocess(self, output_data):
B
barrierye 已提交
437
        return output_data
B
barrierye 已提交
438 439

    def stop(self):
440 441 442
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
443 444
        self._run = False

445
    def _parse_channeldata(self, channeldata):
B
bug fix  
barrierye 已提交
446
        data_id, error_pbdata = None, None
447 448 449 450 451
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
            data_id = channeldata[key].pbdata.id
            for _, data in channeldata.items():
B
bug fix  
barrierye 已提交
452 453
                if data.pbdata.ecode != ChannelDataEcode.OK.value:
                    error_pbdata = data.pbdata
454 455 456
                    break
        else:
            data_id = channeldata.pbdata.id
B
bug fix  
barrierye 已提交
457 458 459
            if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
                error_pbdata = channeldata.pbdata
        return data_id, error_pbdata
460

B
bug fix  
barrierye 已提交
461 462 463
    def _push_to_output_channels(self, data, name=None):
        if name is None:
            name = self.name
B
barrierye 已提交
464
        for channel in self._outputs:
B
bug fix  
barrierye 已提交
465
            channel.push(data, name)
B
barrierye 已提交
466

B
barrierye 已提交
467
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
468
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
469
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
470 471
        self._run = True
        while self._run:
B
barrierye 已提交
472
            _profiler.record("{}-get_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
473
            channeldata = self._input.front(self.name)
B
barrierye 已提交
474
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
475
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
476

B
bug fix  
barrierye 已提交
477
            data_id, error_pbdata = self._parse_channeldata(channeldata)
478

B
bug fix  
barrierye 已提交
479 480 481
            # error data in predecessor Op
            if error_pbdata is not None:
                self._push_to_output_channels(ChannelData(pbdata=error_pbdata))
B
barrierye 已提交
482 483
                continue

B
bug fix  
barrierye 已提交
484
            # preprecess
B
barrierye 已提交
485 486
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
487
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
488 489
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
490
                # preprocess function not implemented
B
barrierye 已提交
491 492 493 494 495 496 497 498
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
B
bug fix  
barrierye 已提交
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
            except TypeError as e:
                # Error type in channeldata.pbdata.type
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
518

B
barrierye 已提交
519 520 521
            # midprocess
            call_future = None
            if self.with_serving():
B
bug fix  
barrierye 已提交
522
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
523 524 525
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
B
bug fix  
barrierye 已提交
526
                        call_future = self.midprocess(preped_data)
B
barrierye 已提交
527 528 529 530
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
531
                else:
B
barrierye 已提交
532 533 534
                    for i in range(self._retry):
                        try:
                            call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
535 536 537 538
                                self._timeout,
                                self.midprocess,
                                args=(preped_data, ))
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
539 540
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
541 542
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
543 544
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
545
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
546 547 548 549 550 551 552
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
553
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
554 555 556 557 558 559
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
560

B
barrierye 已提交
561 562 563
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
564 565
            if self.with_serving():
                # use call_future
B
barrierye 已提交
566 567 568 569 570
                output_data = ChannelData(
                    future=call_future,
                    data_id=data_id,
                    callback_func=self.postprocess)
            else:
B
bug fix  
barrierye 已提交
571 572 573 574 575 576 577 578 579 580 581 582
                try:
                    postped_data = self.postprocess(preped_data)
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
583 584
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
585
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
586 587 588 589 590 591
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
B
bug fix  
barrierye 已提交
592 593 594

                ecode = ChannelDataEcode.OK.value
                error_info = None
B
barrierye 已提交
595
                pbdata = channel_pb2.ChannelData()
B
bug fix  
barrierye 已提交
596 597 598 599 600 601 602 603 604 605 606
                for name, value in postped_data.items():
                    if not isinstance(name, (str, unicode)):
                        ecode = ChannelDataEcode.TYPE_ERROR.value
                        error_info = log("the key of postped_data must " \
                                "be str, but get {}".format(type(name)))
                        break
                    if not isinstance(value, np.ndarray):
                        ecode = ChannelDataEcode.TYPE_ERROR.value
                        error_info = log("the value of postped_data must " \
                                "be np.ndarray, but get {}".format(type(value)))
                        break
B
barrierye 已提交
607 608 609 610 611 612
                    inst = channel_pb2.Inst()
                    inst.data = value.tobytes()
                    inst.name = name
                    inst.shape = np.array(value.shape, dtype="int32").tobytes()
                    inst.type = str(value.dtype)
                    pbdata.insts.append(inst)
B
bug fix  
barrierye 已提交
613 614 615 616 617 618 619 620
                if ecode != ChannelDataEcode.OK.value:
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                pbdata.ecode = ecode
B
barrierye 已提交
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
                pbdata.id = data_id
                output_data = ChannelData(pbdata=pbdata)
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
            self._push_to_output_channels(output_data)
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
638

639 640 641
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
642

B
bug fix  
barrierye 已提交
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

    def start(self, concurrency_idx):
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
        self._run = True
        while self._run:
            _profiler.record("{}-get_0".format(op_info_prefix))
            channeldata = self._input.front(self.name)
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
                    self._push_to_output_channels(data, name=name)
            else:
                self._push_to_output_channels(channeldata,
                                              self._virtual_pred_ops[0].name)
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
682 683
class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
684
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
685
        super(GeneralPythonService, self).__init__()
686
        self.name = "#G"
687 688
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
689 690
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
691 692 693 694 695
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
696 697
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
698
        self._retry = retry
B
barrierye 已提交
699 700 701
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
702 703

    def _log(self, info_str):
704
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
705

706
    def set_in_channel(self, in_channel):
707 708 709 710
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
711
        in_channel.add_producer(self.name)
712 713 714
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
715 716 717 718
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
719
        out_channel.add_consumer(self.name)
720 721
        self._out_channel = out_channel

B
barrierye 已提交
722 723
    def _recive_out_channel_func(self):
        while True:
724
            channeldata = self._out_channel.front(self.name)
725
            if not isinstance(channeldata, ChannelData):
726 727
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
728
                              format(type(channeldata))))
B
barrierye 已提交
729
            with self._cv:
730 731
                data_id = channeldata.pbdata.id
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
732
                self._cv.notify_all()
B
barrierye 已提交
733 734

    def _get_next_id(self):
B
barrierye 已提交
735
        with self._id_lock:
B
barrierye 已提交
736 737 738 739
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
740 741 742 743 744 745
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
746
        return resp
B
barrierye 已提交
747 748

    def _pack_data_for_infer(self, request):
749
        logging.debug(self._log('start inferce'))
750
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
751
        data_id = self._get_next_id()
752
        pbdata.id = data_id
B
barrierye 已提交
753
        for idx, name in enumerate(request.feed_var_names):
754 755 756
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
757
            inst = channel_pb2.Inst()
B
barrierye 已提交
758
            inst.data = request.feed_insts[idx]
759
            inst.shape = request.shape[idx]
B
barrierye 已提交
760
            inst.name = name
761 762
            inst.type = request.type[idx]
            pbdata.insts.append(inst)
B
bug fix  
barrierye 已提交
763
        pbdata.ecode = ChannelDataEcode.OK.value  #TODO: parse request error
764
        return ChannelData(pbdata=pbdata), data_id
B
barrierye 已提交
765

766 767 768 769
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
        resp.ecode = channeldata.pbdata.ecode
B
bug fix  
barrierye 已提交
770
        if resp.ecode == ChannelDataEcode.OK.value:
771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791
            if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
            elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
                feed = channeldata.futures.result()
                if channeldata.callback_func is not None:
                    feed = channeldata.callback_func(feed)
                for name, var in feed:
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
                    self._log("Error type({}) in pbdata.type.".format(
                        self.pbdata.type)))
B
barrierye 已提交
792
        else:
793
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
794
        return resp
B
barrierye 已提交
795

B
barrierye 已提交
796
    def inference(self, request, context):
797
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
798
        data, data_id = self._pack_data_for_infer(request)
799
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
800

801
        resp_channeldata = None
B
barrierye 已提交
802 803
        for i in range(self._retry):
            logging.debug(self._log('push data'))
804 805 806
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
807 808

            logging.debug(self._log('wait for infer'))
809
            _profiler.record("{}-fetch_0".format(self.name))
810
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
811
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
812

B
bug fix  
barrierye 已提交
813
            if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
814
                break
B
barrierye 已提交
815 816 817
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
                    i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
818

819
        _profiler.record("{}-postpack_0".format(self.name))
820
        resp = self._pack_data_for_resp(resp_channeldata)
821
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
822
        _profiler.print_profile()
B
barrierye 已提交
823 824
        return resp

B
barrierye 已提交
825 826

class PyServer(object):
B
barrierye 已提交
827
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
828
        self._channels = []
829
        self._user_ops = []
B
bug fix  
barrierye 已提交
830
        self._actual_ops = []
B
barrierye 已提交
831 832 833
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
834 835
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
836
        self._retry = retry
B
barrierye 已提交
837
        _profiler.enable(profile)
B
barrierye 已提交
838 839 840 841 842

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
843 844 845
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
846
        self._user_ops.extend(ops)
B
barrierye 已提交
847 848

    def gen_desc(self):
849
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
850 851
        pass

852 853 854
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
855
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
856 857 858 859 860
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
861 862 863 864 865 866 867 868
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
869
                outdegs[pred_op.name].append(op)
870

B
bug fix  
barrierye 已提交
871
        # topo sort to get dag_views
872 873 874 875 876 877 878 879 880 881
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
882
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
883
                    indeg_num[succ_op.name] -= 1
884 885 886 887 888 889 890 891 892 893 894 895 896 897
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
898 899 900 901 902 903 904 905 906 907 908
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
909 910 911
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
912
        actual_view = None
913 914 915 916
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
917 918
            if actual_view is None:
                actual_view = view
919 920
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
921 922
            for op in actual_view:
                # find actual succ op in next view and create virtual op
923 924
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
925 926
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
927 928 929 930
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
931 932 933
                        # create virtual op
                        virtual_op = None
                        virtual_op = VirtualOp(name=virtual_op_name_gen.next())
934
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
935 936 937 938 939
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
940 941 942
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
943
                if op.name in processed_op:
944
                    continue
B
bug fix  
barrierye 已提交
945
                channel = Channel(name=channel_name_gen.next())
946
                channels.append(channel)
B
bug fix  
barrierye 已提交
947
                logging.debug("{} => {}".format(channel.name, op.name))
948
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
949
                pred_ops = pred_op_of_next_view_op[op.name]
950 951 952
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
953
                    # if pred_op is virtual op, it will use ancestors as producers to channel
954
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
955 956
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
957
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
958 959 960
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
961 962 963 964 965 966 967 968 969 970 971
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
972 973
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
974 975
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
B
bug fix  
barrierye 已提交
976
        output_channel = Channel(name=channel_name_gen.next())
977 978 979 980
        channels.append(output_channel)
        last_op = dag_views[-1][0]
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
981
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
982 983
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
984
                # pass read op
B
fix bug  
barrierye 已提交
985
                continue
B
bug fix  
barrierye 已提交
986
            self._actual_ops.append(op)
987
        self._channels = channels
B
bug fix  
barrierye 已提交
988 989
        for c in channels:
            logging.debug(c.debug())
990 991
        return input_channel, output_channel

B
barrierye 已提交
992 993 994
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
995 996 997

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
998
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
999
        for op in self._actual_ops:
B
fix bug  
barrierye 已提交
1000 1001
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
1002 1003
        self.gen_desc()

B
barrierye 已提交
1004 1005
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
1006

1007
    def _run_ops(self):
B
bug fix  
barrierye 已提交
1008
        for op in self._actual_ops:
1009
            op_concurrency = op.get_concurrency()
1010
            logging.debug("run op: {}, op_concurrency: {}".format(
1011
                op.name, op_concurrency))
1012 1013
            for c in range(op_concurrency):
                th = threading.Thread(
B
barrierye 已提交
1014
                    target=self._op_start_wrapper, args=(op, c))
1015 1016 1017
                th.start()
                self._op_threads.append(th)

1018
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1019
        for op in self._actual_ops:
1020 1021
            op.stop()

1022 1023
    def run_server(self):
        self._run_ops()
B
barrierye 已提交
1024 1025
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1026
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1027 1028
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1029
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1030
        server.start()
1031 1032
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
bug fix  
barrierye 已提交
1033 1034
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
1035 1036 1037 1038 1039 1040 1041

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
1042 1043
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1044
        else:
1045 1046
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
1047 1048
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))