pyserver.py 42.1 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20
import paddle_serving_server
21
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
22
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24
import grpc
25 26 27 28
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
29
import logging
30
import random
B
barrierye 已提交
31
import time
B
barrierye 已提交
32
import func_timeout
33
import enum
34
import collections
B
barrierye 已提交
35 36


B
barrierye 已提交
37 38 39 40 41 42 43 44 45 46 47
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
48 49
        if self._enable is False:
            return
B
barrierye 已提交
50 51 52 53 54 55
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
56 57
        if self._enable is False:
            return
B
barrierye 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


77 78 79
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
80 81
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
B
barrierye 已提交
82 83
    RPC_PACKAGE_ERROR = 4
    UNKNOW = 5
84 85 86 87 88


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1
B
barrierye 已提交
89
    CHANNEL_NPDATA = 2
B
bug fix  
barrierye 已提交
90
    ERROR = 3
91 92 93 94


class ChannelData(object):
    def __init__(self,
B
barrierye 已提交
95
                 datatype=None,
96 97
                 future=None,
                 pbdata=None,
B
barrierye 已提交
98
                 npdata=None,
99
                 data_id=None,
B
barrierye 已提交
100 101 102 103 104 105
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
barrierye 已提交
106 107 108 109 110 111
        1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
        2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
        3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
        4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
        5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        6. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
112 113 114 115
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
116
            pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
117
            pbdata.ecode = ecode
118
            pbdata.id = data_id
B
barrierye 已提交
119
            pbdata.error_info = error_info
B
bug fix  
barrierye 已提交
120
            datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
121
        else:
B
barrierye 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
            if datatype == ChannelDataType.CHANNEL_FUTURE.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    pbdata.ecode = ChannelDataEcode.OK.value
                    pbdata.id = data_id
            elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    pbdata.id = data_id
                    ecode, error_info = self._check_npdata(npdata)
                    pbdata.ecode = ecode
                    if pbdata.ecode != ChannelDataEcode.OK.value:
                        pbdata.error_info = error_info
                        logging.error(pbdata.error_info)
                    else:
B
bug fix  
barrierye 已提交
141
                        for name, value in npdata.items():
B
barrierye 已提交
142 143 144 145 146 147 148 149 150
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
            elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
151 152
                pbdata = channel_pb2.ChannelData()
                pbdata.id = data_id
B
barrierye 已提交
153 154 155 156 157 158 159 160 161 162
                pbdata.ecode = ecode
                if pbdata.ecode != ChannelDataEcode.OK.value:
                    pbdata.error_info = error_info
                    logging.error(pbdata.error_info)
            else:
                raise ValueError("datatype not match")
        if not isinstance(pbdata, channel_pb2.ChannelData):
            raise TypeError(
                "pbdata must be pyserving_channel_pb2.ChannelData type({})".
                format(type(pbdata)))
B
barrierye 已提交
163
        self.future = future
164
        self.pbdata = pbdata
B
barrierye 已提交
165 166
        self.npdata = npdata
        self.datatype = datatype
167 168
        self.callback_func = callback_func

B
barrierye 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    def _check_npdata(self, npdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for name, value in npdata.items():
            if not isinstance(name, (str, unicode)):
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the key of postped_data must " \
                        "be str, but get {}".format(type(name)))
                break
            if not isinstance(value, np.ndarray):
                pbdata.ecode = ChannelDataEcode.TYPE_ERROR.value
                pbdata.error_info = log("the value of postped_data must " \
                        "be np.ndarray, but get {}".format(type(value)))
                break
        return ecode, error_info

185 186
    def parse(self):
        # return narray
B
barrierye 已提交
187 188 189
        feed = None
        if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
            feed = {}
190 191 192
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
B
barrierye 已提交
193
        elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
194 195 196
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
B
barrierye 已提交
197 198
        elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            feed = self.npdata
199
        else:
B
bug fix  
barrierye 已提交
200 201
            raise TypeError("Error type({}) in pbdata.type.".format(
                self.pbdata.type))
202 203
        return feed

B
barrierye 已提交
204 205 206 207
    def __str__(self):
        return "type[{}], ecode[{}]".format(
            ChannelDataType(self.datatype).name, self.pbdata.ecode)

208

B
barrierye 已提交
209
class Channel(Queue.Queue):
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
227
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
228
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
229 230
        self._maxsize = maxsize
        self._timeout = timeout
231
        self.name = name
232
        self._stop = False
233

234 235 236 237 238 239 240 241
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
242
        self._consumer_base_idx = 0
243 244 245 246 247 248 249 250 251
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
252
        return "[{}] {}".format(self.name, info_str)
253 254 255 256 257 258

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
259
        """ not thread safe, and can only be called during initialization. """
260 261 262 263
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
264 265

    def add_consumer(self, op_name):
B
barrierye 已提交
266
        """ not thread safe, and can only be called during initialization. """
267 268 269 270
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
271 272 273 274

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
275

276
    def push(self, channeldata, op_name=None):
277
        logging.debug(
278
            self._log("{} try to push data: {}".format(op_name,
B
barrierye 已提交
279
                                                       channeldata.__str__())))
280
        if len(self._producers) == 0:
281
            raise Exception(
282 283 284 285
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
286
            with self._cv:
287
                while self._stop is False:
B
barrierye 已提交
288
                    try:
289
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
290 291 292 293
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
294 295 296 297 298 299
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
300

301
        producer_num = len(self._producers)
302
        data_id = channeldata.pbdata.id
303
        put_data = None
B
barrierye 已提交
304
        with self._cv:
305
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
306 307 308 309 310 311
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
312
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
313 314 315 316 317 318
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
319

B
barrierye 已提交
320 321
            if put_data is None:
                logging.debug(
322
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
323 324
                              format(op_name)))
            else:
325
                while self._stop is False:
B
barrierye 已提交
326 327 328 329 330 331 332 333 334
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
335
        return True
336

337 338 339 340 341 342 343 344 345
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
346
            with self._cv:
347
                while self._stop is False and resp is None:
B
barrierye 已提交
348 349 350 351 352
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
353 354 355
            logging.debug(
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
356 357 358 359 360
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
361

B
barrierye 已提交
362 363
        with self._cv:
            # data_idx = consumer_idx - base_idx
364 365
            while self._stop is False and self._consumers[
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
366
                try:
367 368
                    channeldata = self.get(timeout=0)
                    self._front_res.append(channeldata)
B
barrierye 已提交
369 370 371
                    break
                except Queue.Empty:
                    self._cv.wait()
372

B
barrierye 已提交
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
391

B
barrierye 已提交
392
            self._cv.notify_all()
393

394
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
395
        return resp  # reference, read only
B
barrierye 已提交
396

397 398 399 400
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
401
        self._cv.notify_all()
402

B
barrierye 已提交
403 404 405

class Op(object):
    def __init__(self,
406
                 name,
407
                 inputs,
B
barrierye 已提交
408 409 410 411 412
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
413
                 fetch_names=None,
B
barrierye 已提交
414
                 concurrency=1,
B
barrierye 已提交
415 416
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
417
        self._run = False
418
        self.name = name  # to identify the type of OP, it must be globally unique
419
        self._concurrency = concurrency  # amount of concurrency
420 421
        self.set_input_ops(inputs)
        self.set_client(client_config, server_name, fetch_names)
B
barrierye 已提交
422 423
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
424
        self._device = device
B
barrierye 已提交
425
        self._timeout = timeout
B
bug fix  
barrierye 已提交
426
        self._retry = max(1, retry)
427 428
        self._input = None
        self._outputs = []
B
barrierye 已提交
429 430

    def set_client(self, client_config, server_name, fetch_names):
431 432 433 434 435
        self._client = None
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
            return
B
barrierye 已提交
436 437 438 439 440 441 442 443
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

444
    def get_input_channel(self):
445
        return self._input
B
barrierye 已提交
446

447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
462 463 464 465
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
466
        channel.add_consumer(self.name)
467
        self._input = channel
B
barrierye 已提交
468

469
    def get_output_channels(self):
B
barrierye 已提交
470 471
        return self._outputs

472 473
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
474
            raise TypeError(
475 476 477 478
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
479

480 481
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
482 483 484
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
485
        feed = channeldata.parse()
486
        return feed
B
barrierye 已提交
487 488

    def midprocess(self, data):
489 490 491 492 493 494 495
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
496 497 498 499
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
500 501

    def postprocess(self, output_data):
B
barrierye 已提交
502
        return output_data
B
barrierye 已提交
503 504

    def stop(self):
505 506 507
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
508 509
        self._run = False

510
    def _parse_channeldata(self, channeldata):
B
bug fix  
barrierye 已提交
511
        data_id, error_pbdata = None, None
512 513 514 515 516
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
            data_id = channeldata[key].pbdata.id
            for _, data in channeldata.items():
B
bug fix  
barrierye 已提交
517 518
                if data.pbdata.ecode != ChannelDataEcode.OK.value:
                    error_pbdata = data.pbdata
519 520 521
                    break
        else:
            data_id = channeldata.pbdata.id
B
bug fix  
barrierye 已提交
522 523 524
            if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
                error_pbdata = channeldata.pbdata
        return data_id, error_pbdata
525

B
bug fix  
barrierye 已提交
526 527 528
    def _push_to_output_channels(self, data, name=None):
        if name is None:
            name = self.name
B
barrierye 已提交
529
        for channel in self._outputs:
B
bug fix  
barrierye 已提交
530
            channel.push(data, name)
B
barrierye 已提交
531

B
barrierye 已提交
532
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
533
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
534
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
535 536
        self._run = True
        while self._run:
B
barrierye 已提交
537
            _profiler.record("{}-get_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
538
            channeldata = self._input.front(self.name)
B
barrierye 已提交
539
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
540
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
541

B
bug fix  
barrierye 已提交
542
            data_id, error_pbdata = self._parse_channeldata(channeldata)
543

B
bug fix  
barrierye 已提交
544 545
            # error data in predecessor Op
            if error_pbdata is not None:
B
barrierye 已提交
546 547 548 549
                self._push_to_output_channels(
                    ChannelData(
                        datatype=ChannelDataType.CHANNEL_PBDATA.value,
                        pbdata=error_pbdata))
B
barrierye 已提交
550 551
                continue

B
bug fix  
barrierye 已提交
552
            # preprecess
B
barrierye 已提交
553 554
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
555
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
556 557
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
558
                # preprocess function not implemented
B
barrierye 已提交
559 560 561 562 563 564 565 566
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
B
bug fix  
barrierye 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
            except TypeError as e:
                # Error type in channeldata.pbdata.type
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
                        data_id=data_id))
                continue
586

B
barrierye 已提交
587 588 589
            # midprocess
            call_future = None
            if self.with_serving():
B
bug fix  
barrierye 已提交
590
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
591 592 593
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
B
bug fix  
barrierye 已提交
594
                        call_future = self.midprocess(preped_data)
B
barrierye 已提交
595 596 597 598
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
599
                else:
B
barrierye 已提交
600 601 602
                    for i in range(self._retry):
                        try:
                            call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
603 604 605 606
                                self._timeout,
                                self.midprocess,
                                args=(preped_data, ))
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
607 608
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
609 610
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
611 612
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
613
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
614 615 616 617 618 619 620
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
621
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
622 623 624 625 626 627
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
628

B
barrierye 已提交
629 630 631
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
632 633
            if self.with_serving():
                # use call_future
B
barrierye 已提交
634
                output_data = ChannelData(
B
barrierye 已提交
635
                    datatype=ChannelDataType.CHANNEL_FUTURE.value,
B
barrierye 已提交
636 637 638 639
                    future=call_future,
                    data_id=data_id,
                    callback_func=self.postprocess)
            else:
B
bug fix  
barrierye 已提交
640 641 642 643 644 645 646 647 648 649 650 651
                try:
                    postped_data = self.postprocess(preped_data)
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
652 653
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
654
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
655 656 657 658 659 660
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id))
                    continue
B
bug fix  
barrierye 已提交
661

B
barrierye 已提交
662 663 664 665
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
666 667 668 669 670 671 672 673 674 675 676 677 678 679 680
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
            self._push_to_output_channels(output_data)
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
681

682 683 684
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
685

B
bug fix  
barrierye 已提交
686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

    def start(self, concurrency_idx):
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
        self._run = True
        while self._run:
            _profiler.record("{}-get_0".format(op_info_prefix))
            channeldata = self._input.front(self.name)
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
                    self._push_to_output_channels(data, name=name)
            else:
                self._push_to_output_channels(channeldata,
                                              self._virtual_pred_ops[0].name)
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
725 726
class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
727
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
728
        super(GeneralPythonService, self).__init__()
729
        self.name = "#G"
730 731
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
732 733
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
734 735 736 737 738
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
739 740
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
741
        self._retry = retry
B
barrierye 已提交
742 743 744
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
745 746

    def _log(self, info_str):
747
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
748

749
    def set_in_channel(self, in_channel):
750 751 752 753
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
754
        in_channel.add_producer(self.name)
755 756 757
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
758 759 760 761
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
762
        out_channel.add_consumer(self.name)
763 764
        self._out_channel = out_channel

B
barrierye 已提交
765 766
    def _recive_out_channel_func(self):
        while True:
767
            channeldata = self._out_channel.front(self.name)
768
            if not isinstance(channeldata, ChannelData):
769 770
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
771
                              format(type(channeldata))))
B
barrierye 已提交
772
            with self._cv:
773 774
                data_id = channeldata.pbdata.id
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
775
                self._cv.notify_all()
B
barrierye 已提交
776 777

    def _get_next_id(self):
B
barrierye 已提交
778
        with self._id_lock:
B
barrierye 已提交
779 780 781 782
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
783 784 785 786 787 788
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
789
        return resp
B
barrierye 已提交
790 791

    def _pack_data_for_infer(self, request):
792
        logging.debug(self._log('start inferce'))
793
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
794
        data_id = self._get_next_id()
795
        pbdata.id = data_id
B
barrierye 已提交
796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814
        pbdata.ecode = ChannelDataEcode.OK.value
        try:
            for idx, name in enumerate(request.feed_var_names):
                logging.debug(
                    self._log('name: {}'.format(request.feed_var_names[idx])))
                logging.debug(
                    self._log('data: {}'.format(request.feed_insts[idx])))
                inst = channel_pb2.Inst()
                inst.data = request.feed_insts[idx]
                inst.shape = request.shape[idx]
                inst.name = name
                inst.type = request.type[idx]
                pbdata.insts.append(inst)
        except Exception as e:
            pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
            pbdata.error_info = "rpc package error"
        return ChannelData(
            datatype=ChannelDataType.CHANNEL_PBDATA.value,
            pbdata=pbdata), data_id
B
barrierye 已提交
815

816 817 818 819
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
        resp.ecode = channeldata.pbdata.ecode
B
bug fix  
barrierye 已提交
820
        if resp.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
821
            if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
822 823 824 825 826
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
B
barrierye 已提交
827 828 829
            elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
                                          ChannelDataType.CHANNEL_NPDATA.value):
                feed = channeldata.parse()
B
bug fix  
barrierye 已提交
830
                for name, var in feed.items():
831 832 833 834 835 836 837 838 839
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
                    self._log("Error type({}) in pbdata.type.".format(
B
barrierye 已提交
840
                        channeldata.datatype)))
B
barrierye 已提交
841
        else:
842
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
843
        return resp
B
barrierye 已提交
844

B
barrierye 已提交
845
    def inference(self, request, context):
846
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
847
        data, data_id = self._pack_data_for_infer(request)
848
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
849

850
        resp_channeldata = None
B
barrierye 已提交
851 852
        for i in range(self._retry):
            logging.debug(self._log('push data'))
853 854 855
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
856 857

            logging.debug(self._log('wait for infer'))
858
            _profiler.record("{}-fetch_0".format(self.name))
859
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
860
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
861

B
bug fix  
barrierye 已提交
862
            if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
863
                break
B
barrierye 已提交
864 865 866
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
                    i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
867

868
        _profiler.record("{}-postpack_0".format(self.name))
869
        resp = self._pack_data_for_resp(resp_channeldata)
870
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
871
        _profiler.print_profile()
B
barrierye 已提交
872 873
        return resp

B
barrierye 已提交
874 875

class PyServer(object):
B
barrierye 已提交
876
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
877
        self._channels = []
878
        self._user_ops = []
B
bug fix  
barrierye 已提交
879
        self._actual_ops = []
B
barrierye 已提交
880 881 882
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
883 884
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
885
        self._retry = retry
B
barrierye 已提交
886
        _profiler.enable(profile)
B
barrierye 已提交
887 888 889 890 891

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
892 893 894
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
895
        self._user_ops.extend(ops)
B
barrierye 已提交
896 897

    def gen_desc(self):
898
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
899 900
        pass

901 902 903
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
904
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
905 906 907 908 909
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
910 911 912 913 914 915 916 917
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
918
                outdegs[pred_op.name].append(op)
919

B
bug fix  
barrierye 已提交
920
        # topo sort to get dag_views
921 922 923 924 925 926 927 928 929 930
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
931
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
932
                    indeg_num[succ_op.name] -= 1
933 934 935 936 937 938 939 940 941 942 943 944 945 946
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
947 948 949 950 951 952 953 954 955 956 957
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
958 959 960
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
961
        actual_view = None
962 963 964 965
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
966 967
            if actual_view is None:
                actual_view = view
968 969
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
970 971
            for op in actual_view:
                # find actual succ op in next view and create virtual op
972 973
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
974 975
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
976 977 978 979
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
980 981 982
                        # create virtual op
                        virtual_op = None
                        virtual_op = VirtualOp(name=virtual_op_name_gen.next())
983
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
984 985 986 987 988
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
989 990 991
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
992
                if op.name in processed_op:
993
                    continue
B
bug fix  
barrierye 已提交
994
                channel = Channel(name=channel_name_gen.next())
995
                channels.append(channel)
B
bug fix  
barrierye 已提交
996
                logging.debug("{} => {}".format(channel.name, op.name))
997
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
998
                pred_ops = pred_op_of_next_view_op[op.name]
999 1000 1001
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
1002
                    # if pred_op is virtual op, it will use ancestors as producers to channel
1003
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
1004 1005
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
1006
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
1007 1008 1009
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
1021 1022
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
1023 1024
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
B
bug fix  
barrierye 已提交
1025
        output_channel = Channel(name=channel_name_gen.next())
1026 1027 1028 1029
        channels.append(output_channel)
        last_op = dag_views[-1][0]
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
1030
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
1031 1032
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
1033
                # pass read op
B
fix bug  
barrierye 已提交
1034
                continue
B
bug fix  
barrierye 已提交
1035
            self._actual_ops.append(op)
1036
        self._channels = channels
B
bug fix  
barrierye 已提交
1037 1038
        for c in channels:
            logging.debug(c.debug())
1039 1040
        return input_channel, output_channel

B
barrierye 已提交
1041 1042 1043
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
1044 1045 1046

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
1047
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
1048
        for op in self._actual_ops:
B
fix bug  
barrierye 已提交
1049 1050
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
1051 1052
        self.gen_desc()

B
barrierye 已提交
1053 1054
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
1055

1056
    def _run_ops(self):
B
bug fix  
barrierye 已提交
1057
        for op in self._actual_ops:
1058
            op_concurrency = op.get_concurrency()
1059
            logging.debug("run op: {}, op_concurrency: {}".format(
1060
                op.name, op_concurrency))
1061 1062
            for c in range(op_concurrency):
                th = threading.Thread(
B
barrierye 已提交
1063
                    target=self._op_start_wrapper, args=(op, c))
1064 1065 1066
                th.start()
                self._op_threads.append(th)

1067
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1068
        for op in self._actual_ops:
1069 1070
            op.stop()

1071 1072
    def run_server(self):
        self._run_ops()
B
barrierye 已提交
1073 1074
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1075
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1076 1077
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1078
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1079
        server.start()
1080 1081
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
bug fix  
barrierye 已提交
1082 1083
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
1084 1085 1086 1087 1088 1089 1090

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
1091 1092
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1093
        else:
1094 1095
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
1096 1097
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))