pyserver.py 48.3 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import multiprocessing.queues
B
barrierye 已提交
18
import sys
B
barrierye 已提交
19 20 21 22 23 24 25
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
import os
B
barrierye 已提交
26
import paddle_serving_server
27
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
28
from paddle_serving_client import MultiLangPredictFuture
B
barrierye 已提交
29
from concurrent import futures
B
barrierye 已提交
30
import numpy as np
B
barrierye 已提交
31
import grpc
32 33 34 35
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
36
import logging
37
import random
B
barrierye 已提交
38
import time
B
barrierye 已提交
39
import func_timeout
40
import enum
41
import collections
B
barrierye 已提交
42 43


B
barrierye 已提交
44 45 46 47 48 49 50 51 52 53 54
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
55 56
        if self._enable is False:
            return
B
barrierye 已提交
57 58 59 60 61 62
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
63 64
        if self._enable is False:
            return
B
barrierye 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


84 85 86
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
87 88
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
B
barrierye 已提交
89 90
    RPC_PACKAGE_ERROR = 4
    UNKNOW = 5
91 92 93 94 95


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1
B
barrierye 已提交
96
    CHANNEL_NPDATA = 2
B
bug fix  
barrierye 已提交
97
    ERROR = 3
98 99 100 101


class ChannelData(object):
    def __init__(self,
B
barrierye 已提交
102
                 datatype=None,
103 104
                 future=None,
                 pbdata=None,
B
barrierye 已提交
105
                 npdata=None,
106
                 data_id=None,
B
barrierye 已提交
107 108 109 110 111 112
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
barrierye 已提交
113 114 115 116 117 118
        1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
        2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
        3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
        4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
        5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        6. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
119 120 121

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
B
barrierye 已提交
122 123 124 125
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
B
bug fix  
barrierye 已提交
126
            datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
127
        else:
B
barrierye 已提交
128
            if datatype == ChannelDataType.CHANNEL_FUTURE.value:
B
barrierye 已提交
129 130 131
                if data_id is None:
                    raise ValueError("data_id cannot be None")
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
132 133 134 135 136 137
            elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
138 139
                    if ecode != ChannelDataEcode.OK.value:
                        logging.error(error_info)
B
barrierye 已提交
140
                    else:
B
bug fix  
barrierye 已提交
141
                        for name, value in npdata.items():
B
barrierye 已提交
142 143 144 145 146 147 148 149 150
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
            elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
151 152
                if ecode != ChannelDataEcode.OK.value:
                    logging.error(error_info)
B
barrierye 已提交
153 154
            else:
                raise ValueError("datatype not match")
B
barrierye 已提交
155
        self.future = future
156
        self.pbdata = pbdata
B
barrierye 已提交
157 158
        self.npdata = npdata
        self.datatype = datatype
159
        self.callback_func = callback_func
B
barrierye 已提交
160 161 162
        self.id = data_id
        self.ecode = ecode
        self.error_info = error_info
163

B
barrierye 已提交
164 165 166 167 168 169 170 171 172 173
    def _check_npdata(self, npdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for name, value in npdata.items():
            if not isinstance(name, (str, unicode)):
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the key of postped_data must " \
                        "be str, but get {}".format(type(name)))
                break
            if not isinstance(value, np.ndarray):
B
barrierye 已提交
174 175
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the value of postped_data must " \
B
barrierye 已提交
176 177 178 179
                        "be np.ndarray, but get {}".format(type(value)))
                break
        return ecode, error_info

180 181
    def parse(self):
        # return narray
B
barrierye 已提交
182 183 184
        feed = None
        if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
            feed = {}
185 186 187
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
B
barrierye 已提交
188
        elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
189 190 191
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
B
barrierye 已提交
192 193
        elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            feed = self.npdata
194
        else:
B
barrierye 已提交
195
            raise TypeError("Error type({}) in datatype.".format(self.datatype))
196 197
        return feed

B
barrierye 已提交
198
    def __str__(self):
B
barrierye 已提交
199 200
        return "type[{}], ecode[{}], id[{}]".format(
            ChannelDataType(self.datatype).name, self.ecode, self.id)
B
barrierye 已提交
201

202

B
barrierye 已提交
203
class Channel(multiprocessing.queues.Queue):
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
221 222
    def __init__(self, manager, name=None, maxsize=0, timeout=None):
        # https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5/
223 224 225 226 227 228 229
        if sys.version_info.major == 2:
            super(Channel, self).__init__(maxsize=maxsize)
        elif sys.version_info.major == 3:
            super(Channel, self).__init__(
                maxsize=maxsize, ctx=multiprocessing.get_context())
        else:
            raise Exception("Error Python version")
B
barrierye 已提交
230 231
        self._maxsize = maxsize
        self._timeout = timeout
232
        self.name = name
233
        self._stop = False
234

B
barrierye 已提交
235
        self._cv = multiprocessing.Condition()
236 237

        self._producers = []
B
barrierye 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250
        self._producer_res_count = manager.dict()  # {data_id: count}
        # self._producer_res_count = {}  # {data_id: count}
        self._push_res = manager.dict()  # {data_id: {op_name: data}}
        # self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = manager.dict()  # {op_name: idx}
        # self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = manager.dict()  # {idx: num}
        # self._idx_consumer_num = {}  # {idx: num}
        self._consumer_base_idx = manager.Value('i', 0)
        # self._consumer_base_idx = 0
        self._front_res = manager.list()
        # self._front_res = []
251 252 253 254 255 256 257 258

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
259
        return "[{}] {}".format(self.name, info_str)
260 261 262 263 264 265

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
266
        """ not thread safe, and can only be called during initialization. """
267 268 269 270
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
271 272

    def add_consumer(self, op_name):
B
barrierye 已提交
273
        """ not thread safe, and can only be called during initialization. """
274 275 276 277
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
278 279 280 281

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
282

283
    def push(self, channeldata, op_name=None):
284
        logging.debug(
285
            self._log("{} try to push data: {}".format(op_name,
B
barrierye 已提交
286
                                                       channeldata.__str__())))
287
        if len(self._producers) == 0:
288
            raise Exception(
289 290 291 292
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
293
            with self._cv:
294
                while self._stop is False:
B
barrierye 已提交
295
                    try:
296
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
297
                        break
B
barrierye 已提交
298
                    except Queue.Full:
B
barrierye 已提交
299
                        self._cv.wait()
B
barrierye 已提交
300 301 302
                logging.debug(
                    self._log("{} channel size: {}".format(op_name,
                                                           self.qsize())))
B
barrierye 已提交
303
                self._cv.notify_all()
B
barrierye 已提交
304
                logging.debug(self._log("{} notify all".format(op_name)))
305 306 307 308 309 310
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
311

312
        producer_num = len(self._producers)
B
barrierye 已提交
313
        data_id = channeldata.id
314
        put_data = None
B
barrierye 已提交
315
        with self._cv:
316
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
317 318 319 320 321 322
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
B
barrierye 已提交
323 324 325 326 327 328
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
            # self._push_res[data_id][op_name] = channeldata
            tmp_push_res = self._push_res[data_id]
            tmp_push_res[op_name] = channeldata
            self._push_res[data_id] = tmp_push_res

B
barrierye 已提交
329 330 331 332 333 334
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
335

B
barrierye 已提交
336 337
            if put_data is None:
                logging.debug(
338
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
339 340
                              format(op_name)))
            else:
341
                while self._stop is False:
B
barrierye 已提交
342
                    try:
B
barrierye 已提交
343 344 345
                        logging.debug(
                            self._log("{} push data succ: {}".format(
                                op_name, put_data.__str__())))
B
barrierye 已提交
346 347 348 349 350 351 352 353
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
354
        return True
355

356
    def front(self, op_name=None):
B
barrierye 已提交
357
        logging.debug(self._log("{} try to get data...".format(op_name)))
358 359 360 361 362 363 364
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
365
            with self._cv:
366
                while self._stop is False and resp is None:
B
barrierye 已提交
367
                    try:
B
barrierye 已提交
368
                        logging.debug(
369 370 371 372 373 374 375 376 377 378 379 380 381
                            self._log("{} try to get(with channel empty: {})".
                                      format(op_name, self.empty())))
                        # For Python2, after putting an object on an empty queue there may
                        # be an infinitessimal delay before the queue's :meth:`~Queue.empty`
                        # see more:
                        # - https://bugs.python.org/issue18277
                        # - https://hg.python.org/cpython/rev/860fc6a2bd21
                        if sys.version_info.major == 2:
                            resp = self.get(timeout=1e-3)
                        elif sys.version_info.major == 3:
                            resp = self.get(timeout=0)
                        else:
                            raise Exception("Error Python version")
B
barrierye 已提交
382 383
                        break
                    except Queue.Empty:
B
barrierye 已提交
384 385
                        logging.debug(
                            self._log(
386 387
                                "{} wait for empty queue(with channel empty: {})".
                                format(op_name, self.empty())))
B
barrierye 已提交
388
                        self._cv.wait()
B
barrierye 已提交
389 390 391
            logging.debug(
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
392 393 394 395 396
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
397

B
barrierye 已提交
398 399
        with self._cv:
            # data_idx = consumer_idx - base_idx
400
            while self._stop is False and self._consumers[
B
barrierye 已提交
401 402 403 404 405 406 407
                    op_name] - self._consumer_base_idx.value >= len(
                        self._front_res):
                logging.debug(
                    self._log(
                        "({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
                        format(op_name, self._consumers, self.
                               _consumer_base_idx.value, len(self._front_res))))
B
barrierye 已提交
408
                try:
B
barrierye 已提交
409 410 411
                    logging.debug(
                        self._log("{} try to get(with channel size: {})".format(
                            op_name, self.qsize())))
412 413 414 415 416 417 418 419 420 421 422
                    # For Python2, after putting an object on an empty queue there may
                    # be an infinitessimal delay before the queue's :meth:`~Queue.empty`
                    # see more:
                    # - https://bugs.python.org/issue18277
                    # - https://hg.python.org/cpython/rev/860fc6a2bd21
                    if sys.version_info.major == 2:
                        channeldata = self.get(timeout=1e-3)
                    elif sys.version_info.major == 3:
                        channeldata = self.get(timeout=0)
                    else:
                        raise Exception("Error Python version")
423
                    self._front_res.append(channeldata)
B
barrierye 已提交
424 425
                    break
                except Queue.Empty:
B
barrierye 已提交
426 427 428 429
                    logging.debug(
                        self._log(
                            "{} wait for empty queue(with channel size: {})".
                            format(op_name, self.qsize())))
B
barrierye 已提交
430
                    self._cv.wait()
431

B
barrierye 已提交
432
            consumer_idx = self._consumers[op_name]
B
barrierye 已提交
433
            base_idx = self._consumer_base_idx.value
B
barrierye 已提交
434 435 436 437 438 439 440 441 442
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
B
barrierye 已提交
443
                self._consumer_base_idx.value += 1
B
barrierye 已提交
444 445 446 447 448 449

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
B
barrierye 已提交
450 451 452 453 454 455
            logging.debug(
                self._log(
                    "({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
                    format(op_name, self._consumers, self._consumer_base_idx.
                           value, len(self._front_res))))
            logging.debug(self._log("{} notify all".format(op_name)))
B
barrierye 已提交
456
            self._cv.notify_all()
457

458
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
459
        return resp  # reference, read only
B
barrierye 已提交
460

461 462 463 464
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
465
        self._cv.notify_all()
466

B
barrierye 已提交
467 468 469

class Op(object):
    def __init__(self,
470
                 name,
471
                 inputs,
B
barrierye 已提交
472 473 474 475 476
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
477
                 fetch_names=None,
B
barrierye 已提交
478
                 concurrency=1,
B
barrierye 已提交
479 480
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
481
        self._is_run = False
482
        self.name = name  # to identify the type of OP, it must be globally unique
483
        self._concurrency = concurrency  # amount of concurrency
484
        self.set_input_ops(inputs)
B
barrierye 已提交
485
        self._timeout = timeout
B
bug fix  
barrierye 已提交
486
        self._retry = max(1, retry)
487 488
        self._input = None
        self._outputs = []
B
barrierye 已提交
489

B
barrierye 已提交
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
        self.with_serving = False
        self._client_config = client_config
        self._server_name = server_name
        self._fetch_names = fetch_names
        self._server_model = server_model
        self._server_port = server_port
        self._device = device
        if self._client_config is not None and \
                self._server_name is not None and \
                self._fetch_names is not None and \
                self._server_model is not None and \
                self._server_port is not None and \
                self._device is not None:
            self.with_serving = True

    def init_client(self, client_config, server_name, fetch_names):
B
barrierye 已提交
506 507
        if self.with_serving == False:
            logging.debug("{} no client".format(self.name))
508
            return
B
barrierye 已提交
509 510 511
        logging.debug("{} client_config: {}".format(self.name, client_config))
        logging.debug("{} server_name: {}".format(self.name, server_name))
        logging.debug("{} fetch_names: {}".format(self.name, fetch_names))
B
barrierye 已提交
512 513 514 515 516
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

517
    def get_input_channel(self):
518
        return self._input
B
barrierye 已提交
519

520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
535 536 537 538
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
539
        channel.add_consumer(self.name)
540
        self._input = channel
B
barrierye 已提交
541

542
    def get_output_channels(self):
B
barrierye 已提交
543 544
        return self._outputs

545 546
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
547
            raise TypeError(
548 549 550 551
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
552

553 554
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
555 556 557
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
558
        feed = channeldata.parse()
559
        return feed
B
barrierye 已提交
560 561

    def midprocess(self, data):
562 563 564 565 566 567 568
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
569 570 571 572
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
573 574

    def postprocess(self, output_data):
B
barrierye 已提交
575
        return output_data
B
barrierye 已提交
576 577

    def stop(self):
578 579 580
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
581
        self._is_run = False
B
barrierye 已提交
582

583
    def _parse_channeldata(self, channeldata):
B
barrierye 已提交
584
        data_id, error_channeldata = None, None
585 586 587
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
B
barrierye 已提交
588
            data_id = channeldata[key].id
589
            for _, data in channeldata.items():
B
barrierye 已提交
590 591
                if data.ecode != ChannelDataEcode.OK.value:
                    error_channeldata = data
592 593
                    break
        else:
B
barrierye 已提交
594 595 596 597
            data_id = channeldata.id
            if channeldata.ecode != ChannelDataEcode.OK.value:
                error_channeldata = channeldata
        return data_id, error_channeldata
598

B
barrierye 已提交
599
    def _push_to_output_channels(self, data, channels, name=None):
B
bug fix  
barrierye 已提交
600 601
        if name is None:
            name = self.name
B
barrierye 已提交
602
        for channel in channels:
B
bug fix  
barrierye 已提交
603
            channel.push(data, name)
B
barrierye 已提交
604

B
barrierye 已提交
605 606 607 608 609 610 611 612 613 614 615
    def start(self):
        proces = []
        for concurrency_idx in range(self._concurrency):
            p = multiprocessing.Process(
                target=self._run,
                args=(concurrency_idx, self.get_input_channel(),
                      self.get_output_channels()))
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
616
    def _run(self, concurrency_idx, input_channel, output_channels):
B
barrierye 已提交
617 618
        self.init_client(self._client_config, self._server_name,
                         self._fetch_names)
B
bug fix  
barrierye 已提交
619
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
620
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
621 622
        self._is_run = True
        while self._is_run:
B
barrierye 已提交
623
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
624
            channeldata = input_channel.front(self.name)
B
barrierye 已提交
625
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
626
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
627

B
barrierye 已提交
628
            data_id, error_channeldata = self._parse_channeldata(channeldata)
629

B
bug fix  
barrierye 已提交
630
            # error data in predecessor Op
B
barrierye 已提交
631 632 633
            if error_channeldata is not None:
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
B
barrierye 已提交
634 635
                continue

B
bug fix  
barrierye 已提交
636
            # preprecess
B
barrierye 已提交
637 638
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
639
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
640 641
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
642
                # preprocess function not implemented
B
barrierye 已提交
643 644 645 646 647 648
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
B
barrierye 已提交
649 650
                        data_id=data_id),
                    output_channels)
B
barrierye 已提交
651
                continue
B
bug fix  
barrierye 已提交
652
            except TypeError as e:
B
barrierye 已提交
653
                # Error type in channeldata.datatype
B
bug fix  
barrierye 已提交
654 655 656 657 658 659
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
B
barrierye 已提交
660 661
                        data_id=data_id),
                    output_channels)
B
bug fix  
barrierye 已提交
662 663 664 665 666 667
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
B
barrierye 已提交
668
                        ecode=ChannelDataEcode.UNKNOW.value,
B
bug fix  
barrierye 已提交
669
                        error_info=error_info,
B
barrierye 已提交
670 671
                        data_id=data_id),
                    output_channels)
B
bug fix  
barrierye 已提交
672
                continue
673

B
barrierye 已提交
674 675
            # midprocess
            call_future = None
B
barrierye 已提交
676
            if self.with_serving:
B
bug fix  
barrierye 已提交
677
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
678 679 680
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
B
bug fix  
barrierye 已提交
681
                        call_future = self.midprocess(preped_data)
B
barrierye 已提交
682 683 684 685
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
686
                else:
B
barrierye 已提交
687 688 689
                    for i in range(self._retry):
                        try:
                            call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
690 691 692 693
                                self._timeout,
                                self.midprocess,
                                args=(preped_data, ))
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
694 695
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
696 697
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
698 699
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
700
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
701 702 703 704 705 706 707
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
708
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
709 710 711
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
712 713
                            data_id=data_id),
                        output_channels)
B
barrierye 已提交
714 715
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
716

B
barrierye 已提交
717 718 719
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
B
barrierye 已提交
720
            if self.with_serving:
B
bug fix  
barrierye 已提交
721
                # use call_future
B
barrierye 已提交
722
                output_data = ChannelData(
B
barrierye 已提交
723
                    datatype=ChannelDataType.CHANNEL_FUTURE.value,
B
barrierye 已提交
724 725 726
                    future=call_future,
                    data_id=data_id,
                    callback_func=self.postprocess)
B
barrierye 已提交
727 728 729 730 731 732 733 734 735
                #TODO: for future are not picklable
                npdata = self.postprocess(call_future.result())
                self._push_to_output_channels(
                    ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=npdata,
                        data_id=data_id),
                    output_channels)
                continue
B
barrierye 已提交
736
            else:
B
bug fix  
barrierye 已提交
737 738 739 740 741 742 743 744 745
                try:
                    postped_data = self.postprocess(preped_data)
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
746 747
                            data_id=data_id),
                        output_channels)
B
bug fix  
barrierye 已提交
748 749
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
750 751
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
752
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
753 754 755 756
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
757 758
                            data_id=data_id),
                        output_channels)
B
barrierye 已提交
759
                    continue
B
bug fix  
barrierye 已提交
760

B
barrierye 已提交
761 762 763 764
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
765 766 767 768
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
B
barrierye 已提交
769
            self._push_to_output_channels(output_data, output_channels)
B
barrierye 已提交
770 771 772 773 774 775 776 777 778 779
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
780

781 782 783
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
784

B
bug fix  
barrierye 已提交
785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

B
barrierye 已提交
805
    def _run(self, input_channel, output_channels):
B
bug fix  
barrierye 已提交
806 807
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
808 809
        self._is_run = True
        while self._is_run:
B
bug fix  
barrierye 已提交
810
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
811
            channeldata = input_channel.front(self.name)
B
bug fix  
barrierye 已提交
812 813 814 815 816
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
B
barrierye 已提交
817 818
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
B
bug fix  
barrierye 已提交
819
            else:
B
barrierye 已提交
820 821 822 823
                self._push_to_output_channels(
                    channeldata,
                    channels=output_channels,
                    name=self._virtual_pred_ops[0].name)
B
bug fix  
barrierye 已提交
824 825 826
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
827
class GeneralPythonService(
828
        general_python_service_pb2_grpc.GeneralPythonServiceServicer):
B
barrierye 已提交
829
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
830
        super(GeneralPythonService, self).__init__()
831
        self.name = "#G"
832 833
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
834 835
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
836 837 838 839 840
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
841 842
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
843
        self._retry = retry
B
barrierye 已提交
844 845 846
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
847 848

    def _log(self, info_str):
849
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
850

851
    def set_in_channel(self, in_channel):
852 853 854 855
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
856
        in_channel.add_producer(self.name)
857 858 859
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
860 861 862 863
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
864
        out_channel.add_consumer(self.name)
865 866
        self._out_channel = out_channel

B
barrierye 已提交
867 868
    def _recive_out_channel_func(self):
        while True:
869
            channeldata = self._out_channel.front(self.name)
870
            if not isinstance(channeldata, ChannelData):
871 872
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
873
                              format(type(channeldata))))
B
barrierye 已提交
874
            with self._cv:
B
barrierye 已提交
875
                data_id = channeldata.id
876
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
877
                self._cv.notify_all()
B
barrierye 已提交
878 879

    def _get_next_id(self):
B
barrierye 已提交
880
        with self._id_lock:
B
barrierye 已提交
881 882 883 884
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
885 886 887 888 889 890
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
891
        return resp
B
barrierye 已提交
892 893

    def _pack_data_for_infer(self, request):
894
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
895
        data_id = self._get_next_id()
B
barrierye 已提交
896
        npdata = {}
B
barrierye 已提交
897 898 899 900 901 902
        try:
            for idx, name in enumerate(request.feed_var_names):
                logging.debug(
                    self._log('name: {}'.format(request.feed_var_names[idx])))
                logging.debug(
                    self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
903 904 905 906
                npdata[name] = np.frombuffer(
                    request.feed_insts[idx], dtype=request.type[idx])
                npdata[name].shape = np.frombuffer(
                    request.shape[idx], dtype="int32")
B
barrierye 已提交
907
        except Exception as e:
B
barrierye 已提交
908 909 910 911 912 913 914 915 916
            return ChannelData(
                ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
                error_info="rpc package error",
                data_id=data_id), data_id
        else:
            return ChannelData(
                datatype=ChannelDataType.CHANNEL_NPDATA.value,
                npdata=npdata,
                data_id=data_id), data_id
B
barrierye 已提交
917

918 919 920
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
B
barrierye 已提交
921
        resp.ecode = channeldata.ecode
B
bug fix  
barrierye 已提交
922
        if resp.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
923
            if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
924 925 926 927 928
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
B
barrierye 已提交
929 930 931
            elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
                                          ChannelDataType.CHANNEL_NPDATA.value):
                feed = channeldata.parse()
B
bug fix  
barrierye 已提交
932
                for name, var in feed.items():
933 934 935 936 937 938 939 940
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
B
barrierye 已提交
941
                    self._log("Error type({}) in datatype.".format(
B
barrierye 已提交
942
                        channeldata.datatype)))
B
barrierye 已提交
943
        else:
B
barrierye 已提交
944
            resp.error_info = channeldata.error_info
B
barrierye 已提交
945
        return resp
B
barrierye 已提交
946

B
barrierye 已提交
947
    def inference(self, request, context):
948
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
949
        data, data_id = self._pack_data_for_infer(request)
950
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
951

952
        resp_channeldata = None
B
barrierye 已提交
953 954
        for i in range(self._retry):
            logging.debug(self._log('push data'))
955 956 957
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
958 959

            logging.debug(self._log('wait for infer'))
960
            _profiler.record("{}-fetch_0".format(self.name))
961
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
962
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
963

B
barrierye 已提交
964
            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
965
                break
B
barrierye 已提交
966 967
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
B
barrierye 已提交
968
                    i + 1, resp_channeldata.error_info))
B
barrierye 已提交
969

970
        _profiler.record("{}-postpack_0".format(self.name))
971
        resp = self._pack_data_for_resp(resp_channeldata)
972
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
973
        _profiler.print_profile()
B
barrierye 已提交
974 975
        return resp

B
barrierye 已提交
976 977

class PyServer(object):
B
barrierye 已提交
978
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
979
        self._channels = []
980
        self._user_ops = []
B
bug fix  
barrierye 已提交
981
        self._actual_ops = []
B
barrierye 已提交
982 983
        self._port = None
        self._worker_num = None
B
barrierye 已提交
984 985
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
986
        self._retry = retry
B
barrierye 已提交
987
        self._manager = multiprocessing.Manager()
B
barrierye 已提交
988
        _profiler.enable(profile)
B
barrierye 已提交
989 990 991 992 993

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
994 995 996
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
997
        self._user_ops.extend(ops)
B
barrierye 已提交
998 999

    def gen_desc(self):
1000
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
1001 1002
        pass

1003 1004 1005
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
1006
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
1007 1008 1009 1010 1011
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
B
bug fix  
barrierye 已提交
1012
        zero_indeg_num, zero_outdeg_num = 0, 0
1013 1014 1015 1016 1017 1018 1019
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
B
bug fix  
barrierye 已提交
1020
                zero_indeg_num += 1
1021
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
1022
                outdegs[pred_op.name].append(op)
B
bug fix  
barrierye 已提交
1023 1024 1025 1026 1027 1028 1029
        if zero_indeg_num != 1:
            raise Exception("DAG contains multiple input Ops")
        for _, succ_list in outdegs.items():
            if len(succ_list) == 0:
                zero_outdeg_num += 1
        if zero_outdeg_num != 1:
            raise Exception("DAG contains multiple output Ops")
1030

B
bug fix  
barrierye 已提交
1031
        # topo sort to get dag_views
1032 1033 1034 1035 1036 1037 1038 1039 1040 1041
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
1042
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
1043
                    indeg_num[succ_op.name] -= 1
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
1065 1066 1067
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
1068
        actual_view = None
1069 1070 1071 1072
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
1073 1074
            if actual_view is None:
                actual_view = view
1075 1076
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
1077 1078
            for op in actual_view:
                # find actual succ op in next view and create virtual op
1079 1080
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
1081 1082
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
1083 1084 1085 1086
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
1087 1088
                        # create virtual op
                        virtual_op = None
B
barrierye 已提交
1089 1090 1091 1092 1093 1094 1095 1096
                        if sys.version_info.major == 2:
                            virtual_op = VirtualOp(
                                name=virtual_op_name_gen.next())
                        elif sys.version_info.major == 3:
                            virtual_op = VirtualOp(
                                name=virtual_op_name_gen.__next__())
                        else:
                            raise Exception("Error Python version")
1097
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
1098 1099 1100 1101 1102
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
1103 1104 1105
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
1106
                if op.name in processed_op:
1107
                    continue
B
barrierye 已提交
1108 1109 1110 1111 1112 1113 1114 1115
                if sys.version_info.major == 2:
                    channel = Channel(
                        self._manager, name=channel_name_gen.next())
                elif sys.version_info.major == 3:
                    channel = Channel(
                        self._manager, name=channel_name_gen.__next__())
                else:
                    raise Exception("Error Python version")
1116
                channels.append(channel)
B
bug fix  
barrierye 已提交
1117
                logging.debug("{} => {}".format(channel.name, op.name))
1118
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
1119
                pred_ops = pred_op_of_next_view_op[op.name]
1120 1121 1122
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
1123
                    # if pred_op is virtual op, it will use ancestors as producers to channel
1124
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
1125 1126
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
1127
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
1128 1129 1130
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
1142 1143
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
1144 1145
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
B
barrierye 已提交
1146 1147 1148 1149 1150 1151 1152 1153
        if sys.version_info.major == 2:
            output_channel = Channel(
                self._manager, name=channel_name_gen.next())
        elif sys.version_info.major == 3:
            output_channel = Channel(
                self._manager, name=channel_name_gen.__next__())
        else:
            raise Exception("Error Python version")
1154
        channels.append(output_channel)
B
bug fix  
barrierye 已提交
1155
        last_op = dag_views[-1][0]
1156 1157
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
1158
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
1159 1160
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
1161
                # pass read op
B
fix bug  
barrierye 已提交
1162
                continue
B
bug fix  
barrierye 已提交
1163
            self._actual_ops.append(op)
1164
        self._channels = channels
B
bug fix  
barrierye 已提交
1165 1166
        for c in channels:
            logging.debug(c.debug())
1167 1168
        return input_channel, output_channel

B
barrierye 已提交
1169 1170 1171
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
1172 1173 1174

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
1175
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
1176
        for op in self._actual_ops:
B
barrierye 已提交
1177
            if op.with_serving:
B
fix bug  
barrierye 已提交
1178
                self.prepare_serving(op)
B
barrierye 已提交
1179 1180
        self.gen_desc()

1181
    def _run_ops(self):
B
barrierye 已提交
1182
        proces = []
B
bug fix  
barrierye 已提交
1183
        for op in self._actual_ops:
B
barrierye 已提交
1184 1185
            proces.extend(op.start())
        return proces
1186

1187
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1188
        for op in self._actual_ops:
1189 1190
            op.stop()

1191
    def run_server(self):
B
barrierye 已提交
1192
        op_proces = self._run_ops()
B
barrierye 已提交
1193 1194
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1195
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1196 1197
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1198
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1199
        server.start()
1200 1201
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
barrierye 已提交
1202 1203
        for p in op_proces:
            p.join()
B
barrierye 已提交
1204 1205 1206 1207 1208 1209 1210

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
1211 1212
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1213
        else:
1214 1215
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
1216 1217
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))