pyserver.py 45.8 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import multiprocessing.queues
B
barrierye 已提交
18
import Queue
B
barrierye 已提交
19
import os
B
barrierye 已提交
20
import sys
B
barrierye 已提交
21
import paddle_serving_server
22
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
23
from paddle_serving_client import MultiLangPredictFuture
B
barrierye 已提交
24
from concurrent import futures
B
barrierye 已提交
25
import numpy as np
B
barrierye 已提交
26
import grpc
27 28 29 30
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
31
import logging
32
import random
B
barrierye 已提交
33
import time
B
barrierye 已提交
34
import func_timeout
35
import enum
36
import collections
B
barrierye 已提交
37 38


B
barrierye 已提交
39 40 41 42 43 44 45 46 47 48 49
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
50 51
        if self._enable is False:
            return
B
barrierye 已提交
52 53 54 55 56 57
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
58 59
        if self._enable is False:
            return
B
barrierye 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


79 80 81
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
82 83
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
B
barrierye 已提交
84 85
    RPC_PACKAGE_ERROR = 4
    UNKNOW = 5
86 87 88 89 90


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1
B
barrierye 已提交
91
    CHANNEL_NPDATA = 2
B
bug fix  
barrierye 已提交
92
    ERROR = 3
93 94 95 96


class ChannelData(object):
    def __init__(self,
B
barrierye 已提交
97
                 datatype=None,
98 99
                 future=None,
                 pbdata=None,
B
barrierye 已提交
100
                 npdata=None,
101
                 data_id=None,
B
barrierye 已提交
102 103 104 105 106 107
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
barrierye 已提交
108 109 110 111 112 113
        1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
        2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
        3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
        4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
        5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        6. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
114 115 116

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
B
barrierye 已提交
117 118 119 120
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
B
bug fix  
barrierye 已提交
121
            datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
122
        else:
B
barrierye 已提交
123
            if datatype == ChannelDataType.CHANNEL_FUTURE.value:
B
barrierye 已提交
124 125 126
                if data_id is None:
                    raise ValueError("data_id cannot be None")
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
127 128 129 130 131 132
            elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
133 134
                    if ecode != ChannelDataEcode.OK.value:
                        logging.error(error_info)
B
barrierye 已提交
135
                    else:
B
bug fix  
barrierye 已提交
136
                        for name, value in npdata.items():
B
barrierye 已提交
137 138 139 140 141 142 143 144 145
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
            elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
146 147
                if ecode != ChannelDataEcode.OK.value:
                    logging.error(error_info)
B
barrierye 已提交
148 149
            else:
                raise ValueError("datatype not match")
B
barrierye 已提交
150
        self.future = future
151
        self.pbdata = pbdata
B
barrierye 已提交
152 153
        self.npdata = npdata
        self.datatype = datatype
154
        self.callback_func = callback_func
B
barrierye 已提交
155 156 157
        self.id = data_id
        self.ecode = ecode
        self.error_info = error_info
158

B
barrierye 已提交
159 160 161 162 163 164 165 166 167 168
    def _check_npdata(self, npdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for name, value in npdata.items():
            if not isinstance(name, (str, unicode)):
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the key of postped_data must " \
                        "be str, but get {}".format(type(name)))
                break
            if not isinstance(value, np.ndarray):
B
barrierye 已提交
169 170
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the value of postped_data must " \
B
barrierye 已提交
171 172 173 174
                        "be np.ndarray, but get {}".format(type(value)))
                break
        return ecode, error_info

175 176
    def parse(self):
        # return narray
B
barrierye 已提交
177 178 179
        feed = None
        if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
            feed = {}
180 181 182
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
B
barrierye 已提交
183
        elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
184 185 186
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
B
barrierye 已提交
187 188
        elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            feed = self.npdata
189
        else:
B
barrierye 已提交
190
            raise TypeError("Error type({}) in datatype.".format(self.datatype))
191 192
        return feed

B
barrierye 已提交
193
    def __str__(self):
B
barrierye 已提交
194 195
        return "type[{}], ecode[{}], id[{}]".format(
            ChannelDataType(self.datatype).name, self.ecode, self.id)
B
barrierye 已提交
196

197

B
barrierye 已提交
198
class Channel(multiprocessing.queues.Queue):
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
216 217
    def __init__(self, manager, name=None, maxsize=0, timeout=None):
        # https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5/
B
barrierye 已提交
218
        multiprocessing.queues.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
219 220
        self._maxsize = maxsize
        self._timeout = timeout
221
        self.name = name
222
        self._stop = False
223

B
barrierye 已提交
224
        self._cv = multiprocessing.Condition()
225 226

        self._producers = []
B
barrierye 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239
        self._producer_res_count = manager.dict()  # {data_id: count}
        # self._producer_res_count = {}  # {data_id: count}
        self._push_res = manager.dict()  # {data_id: {op_name: data}}
        # self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = manager.dict()  # {op_name: idx}
        # self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = manager.dict()  # {idx: num}
        # self._idx_consumer_num = {}  # {idx: num}
        self._consumer_base_idx = manager.Value('i', 0)
        # self._consumer_base_idx = 0
        self._front_res = manager.list()
        # self._front_res = []
240 241 242 243 244 245 246 247

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
248
        return "[{}] {}".format(self.name, info_str)
249 250 251 252 253 254

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
255
        """ not thread safe, and can only be called during initialization. """
256 257 258 259
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
260 261

    def add_consumer(self, op_name):
B
barrierye 已提交
262
        """ not thread safe, and can only be called during initialization. """
263 264 265 266
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
267 268 269 270

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
271

272
    def push(self, channeldata, op_name=None):
273
        logging.debug(
274
            self._log("{} try to push data: {}".format(op_name,
B
barrierye 已提交
275
                                                       channeldata.__str__())))
276
        if len(self._producers) == 0:
277
            raise Exception(
278 279 280 281
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
282
            with self._cv:
283
                while self._stop is False:
B
barrierye 已提交
284
                    try:
285
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
286
                        break
B
barrierye 已提交
287
                    except Queue.Full:
B
barrierye 已提交
288
                        self._cv.wait()
B
barrierye 已提交
289 290 291
                logging.debug(
                    self._log("{} channel size: {}".format(op_name,
                                                           self.qsize())))
B
barrierye 已提交
292
                self._cv.notify_all()
B
barrierye 已提交
293
                logging.debug(self._log("{} notify all".format(op_name)))
294 295 296 297 298 299
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
300

301
        producer_num = len(self._producers)
B
barrierye 已提交
302
        data_id = channeldata.id
303
        put_data = None
B
barrierye 已提交
304
        with self._cv:
305
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
306 307 308 309 310 311
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
B
barrierye 已提交
312 313 314 315 316 317
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
            # self._push_res[data_id][op_name] = channeldata
            tmp_push_res = self._push_res[data_id]
            tmp_push_res[op_name] = channeldata
            self._push_res[data_id] = tmp_push_res

B
barrierye 已提交
318 319 320 321 322 323
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
324

B
barrierye 已提交
325 326
            if put_data is None:
                logging.debug(
327
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
328 329
                              format(op_name)))
            else:
330
                while self._stop is False:
B
barrierye 已提交
331
                    try:
B
barrierye 已提交
332 333 334
                        logging.debug(
                            self._log("{} push data succ: {}".format(
                                op_name, put_data.__str__())))
B
barrierye 已提交
335 336 337 338 339 340 341 342
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
343
        return True
344

345
    def front(self, op_name=None):
B
barrierye 已提交
346
        logging.debug(self._log("{} try to get data...".format(op_name)))
347 348 349 350 351 352 353
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
354
            with self._cv:
355
                while self._stop is False and resp is None:
B
barrierye 已提交
356
                    try:
B
barrierye 已提交
357 358 359 360 361 362
                        logging.debug(
                            self._log("{} try to get(with channel size: {})".
                                      format(op_name, self.qsize())))
                        #TODO: bug to fix
                        # (multiple processes) the queue is not empty, but it raise Queue.Empty                   
                        resp = self.get(timeout=1e-3)
B
barrierye 已提交
363 364
                        break
                    except Queue.Empty:
B
barrierye 已提交
365 366 367 368
                        logging.debug(
                            self._log(
                                "{} wait for empty queue(with channel size: {})".
                                format(op_name, self.qsize())))
B
barrierye 已提交
369
                        self._cv.wait()
B
barrierye 已提交
370 371 372
            logging.debug(
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
373 374 375 376 377
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
378

B
barrierye 已提交
379 380
        with self._cv:
            # data_idx = consumer_idx - base_idx
381
            while self._stop is False and self._consumers[
B
barrierye 已提交
382 383 384 385 386 387 388
                    op_name] - self._consumer_base_idx.value >= len(
                        self._front_res):
                logging.debug(
                    self._log(
                        "({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
                        format(op_name, self._consumers, self.
                               _consumer_base_idx.value, len(self._front_res))))
B
barrierye 已提交
389
                try:
B
barrierye 已提交
390 391 392 393 394 395
                    logging.debug(
                        self._log("{} try to get(with channel size: {})".format(
                            op_name, self.qsize())))
                    #TODO: bug to fix
                    # (multiple processes) the queue is not empty, but it raise Queue.Empty
                    channeldata = self.get(timeout=1e-3)
396
                    self._front_res.append(channeldata)
B
barrierye 已提交
397 398
                    break
                except Queue.Empty:
B
barrierye 已提交
399 400 401 402
                    logging.debug(
                        self._log(
                            "{} wait for empty queue(with channel size: {})".
                            format(op_name, self.qsize())))
B
barrierye 已提交
403
                    self._cv.wait()
404

B
barrierye 已提交
405
            consumer_idx = self._consumers[op_name]
B
barrierye 已提交
406
            base_idx = self._consumer_base_idx.value
B
barrierye 已提交
407 408 409 410 411 412 413 414 415
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
B
barrierye 已提交
416
                self._consumer_base_idx.value += 1
B
barrierye 已提交
417 418 419 420 421 422

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
B
barrierye 已提交
423 424 425 426 427 428
            logging.debug(
                self._log(
                    "({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
                    format(op_name, self._consumers, self._consumer_base_idx.
                           value, len(self._front_res))))
            logging.debug(self._log("{} notify all".format(op_name)))
B
barrierye 已提交
429
            self._cv.notify_all()
430

431
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
432
        return resp  # reference, read only
B
barrierye 已提交
433

434 435 436 437
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
438
        self._cv.notify_all()
439

B
barrierye 已提交
440 441 442

class Op(object):
    def __init__(self,
443
                 name,
444
                 inputs,
B
barrierye 已提交
445 446 447 448 449
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
450
                 fetch_names=None,
B
barrierye 已提交
451
                 concurrency=1,
B
barrierye 已提交
452 453
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
454
        self._is_run = False
455
        self.name = name  # to identify the type of OP, it must be globally unique
456
        self._concurrency = concurrency  # amount of concurrency
457
        self.set_input_ops(inputs)
B
barrierye 已提交
458
        self._timeout = timeout
B
bug fix  
barrierye 已提交
459
        self._retry = max(1, retry)
460 461
        self._input = None
        self._outputs = []
B
barrierye 已提交
462

B
barrierye 已提交
463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
        self.with_serving = False
        self._client_config = client_config
        self._server_name = server_name
        self._fetch_names = fetch_names
        self._server_model = server_model
        self._server_port = server_port
        self._device = device
        if self._client_config is not None and \
                self._server_name is not None and \
                self._fetch_names is not None and \
                self._server_model is not None and \
                self._server_port is not None and \
                self._device is not None:
            self.with_serving = True

    def init_client(self, client_config, server_name, fetch_names):
B
barrierye 已提交
479 480
        if self.with_serving == False:
            logging.debug("{} no client".format(self.name))
481
            return
B
barrierye 已提交
482 483 484
        logging.debug("{} client_config: {}".format(self.name, client_config))
        logging.debug("{} server_name: {}".format(self.name, server_name))
        logging.debug("{} fetch_names: {}".format(self.name, fetch_names))
B
barrierye 已提交
485 486 487 488 489
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

490
    def get_input_channel(self):
491
        return self._input
B
barrierye 已提交
492

493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
508 509 510 511
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
512
        channel.add_consumer(self.name)
513
        self._input = channel
B
barrierye 已提交
514

515
    def get_output_channels(self):
B
barrierye 已提交
516 517
        return self._outputs

518 519
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
520
            raise TypeError(
521 522 523 524
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
525

526 527
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
528 529 530
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
531
        feed = channeldata.parse()
532
        return feed
B
barrierye 已提交
533 534

    def midprocess(self, data):
535 536 537 538 539 540 541
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
542 543 544 545
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
546 547

    def postprocess(self, output_data):
B
barrierye 已提交
548
        return output_data
B
barrierye 已提交
549 550

    def stop(self):
551 552 553
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
554
        self._is_run = False
B
barrierye 已提交
555

556
    def _parse_channeldata(self, channeldata):
B
barrierye 已提交
557
        data_id, error_channeldata = None, None
558 559 560
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
B
barrierye 已提交
561
            data_id = channeldata[key].id
562
            for _, data in channeldata.items():
B
barrierye 已提交
563 564
                if data.ecode != ChannelDataEcode.OK.value:
                    error_channeldata = data
565 566
                    break
        else:
B
barrierye 已提交
567 568 569 570
            data_id = channeldata.id
            if channeldata.ecode != ChannelDataEcode.OK.value:
                error_channeldata = channeldata
        return data_id, error_channeldata
571

B
barrierye 已提交
572
    def _push_to_output_channels(self, data, channels, name=None):
B
bug fix  
barrierye 已提交
573 574
        if name is None:
            name = self.name
B
barrierye 已提交
575
        for channel in channels:
B
bug fix  
barrierye 已提交
576
            channel.push(data, name)
B
barrierye 已提交
577

B
barrierye 已提交
578 579 580 581 582 583 584 585 586 587 588
    def start(self):
        proces = []
        for concurrency_idx in range(self._concurrency):
            p = multiprocessing.Process(
                target=self._run,
                args=(concurrency_idx, self.get_input_channel(),
                      self.get_output_channels()))
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
589
    def _run(self, concurrency_idx, input_channel, output_channels):
B
barrierye 已提交
590 591
        self.init_client(self._client_config, self._server_name,
                         self._fetch_names)
B
bug fix  
barrierye 已提交
592
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
593
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
594 595
        self._is_run = True
        while self._is_run:
B
barrierye 已提交
596
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
597
            channeldata = input_channel.front(self.name)
B
barrierye 已提交
598
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
599
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
600

B
barrierye 已提交
601
            data_id, error_channeldata = self._parse_channeldata(channeldata)
602

B
bug fix  
barrierye 已提交
603
            # error data in predecessor Op
B
barrierye 已提交
604 605 606
            if error_channeldata is not None:
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
B
barrierye 已提交
607 608
                continue

B
bug fix  
barrierye 已提交
609
            # preprecess
B
barrierye 已提交
610 611
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
612
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
613 614
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
615
                # preprocess function not implemented
B
barrierye 已提交
616 617 618 619 620 621
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
B
barrierye 已提交
622 623
                        data_id=data_id),
                    output_channels)
B
barrierye 已提交
624
                continue
B
bug fix  
barrierye 已提交
625
            except TypeError as e:
B
barrierye 已提交
626
                # Error type in channeldata.datatype
B
bug fix  
barrierye 已提交
627 628 629 630 631 632
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
B
barrierye 已提交
633 634
                        data_id=data_id),
                    output_channels)
B
bug fix  
barrierye 已提交
635 636 637 638 639 640
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
B
barrierye 已提交
641
                        ecode=ChannelDataEcode.UNKNOW.value,
B
bug fix  
barrierye 已提交
642
                        error_info=error_info,
B
barrierye 已提交
643 644
                        data_id=data_id),
                    output_channels)
B
bug fix  
barrierye 已提交
645
                continue
646

B
barrierye 已提交
647 648
            # midprocess
            call_future = None
B
barrierye 已提交
649
            if self.with_serving:
B
bug fix  
barrierye 已提交
650
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
651 652 653
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
B
bug fix  
barrierye 已提交
654
                        call_future = self.midprocess(preped_data)
B
barrierye 已提交
655 656 657 658
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
659
                else:
B
barrierye 已提交
660 661 662
                    for i in range(self._retry):
                        try:
                            call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
663 664 665 666
                                self._timeout,
                                self.midprocess,
                                args=(preped_data, ))
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
667 668
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
669 670
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
671 672
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
673
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
674 675 676 677 678 679 680
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
681
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
682 683 684
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
685 686
                            data_id=data_id),
                        output_channels)
B
barrierye 已提交
687 688
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
689

B
barrierye 已提交
690 691 692
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
B
barrierye 已提交
693
            if self.with_serving:
B
bug fix  
barrierye 已提交
694
                # use call_future
B
barrierye 已提交
695
                output_data = ChannelData(
B
barrierye 已提交
696
                    datatype=ChannelDataType.CHANNEL_FUTURE.value,
B
barrierye 已提交
697 698 699
                    future=call_future,
                    data_id=data_id,
                    callback_func=self.postprocess)
B
barrierye 已提交
700 701 702 703 704 705 706 707 708
                #TODO: for future are not picklable
                npdata = self.postprocess(call_future.result())
                self._push_to_output_channels(
                    ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=npdata,
                        data_id=data_id),
                    output_channels)
                continue
B
barrierye 已提交
709
            else:
B
bug fix  
barrierye 已提交
710 711 712 713 714 715 716 717 718
                try:
                    postped_data = self.postprocess(preped_data)
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
719 720
                            data_id=data_id),
                        output_channels)
B
bug fix  
barrierye 已提交
721 722
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
723 724
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
725
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
726 727 728 729
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
730 731
                            data_id=data_id),
                        output_channels)
B
barrierye 已提交
732
                    continue
B
bug fix  
barrierye 已提交
733

B
barrierye 已提交
734 735 736 737
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
738 739 740 741
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
B
barrierye 已提交
742
            self._push_to_output_channels(output_data, output_channels)
B
barrierye 已提交
743 744 745 746 747 748 749 750 751 752
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
753

754 755 756
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
757

B
bug fix  
barrierye 已提交
758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

B
barrierye 已提交
778
    def _run(self, input_channel, output_channels):
B
bug fix  
barrierye 已提交
779 780
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
781 782
        self._is_run = True
        while self._is_run:
B
bug fix  
barrierye 已提交
783
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
784
            channeldata = input_channel.front(self.name)
B
bug fix  
barrierye 已提交
785 786 787 788 789
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
B
barrierye 已提交
790 791
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
B
bug fix  
barrierye 已提交
792
            else:
B
barrierye 已提交
793 794 795 796
                self._push_to_output_channels(
                    channeldata,
                    channels=output_channels,
                    name=self._virtual_pred_ops[0].name)
B
bug fix  
barrierye 已提交
797 798 799
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
800 801
class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
802
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
803
        super(GeneralPythonService, self).__init__()
804
        self.name = "#G"
805 806
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
807 808
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
809 810 811 812 813
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
814 815
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
816
        self._retry = retry
B
barrierye 已提交
817 818 819
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
820 821

    def _log(self, info_str):
822
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
823

824
    def set_in_channel(self, in_channel):
825 826 827 828
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
829
        in_channel.add_producer(self.name)
830 831 832
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
833 834 835 836
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
837
        out_channel.add_consumer(self.name)
838 839
        self._out_channel = out_channel

B
barrierye 已提交
840 841
    def _recive_out_channel_func(self):
        while True:
842
            channeldata = self._out_channel.front(self.name)
843
            if not isinstance(channeldata, ChannelData):
844 845
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
846
                              format(type(channeldata))))
B
barrierye 已提交
847
            with self._cv:
B
barrierye 已提交
848
                data_id = channeldata.id
849
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
850
                self._cv.notify_all()
B
barrierye 已提交
851 852

    def _get_next_id(self):
B
barrierye 已提交
853
        with self._id_lock:
B
barrierye 已提交
854 855 856 857
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
858 859 860 861 862 863
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
864
        return resp
B
barrierye 已提交
865 866

    def _pack_data_for_infer(self, request):
867
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
868
        data_id = self._get_next_id()
B
barrierye 已提交
869
        npdata = {}
B
barrierye 已提交
870 871 872 873 874 875
        try:
            for idx, name in enumerate(request.feed_var_names):
                logging.debug(
                    self._log('name: {}'.format(request.feed_var_names[idx])))
                logging.debug(
                    self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
876 877 878 879
                npdata[name] = np.frombuffer(
                    request.feed_insts[idx], dtype=request.type[idx])
                npdata[name].shape = np.frombuffer(
                    request.shape[idx], dtype="int32")
B
barrierye 已提交
880
        except Exception as e:
B
barrierye 已提交
881 882 883 884 885 886 887 888 889
            return ChannelData(
                ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
                error_info="rpc package error",
                data_id=data_id), data_id
        else:
            return ChannelData(
                datatype=ChannelDataType.CHANNEL_NPDATA.value,
                npdata=npdata,
                data_id=data_id), data_id
B
barrierye 已提交
890

891 892 893
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
B
barrierye 已提交
894
        resp.ecode = channeldata.ecode
B
bug fix  
barrierye 已提交
895
        if resp.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
896
            if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
897 898 899 900 901
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
B
barrierye 已提交
902 903 904
            elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
                                          ChannelDataType.CHANNEL_NPDATA.value):
                feed = channeldata.parse()
B
bug fix  
barrierye 已提交
905
                for name, var in feed.items():
906 907 908 909 910 911 912 913
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
B
barrierye 已提交
914
                    self._log("Error type({}) in datatype.".format(
B
barrierye 已提交
915
                        channeldata.datatype)))
B
barrierye 已提交
916
        else:
B
barrierye 已提交
917
            resp.error_info = channeldata.error_info
B
barrierye 已提交
918
        return resp
B
barrierye 已提交
919

B
barrierye 已提交
920
    def inference(self, request, context):
921
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
922
        data, data_id = self._pack_data_for_infer(request)
923
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
924

925
        resp_channeldata = None
B
barrierye 已提交
926 927
        for i in range(self._retry):
            logging.debug(self._log('push data'))
928 929 930
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
931 932

            logging.debug(self._log('wait for infer'))
933
            _profiler.record("{}-fetch_0".format(self.name))
934
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
935
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
936

B
barrierye 已提交
937
            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
938
                break
B
barrierye 已提交
939 940
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
B
barrierye 已提交
941
                    i + 1, resp_channeldata.error_info))
B
barrierye 已提交
942

943
        _profiler.record("{}-postpack_0".format(self.name))
944
        resp = self._pack_data_for_resp(resp_channeldata)
945
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
946
        _profiler.print_profile()
B
barrierye 已提交
947 948
        return resp

B
barrierye 已提交
949 950

class PyServer(object):
B
barrierye 已提交
951
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
952
        self._channels = []
953
        self._user_ops = []
B
bug fix  
barrierye 已提交
954
        self._actual_ops = []
B
barrierye 已提交
955 956
        self._port = None
        self._worker_num = None
B
barrierye 已提交
957 958
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
959
        self._retry = retry
B
barrierye 已提交
960
        self._manager = multiprocessing.Manager()
B
barrierye 已提交
961
        _profiler.enable(profile)
B
barrierye 已提交
962 963 964 965 966

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
967 968 969
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
970
        self._user_ops.extend(ops)
B
barrierye 已提交
971 972

    def gen_desc(self):
973
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
974 975
        pass

976 977 978
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
979
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
980 981 982 983 984
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
985 986 987 988 989 990 991 992
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
993
                outdegs[pred_op.name].append(op)
994

B
bug fix  
barrierye 已提交
995
        # topo sort to get dag_views
996 997 998 999 1000 1001 1002 1003 1004 1005
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
1006
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
1007
                    indeg_num[succ_op.name] -= 1
1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
1033 1034 1035
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
1036
        actual_view = None
1037 1038 1039 1040
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
1041 1042
            if actual_view is None:
                actual_view = view
1043 1044
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
1045 1046
            for op in actual_view:
                # find actual succ op in next view and create virtual op
1047 1048
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
1049 1050
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
1051 1052 1053 1054
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
1055 1056 1057
                        # create virtual op
                        virtual_op = None
                        virtual_op = VirtualOp(name=virtual_op_name_gen.next())
1058
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
1059 1060 1061 1062 1063
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
1064 1065 1066
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
1067
                if op.name in processed_op:
1068
                    continue
B
barrierye 已提交
1069
                channel = Channel(self._manager, name=channel_name_gen.next())
1070
                channels.append(channel)
B
bug fix  
barrierye 已提交
1071
                logging.debug("{} => {}".format(channel.name, op.name))
1072
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
1073
                pred_ops = pred_op_of_next_view_op[op.name]
1074 1075 1076
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
1077
                    # if pred_op is virtual op, it will use ancestors as producers to channel
1078
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
1079 1080
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
1081
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
1082 1083 1084
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
1096 1097
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
1098 1099
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
B
barrierye 已提交
1100
        output_channel = Channel(self._manager, name=channel_name_gen.next())
1101 1102 1103 1104
        channels.append(output_channel)
        last_op = dag_views[-1][0]
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
1105
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
1106 1107
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
1108
                # pass read op
B
fix bug  
barrierye 已提交
1109
                continue
B
bug fix  
barrierye 已提交
1110
            self._actual_ops.append(op)
1111
        self._channels = channels
B
bug fix  
barrierye 已提交
1112 1113
        for c in channels:
            logging.debug(c.debug())
1114 1115
        return input_channel, output_channel

B
barrierye 已提交
1116 1117 1118
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
1119 1120 1121

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
1122
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
1123
        for op in self._actual_ops:
B
barrierye 已提交
1124
            if op.with_serving:
B
fix bug  
barrierye 已提交
1125
                self.prepare_serving(op)
B
barrierye 已提交
1126 1127
        self.gen_desc()

1128
    def _run_ops(self):
B
barrierye 已提交
1129
        proces = []
B
bug fix  
barrierye 已提交
1130
        for op in self._actual_ops:
B
barrierye 已提交
1131 1132
            proces.extend(op.start())
        return proces
1133

1134
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1135
        for op in self._actual_ops:
1136 1137
            op.stop()

1138
    def run_server(self):
B
barrierye 已提交
1139
        op_proces = self._run_ops()
B
barrierye 已提交
1140 1141
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1142
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1143 1144
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1145
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1146
        server.start()
1147 1148
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
barrierye 已提交
1149 1150
        for p in op_proces:
            p.join()
B
barrierye 已提交
1151 1152 1153 1154 1155 1156 1157

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
1158 1159
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1160
        else:
1161 1162
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
1163 1164
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))