pyserver.py 47.0 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import multiprocessing.queues
B
barrierye 已提交
18
import Queue
B
barrierye 已提交
19
import os
B
barrierye 已提交
20
import sys
B
barrierye 已提交
21
import paddle_serving_server
22
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
23
from paddle_serving_client import MultiLangPredictFuture
B
barrierye 已提交
24
from concurrent import futures
B
barrierye 已提交
25
import numpy as np
B
barrierye 已提交
26
import grpc
27 28 29 30
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
31
import logging
32
import random
B
barrierye 已提交
33
import time
B
barrierye 已提交
34
import func_timeout
35
import enum
36
import collections
B
barrierye 已提交
37 38


B
barrierye 已提交
39 40 41 42 43 44 45 46 47 48 49
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
50 51
        if self._enable is False:
            return
B
barrierye 已提交
52 53 54 55 56 57
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
58 59
        if self._enable is False:
            return
B
barrierye 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


79 80 81
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
82 83
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
B
barrierye 已提交
84 85
    RPC_PACKAGE_ERROR = 4
    UNKNOW = 5
86 87 88 89 90


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1
B
barrierye 已提交
91
    CHANNEL_NPDATA = 2
B
bug fix  
barrierye 已提交
92
    ERROR = 3
93 94 95 96


class ChannelData(object):
    def __init__(self,
B
barrierye 已提交
97
                 datatype=None,
98 99
                 future=None,
                 pbdata=None,
B
barrierye 已提交
100
                 npdata=None,
101
                 data_id=None,
B
barrierye 已提交
102 103 104 105 106 107
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
barrierye 已提交
108 109 110 111 112 113
        1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
        2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
        3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
        4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
        5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        6. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
114 115 116

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
B
barrierye 已提交
117 118 119 120
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
B
bug fix  
barrierye 已提交
121
            datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
122
        else:
B
barrierye 已提交
123
            if datatype == ChannelDataType.CHANNEL_FUTURE.value:
B
barrierye 已提交
124 125 126
                if data_id is None:
                    raise ValueError("data_id cannot be None")
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
127 128 129 130 131 132
            elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
133 134
                    if ecode != ChannelDataEcode.OK.value:
                        logging.error(error_info)
B
barrierye 已提交
135
                    else:
B
bug fix  
barrierye 已提交
136
                        for name, value in npdata.items():
B
barrierye 已提交
137 138 139 140 141 142 143 144 145
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
            elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
146 147
                if ecode != ChannelDataEcode.OK.value:
                    logging.error(error_info)
B
barrierye 已提交
148 149
            else:
                raise ValueError("datatype not match")
B
barrierye 已提交
150
        self.future = future
151
        self.pbdata = pbdata
B
barrierye 已提交
152 153
        self.npdata = npdata
        self.datatype = datatype
154
        self.callback_func = callback_func
B
barrierye 已提交
155 156 157
        self.id = data_id
        self.ecode = ecode
        self.error_info = error_info
158

B
barrierye 已提交
159 160 161 162 163 164 165 166 167 168
    def _check_npdata(self, npdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for name, value in npdata.items():
            if not isinstance(name, (str, unicode)):
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the key of postped_data must " \
                        "be str, but get {}".format(type(name)))
                break
            if not isinstance(value, np.ndarray):
B
barrierye 已提交
169 170
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the value of postped_data must " \
B
barrierye 已提交
171 172 173 174
                        "be np.ndarray, but get {}".format(type(value)))
                break
        return ecode, error_info

175 176
    def parse(self):
        # return narray
B
barrierye 已提交
177 178 179
        feed = None
        if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
            feed = {}
180 181 182
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
B
barrierye 已提交
183
        elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
184 185 186
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
B
barrierye 已提交
187 188
        elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            feed = self.npdata
189
        else:
B
barrierye 已提交
190
            raise TypeError("Error type({}) in datatype.".format(self.datatype))
191 192
        return feed

B
barrierye 已提交
193
    def __str__(self):
B
barrierye 已提交
194 195
        return "type[{}], ecode[{}], id[{}]".format(
            ChannelDataType(self.datatype).name, self.ecode, self.id)
B
barrierye 已提交
196

197

B
barrierye 已提交
198
class Channel(multiprocessing.queues.Queue):
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
216 217
    def __init__(self, manager, name=None, maxsize=0, timeout=None):
        # https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5/
218 219 220 221 222 223 224
        if sys.version_info.major == 2:
            super(Channel, self).__init__(maxsize=maxsize)
        elif sys.version_info.major == 3:
            super(Channel, self).__init__(
                maxsize=maxsize, ctx=multiprocessing.get_context())
        else:
            raise Exception("Error Python version")
B
barrierye 已提交
225 226
        self._maxsize = maxsize
        self._timeout = timeout
227
        self.name = name
228
        self._stop = False
229

B
barrierye 已提交
230
        self._cv = multiprocessing.Condition()
231 232

        self._producers = []
B
barrierye 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245
        self._producer_res_count = manager.dict()  # {data_id: count}
        # self._producer_res_count = {}  # {data_id: count}
        self._push_res = manager.dict()  # {data_id: {op_name: data}}
        # self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = manager.dict()  # {op_name: idx}
        # self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = manager.dict()  # {idx: num}
        # self._idx_consumer_num = {}  # {idx: num}
        self._consumer_base_idx = manager.Value('i', 0)
        # self._consumer_base_idx = 0
        self._front_res = manager.list()
        # self._front_res = []
246 247 248 249 250 251 252 253

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
254
        return "[{}] {}".format(self.name, info_str)
255 256 257 258 259 260

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
261
        """ not thread safe, and can only be called during initialization. """
262 263 264 265
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
266 267

    def add_consumer(self, op_name):
B
barrierye 已提交
268
        """ not thread safe, and can only be called during initialization. """
269 270 271 272
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
273 274 275 276

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
277

278
    def push(self, channeldata, op_name=None):
279
        logging.debug(
280
            self._log("{} try to push data: {}".format(op_name,
B
barrierye 已提交
281
                                                       channeldata.__str__())))
282
        if len(self._producers) == 0:
283
            raise Exception(
284 285 286 287
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
288
            with self._cv:
289
                while self._stop is False:
B
barrierye 已提交
290
                    try:
291
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
292
                        break
B
barrierye 已提交
293
                    except Queue.Full:
B
barrierye 已提交
294
                        self._cv.wait()
B
barrierye 已提交
295 296 297
                logging.debug(
                    self._log("{} channel size: {}".format(op_name,
                                                           self.qsize())))
B
barrierye 已提交
298
                self._cv.notify_all()
B
barrierye 已提交
299
                logging.debug(self._log("{} notify all".format(op_name)))
300 301 302 303 304 305
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
306

307
        producer_num = len(self._producers)
B
barrierye 已提交
308
        data_id = channeldata.id
309
        put_data = None
B
barrierye 已提交
310
        with self._cv:
311
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
312 313 314 315 316 317
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
B
barrierye 已提交
318 319 320 321 322 323
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
            # self._push_res[data_id][op_name] = channeldata
            tmp_push_res = self._push_res[data_id]
            tmp_push_res[op_name] = channeldata
            self._push_res[data_id] = tmp_push_res

B
barrierye 已提交
324 325 326 327 328 329
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
330

B
barrierye 已提交
331 332
            if put_data is None:
                logging.debug(
333
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
334 335
                              format(op_name)))
            else:
336
                while self._stop is False:
B
barrierye 已提交
337
                    try:
B
barrierye 已提交
338 339 340
                        logging.debug(
                            self._log("{} push data succ: {}".format(
                                op_name, put_data.__str__())))
B
barrierye 已提交
341 342 343 344 345 346 347 348
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
349
        return True
350

351
    def front(self, op_name=None):
B
barrierye 已提交
352
        logging.debug(self._log("{} try to get data...".format(op_name)))
353 354 355 356 357 358 359
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
360
            with self._cv:
361
                while self._stop is False and resp is None:
B
barrierye 已提交
362
                    try:
B
barrierye 已提交
363
                        logging.debug(
364 365 366 367 368 369 370 371 372 373 374 375 376
                            self._log("{} try to get(with channel empty: {})".
                                      format(op_name, self.empty())))
                        # For Python2, after putting an object on an empty queue there may
                        # be an infinitessimal delay before the queue's :meth:`~Queue.empty`
                        # see more:
                        # - https://bugs.python.org/issue18277
                        # - https://hg.python.org/cpython/rev/860fc6a2bd21
                        if sys.version_info.major == 2:
                            resp = self.get(timeout=1e-3)
                        elif sys.version_info.major == 3:
                            resp = self.get(timeout=0)
                        else:
                            raise Exception("Error Python version")
B
barrierye 已提交
377 378
                        break
                    except Queue.Empty:
B
barrierye 已提交
379 380
                        logging.debug(
                            self._log(
381 382
                                "{} wait for empty queue(with channel empty: {})".
                                format(op_name, self.empty())))
B
barrierye 已提交
383
                        self._cv.wait()
B
barrierye 已提交
384 385 386
            logging.debug(
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
387 388 389 390 391
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
392

B
barrierye 已提交
393 394
        with self._cv:
            # data_idx = consumer_idx - base_idx
395
            while self._stop is False and self._consumers[
B
barrierye 已提交
396 397 398 399 400 401 402
                    op_name] - self._consumer_base_idx.value >= len(
                        self._front_res):
                logging.debug(
                    self._log(
                        "({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
                        format(op_name, self._consumers, self.
                               _consumer_base_idx.value, len(self._front_res))))
B
barrierye 已提交
403
                try:
B
barrierye 已提交
404 405 406
                    logging.debug(
                        self._log("{} try to get(with channel size: {})".format(
                            op_name, self.qsize())))
407 408 409 410 411 412 413 414 415 416 417
                    # For Python2, after putting an object on an empty queue there may
                    # be an infinitessimal delay before the queue's :meth:`~Queue.empty`
                    # see more:
                    # - https://bugs.python.org/issue18277
                    # - https://hg.python.org/cpython/rev/860fc6a2bd21
                    if sys.version_info.major == 2:
                        channeldata = self.get(timeout=1e-3)
                    elif sys.version_info.major == 3:
                        channeldata = self.get(timeout=0)
                    else:
                        raise Exception("Error Python version")
418
                    self._front_res.append(channeldata)
B
barrierye 已提交
419 420
                    break
                except Queue.Empty:
B
barrierye 已提交
421 422 423 424
                    logging.debug(
                        self._log(
                            "{} wait for empty queue(with channel size: {})".
                            format(op_name, self.qsize())))
B
barrierye 已提交
425
                    self._cv.wait()
426

B
barrierye 已提交
427
            consumer_idx = self._consumers[op_name]
B
barrierye 已提交
428
            base_idx = self._consumer_base_idx.value
B
barrierye 已提交
429 430 431 432 433 434 435 436 437
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
B
barrierye 已提交
438
                self._consumer_base_idx.value += 1
B
barrierye 已提交
439 440 441 442 443 444

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
B
barrierye 已提交
445 446 447 448 449 450
            logging.debug(
                self._log(
                    "({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
                    format(op_name, self._consumers, self._consumer_base_idx.
                           value, len(self._front_res))))
            logging.debug(self._log("{} notify all".format(op_name)))
B
barrierye 已提交
451
            self._cv.notify_all()
452

453
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
454
        return resp  # reference, read only
B
barrierye 已提交
455

456 457 458 459
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
460
        self._cv.notify_all()
461

B
barrierye 已提交
462 463 464

class Op(object):
    def __init__(self,
465
                 name,
466
                 inputs,
B
barrierye 已提交
467 468 469 470 471
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
472
                 fetch_names=None,
B
barrierye 已提交
473
                 concurrency=1,
B
barrierye 已提交
474 475
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
476
        self._is_run = False
477
        self.name = name  # to identify the type of OP, it must be globally unique
478
        self._concurrency = concurrency  # amount of concurrency
479
        self.set_input_ops(inputs)
B
barrierye 已提交
480
        self._timeout = timeout
B
bug fix  
barrierye 已提交
481
        self._retry = max(1, retry)
482 483
        self._input = None
        self._outputs = []
B
barrierye 已提交
484

B
barrierye 已提交
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
        self.with_serving = False
        self._client_config = client_config
        self._server_name = server_name
        self._fetch_names = fetch_names
        self._server_model = server_model
        self._server_port = server_port
        self._device = device
        if self._client_config is not None and \
                self._server_name is not None and \
                self._fetch_names is not None and \
                self._server_model is not None and \
                self._server_port is not None and \
                self._device is not None:
            self.with_serving = True

    def init_client(self, client_config, server_name, fetch_names):
B
barrierye 已提交
501 502
        if self.with_serving == False:
            logging.debug("{} no client".format(self.name))
503
            return
B
barrierye 已提交
504 505 506
        logging.debug("{} client_config: {}".format(self.name, client_config))
        logging.debug("{} server_name: {}".format(self.name, server_name))
        logging.debug("{} fetch_names: {}".format(self.name, fetch_names))
B
barrierye 已提交
507 508 509 510 511
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

512
    def get_input_channel(self):
513
        return self._input
B
barrierye 已提交
514

515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
530 531 532 533
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
534
        channel.add_consumer(self.name)
535
        self._input = channel
B
barrierye 已提交
536

537
    def get_output_channels(self):
B
barrierye 已提交
538 539
        return self._outputs

540 541
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
542
            raise TypeError(
543 544 545 546
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
547

548 549
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
550 551 552
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
553
        feed = channeldata.parse()
554
        return feed
B
barrierye 已提交
555 556

    def midprocess(self, data):
557 558 559 560 561 562 563
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
564 565 566 567
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
568 569

    def postprocess(self, output_data):
B
barrierye 已提交
570
        return output_data
B
barrierye 已提交
571 572

    def stop(self):
573 574 575
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
576
        self._is_run = False
B
barrierye 已提交
577

578
    def _parse_channeldata(self, channeldata):
B
barrierye 已提交
579
        data_id, error_channeldata = None, None
580 581 582
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
B
barrierye 已提交
583
            data_id = channeldata[key].id
584
            for _, data in channeldata.items():
B
barrierye 已提交
585 586
                if data.ecode != ChannelDataEcode.OK.value:
                    error_channeldata = data
587 588
                    break
        else:
B
barrierye 已提交
589 590 591 592
            data_id = channeldata.id
            if channeldata.ecode != ChannelDataEcode.OK.value:
                error_channeldata = channeldata
        return data_id, error_channeldata
593

B
barrierye 已提交
594
    def _push_to_output_channels(self, data, channels, name=None):
B
bug fix  
barrierye 已提交
595 596
        if name is None:
            name = self.name
B
barrierye 已提交
597
        for channel in channels:
B
bug fix  
barrierye 已提交
598
            channel.push(data, name)
B
barrierye 已提交
599

B
barrierye 已提交
600 601 602 603 604 605 606 607 608 609 610
    def start(self):
        proces = []
        for concurrency_idx in range(self._concurrency):
            p = multiprocessing.Process(
                target=self._run,
                args=(concurrency_idx, self.get_input_channel(),
                      self.get_output_channels()))
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
611
    def _run(self, concurrency_idx, input_channel, output_channels):
B
barrierye 已提交
612 613
        self.init_client(self._client_config, self._server_name,
                         self._fetch_names)
B
bug fix  
barrierye 已提交
614
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
615
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
616 617
        self._is_run = True
        while self._is_run:
B
barrierye 已提交
618
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
619
            channeldata = input_channel.front(self.name)
B
barrierye 已提交
620
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
621
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
622

B
barrierye 已提交
623
            data_id, error_channeldata = self._parse_channeldata(channeldata)
624

B
bug fix  
barrierye 已提交
625
            # error data in predecessor Op
B
barrierye 已提交
626 627 628
            if error_channeldata is not None:
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
B
barrierye 已提交
629 630
                continue

B
bug fix  
barrierye 已提交
631
            # preprecess
B
barrierye 已提交
632 633
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
634
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
635 636
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
637
                # preprocess function not implemented
B
barrierye 已提交
638 639 640 641 642 643
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
B
barrierye 已提交
644 645
                        data_id=data_id),
                    output_channels)
B
barrierye 已提交
646
                continue
B
bug fix  
barrierye 已提交
647
            except TypeError as e:
B
barrierye 已提交
648
                # Error type in channeldata.datatype
B
bug fix  
barrierye 已提交
649 650 651 652 653 654
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
B
barrierye 已提交
655 656
                        data_id=data_id),
                    output_channels)
B
bug fix  
barrierye 已提交
657 658 659 660 661 662
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
B
barrierye 已提交
663
                        ecode=ChannelDataEcode.UNKNOW.value,
B
bug fix  
barrierye 已提交
664
                        error_info=error_info,
B
barrierye 已提交
665 666
                        data_id=data_id),
                    output_channels)
B
bug fix  
barrierye 已提交
667
                continue
668

B
barrierye 已提交
669 670
            # midprocess
            call_future = None
B
barrierye 已提交
671
            if self.with_serving:
B
bug fix  
barrierye 已提交
672
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
673 674 675
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
B
bug fix  
barrierye 已提交
676
                        call_future = self.midprocess(preped_data)
B
barrierye 已提交
677 678 679 680
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
681
                else:
B
barrierye 已提交
682 683 684
                    for i in range(self._retry):
                        try:
                            call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
685 686 687 688
                                self._timeout,
                                self.midprocess,
                                args=(preped_data, ))
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
689 690
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
691 692
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
693 694
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
695
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
696 697 698 699 700 701 702
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
703
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
704 705 706
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
707 708
                            data_id=data_id),
                        output_channels)
B
barrierye 已提交
709 710
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
711

B
barrierye 已提交
712 713 714
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
B
barrierye 已提交
715
            if self.with_serving:
B
bug fix  
barrierye 已提交
716
                # use call_future
B
barrierye 已提交
717
                output_data = ChannelData(
B
barrierye 已提交
718
                    datatype=ChannelDataType.CHANNEL_FUTURE.value,
B
barrierye 已提交
719 720 721
                    future=call_future,
                    data_id=data_id,
                    callback_func=self.postprocess)
B
barrierye 已提交
722 723 724 725 726 727 728 729 730
                #TODO: for future are not picklable
                npdata = self.postprocess(call_future.result())
                self._push_to_output_channels(
                    ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=npdata,
                        data_id=data_id),
                    output_channels)
                continue
B
barrierye 已提交
731
            else:
B
bug fix  
barrierye 已提交
732 733 734 735 736 737 738 739 740
                try:
                    postped_data = self.postprocess(preped_data)
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
741 742
                            data_id=data_id),
                        output_channels)
B
bug fix  
barrierye 已提交
743 744
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
745 746
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
747
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
748 749 750 751
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
752 753
                            data_id=data_id),
                        output_channels)
B
barrierye 已提交
754
                    continue
B
bug fix  
barrierye 已提交
755

B
barrierye 已提交
756 757 758 759
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
760 761 762 763
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
B
barrierye 已提交
764
            self._push_to_output_channels(output_data, output_channels)
B
barrierye 已提交
765 766 767 768 769 770 771 772 773 774
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
775

776 777 778
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
779

B
bug fix  
barrierye 已提交
780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

B
barrierye 已提交
800
    def _run(self, input_channel, output_channels):
B
bug fix  
barrierye 已提交
801 802
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
803 804
        self._is_run = True
        while self._is_run:
B
bug fix  
barrierye 已提交
805
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
806
            channeldata = input_channel.front(self.name)
B
bug fix  
barrierye 已提交
807 808 809 810 811
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
B
barrierye 已提交
812 813
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
B
bug fix  
barrierye 已提交
814
            else:
B
barrierye 已提交
815 816 817 818
                self._push_to_output_channels(
                    channeldata,
                    channels=output_channels,
                    name=self._virtual_pred_ops[0].name)
B
bug fix  
barrierye 已提交
819 820 821
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
822
class GeneralPythonService(
823
        general_python_service_pb2_grpc.GeneralPythonServiceServicer):
B
barrierye 已提交
824
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
825
        super(GeneralPythonService, self).__init__()
826
        self.name = "#G"
827 828
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
829 830
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
831 832 833 834 835
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
836 837
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
838
        self._retry = retry
B
barrierye 已提交
839 840 841
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
842 843

    def _log(self, info_str):
844
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
845

846
    def set_in_channel(self, in_channel):
847 848 849 850
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
851
        in_channel.add_producer(self.name)
852 853 854
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
855 856 857 858
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
859
        out_channel.add_consumer(self.name)
860 861
        self._out_channel = out_channel

B
barrierye 已提交
862 863
    def _recive_out_channel_func(self):
        while True:
864
            channeldata = self._out_channel.front(self.name)
865
            if not isinstance(channeldata, ChannelData):
866 867
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
868
                              format(type(channeldata))))
B
barrierye 已提交
869
            with self._cv:
B
barrierye 已提交
870
                data_id = channeldata.id
871
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
872
                self._cv.notify_all()
B
barrierye 已提交
873 874

    def _get_next_id(self):
B
barrierye 已提交
875
        with self._id_lock:
B
barrierye 已提交
876 877 878 879
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
880 881 882 883 884 885
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
886
        return resp
B
barrierye 已提交
887 888

    def _pack_data_for_infer(self, request):
889
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
890
        data_id = self._get_next_id()
B
barrierye 已提交
891
        npdata = {}
B
barrierye 已提交
892 893 894 895 896 897
        try:
            for idx, name in enumerate(request.feed_var_names):
                logging.debug(
                    self._log('name: {}'.format(request.feed_var_names[idx])))
                logging.debug(
                    self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
898 899 900 901
                npdata[name] = np.frombuffer(
                    request.feed_insts[idx], dtype=request.type[idx])
                npdata[name].shape = np.frombuffer(
                    request.shape[idx], dtype="int32")
B
barrierye 已提交
902
        except Exception as e:
B
barrierye 已提交
903 904 905 906 907 908 909 910 911
            return ChannelData(
                ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
                error_info="rpc package error",
                data_id=data_id), data_id
        else:
            return ChannelData(
                datatype=ChannelDataType.CHANNEL_NPDATA.value,
                npdata=npdata,
                data_id=data_id), data_id
B
barrierye 已提交
912

913 914 915
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
B
barrierye 已提交
916
        resp.ecode = channeldata.ecode
B
bug fix  
barrierye 已提交
917
        if resp.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
918
            if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
919 920 921 922 923
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
B
barrierye 已提交
924 925 926
            elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
                                          ChannelDataType.CHANNEL_NPDATA.value):
                feed = channeldata.parse()
B
bug fix  
barrierye 已提交
927
                for name, var in feed.items():
928 929 930 931 932 933 934 935
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
B
barrierye 已提交
936
                    self._log("Error type({}) in datatype.".format(
B
barrierye 已提交
937
                        channeldata.datatype)))
B
barrierye 已提交
938
        else:
B
barrierye 已提交
939
            resp.error_info = channeldata.error_info
B
barrierye 已提交
940
        return resp
B
barrierye 已提交
941

B
barrierye 已提交
942
    def inference(self, request, context):
943
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
944
        data, data_id = self._pack_data_for_infer(request)
945
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
946

947
        resp_channeldata = None
B
barrierye 已提交
948 949
        for i in range(self._retry):
            logging.debug(self._log('push data'))
950 951 952
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
953 954

            logging.debug(self._log('wait for infer'))
955
            _profiler.record("{}-fetch_0".format(self.name))
956
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
957
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
958

B
barrierye 已提交
959
            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
960
                break
B
barrierye 已提交
961 962
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
B
barrierye 已提交
963
                    i + 1, resp_channeldata.error_info))
B
barrierye 已提交
964

965
        _profiler.record("{}-postpack_0".format(self.name))
966
        resp = self._pack_data_for_resp(resp_channeldata)
967
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
968
        _profiler.print_profile()
B
barrierye 已提交
969 970
        return resp

B
barrierye 已提交
971 972

class PyServer(object):
B
barrierye 已提交
973
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
974
        self._channels = []
975
        self._user_ops = []
B
bug fix  
barrierye 已提交
976
        self._actual_ops = []
B
barrierye 已提交
977 978
        self._port = None
        self._worker_num = None
B
barrierye 已提交
979 980
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
981
        self._retry = retry
B
barrierye 已提交
982
        self._manager = multiprocessing.Manager()
B
barrierye 已提交
983
        _profiler.enable(profile)
B
barrierye 已提交
984 985 986 987 988

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
989 990 991
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
992
        self._user_ops.extend(ops)
B
barrierye 已提交
993 994

    def gen_desc(self):
995
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
996 997
        pass

998 999 1000
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
1001
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
1002 1003 1004 1005 1006
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
1007 1008 1009 1010 1011 1012 1013 1014
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
1015
                outdegs[pred_op.name].append(op)
1016

B
bug fix  
barrierye 已提交
1017
        # topo sort to get dag_views
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
1028
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
1029
                    indeg_num[succ_op.name] -= 1
1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
1055 1056 1057
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
1058
        actual_view = None
1059 1060 1061 1062
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
1063 1064
            if actual_view is None:
                actual_view = view
1065 1066
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
1067 1068
            for op in actual_view:
                # find actual succ op in next view and create virtual op
1069 1070
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
1071 1072
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
1073 1074 1075 1076
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
1077 1078 1079
                        # create virtual op
                        virtual_op = None
                        virtual_op = VirtualOp(name=virtual_op_name_gen.next())
1080
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
1081 1082 1083 1084 1085
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
1086 1087 1088
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
1089
                if op.name in processed_op:
1090
                    continue
B
barrierye 已提交
1091
                channel = Channel(self._manager, name=channel_name_gen.next())
1092
                channels.append(channel)
B
bug fix  
barrierye 已提交
1093
                logging.debug("{} => {}".format(channel.name, op.name))
1094
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
1095
                pred_ops = pred_op_of_next_view_op[op.name]
1096 1097 1098
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
1099
                    # if pred_op is virtual op, it will use ancestors as producers to channel
1100
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
1101 1102
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
1103
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
1104 1105 1106
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
1118 1119
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
1120 1121
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
B
barrierye 已提交
1122
        output_channel = Channel(self._manager, name=channel_name_gen.next())
1123
        channels.append(output_channel)
B
barrierye 已提交
1124
        last_op = dag_views[-1][0]  # TODO: fix it
1125 1126
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
1127
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
1128 1129
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
1130
                # pass read op
B
fix bug  
barrierye 已提交
1131
                continue
B
bug fix  
barrierye 已提交
1132
            self._actual_ops.append(op)
1133
        self._channels = channels
B
bug fix  
barrierye 已提交
1134 1135
        for c in channels:
            logging.debug(c.debug())
1136 1137
        return input_channel, output_channel

B
barrierye 已提交
1138 1139 1140
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
1141 1142 1143

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
1144
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
1145
        for op in self._actual_ops:
B
barrierye 已提交
1146
            if op.with_serving:
B
fix bug  
barrierye 已提交
1147
                self.prepare_serving(op)
B
barrierye 已提交
1148 1149
        self.gen_desc()

1150
    def _run_ops(self):
B
barrierye 已提交
1151
        proces = []
B
bug fix  
barrierye 已提交
1152
        for op in self._actual_ops:
B
barrierye 已提交
1153 1154
            proces.extend(op.start())
        return proces
1155

1156
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1157
        for op in self._actual_ops:
1158 1159
            op.stop()

1160
    def run_server(self):
B
barrierye 已提交
1161
        op_proces = self._run_ops()
B
barrierye 已提交
1162 1163
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1164
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1165 1166
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1167
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1168
        server.start()
1169 1170
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
barrierye 已提交
1171 1172
        for p in op_proces:
            p.join()
B
barrierye 已提交
1173 1174 1175 1176 1177 1178 1179

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
1180 1181
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1182
        else:
1183 1184
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
1185 1186
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))