pyserver.py 42.2 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import sys
B
barrierye 已提交
18 19 20 21 22 23 24
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
import os
B
barrierye 已提交
25
import paddle_serving_server
26
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
27
from concurrent import futures
B
barrierye 已提交
28
import numpy as np
B
barrierye 已提交
29
import grpc
30 31 32 33
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
34
import logging
35
import random
B
barrierye 已提交
36
import time
B
barrierye 已提交
37
import func_timeout
38
import enum
39
import collections
B
barrierye 已提交
40 41


B
barrierye 已提交
42 43 44 45 46 47 48 49 50 51 52
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
53 54
        if self._enable is False:
            return
B
barrierye 已提交
55 56 57 58 59 60
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
61 62
        if self._enable is False:
            return
B
barrierye 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


82 83 84
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
85 86
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
B
barrierye 已提交
87 88
    RPC_PACKAGE_ERROR = 4
    UNKNOW = 5
89 90 91 92 93


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1
B
barrierye 已提交
94
    CHANNEL_NPDATA = 2
B
bug fix  
barrierye 已提交
95
    ERROR = 3
96 97 98 99


class ChannelData(object):
    def __init__(self,
B
barrierye 已提交
100
                 datatype=None,
101 102
                 future=None,
                 pbdata=None,
B
barrierye 已提交
103
                 npdata=None,
104
                 data_id=None,
B
barrierye 已提交
105 106 107 108 109 110
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
barrierye 已提交
111 112 113 114 115 116
        1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
        2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
        3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
        4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
        5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        6. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
117 118 119 120
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
B
barrierye 已提交
121 122 123 124
            pbdata = channel_pb2.ChannelData()
            pbdata.ecode = ecode
            pbdata.id = data_id
            pbdata.error_info = error_info
B
bug fix  
barrierye 已提交
125
            datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
126
        else:
B
barrierye 已提交
127
            if datatype == ChannelDataType.CHANNEL_FUTURE.value:
B
barrierye 已提交
128 129 130 131 132 133
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    pbdata.ecode = ChannelDataEcode.OK.value
                    pbdata.id = data_id
B
barrierye 已提交
134 135 136 137 138
            elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
139
                    pbdata.id = data_id
B
barrierye 已提交
140
                    ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
141 142 143 144
                    pbdata.ecode = ecode
                    if pbdata.ecode != ChannelDataEcode.OK.value:
                        pbdata.error_info = error_info
                        logging.error(pbdata.error_info)
B
barrierye 已提交
145
                    else:
B
bug fix  
barrierye 已提交
146
                        for name, value in npdata.items():
B
barrierye 已提交
147 148 149 150 151 152 153 154 155
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
            elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
156 157 158 159 160 161
                pbdata = channel_pb2.ChannelData()
                pbdata.id = data_id
                pbdata.ecode = ecode
                if pbdata.ecode != ChannelDataEcode.OK.value:
                    pbdata.error_info = error_info
                    logging.error(pbdata.error_info)
B
barrierye 已提交
162 163
            else:
                raise ValueError("datatype not match")
B
barrierye 已提交
164 165 166 167
        if not isinstance(pbdata, channel_pb2.ChannelData):
            raise TypeError(
                "pbdata must be pyserving_channel_pb2.ChannelData type({})".
                format(type(pbdata)))
B
barrierye 已提交
168
        self.future = future
169
        self.pbdata = pbdata
B
barrierye 已提交
170 171
        self.npdata = npdata
        self.datatype = datatype
172 173
        self.callback_func = callback_func

B
barrierye 已提交
174 175 176 177 178 179 180 181 182 183
    def _check_npdata(self, npdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for name, value in npdata.items():
            if not isinstance(name, (str, unicode)):
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the key of postped_data must " \
                        "be str, but get {}".format(type(name)))
                break
            if not isinstance(value, np.ndarray):
B
barrierye 已提交
184 185
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the value of postped_data must " \
B
barrierye 已提交
186 187 188 189
                        "be np.ndarray, but get {}".format(type(value)))
                break
        return ecode, error_info

190 191
    def parse(self):
        # return narray
B
barrierye 已提交
192 193 194
        feed = None
        if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
            feed = {}
195 196 197
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
B
barrierye 已提交
198
        elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
199 200 201
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
B
barrierye 已提交
202 203
        elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            feed = self.npdata
204
        else:
B
barrierye 已提交
205
            raise TypeError("Error type({}) in datatype.".format(datatype))
206 207
        return feed

B
barrierye 已提交
208
    def __str__(self):
B
barrierye 已提交
209 210
        return "type[{}], ecode[{}]".format(
            ChannelDataType(self.datatype).name, self.pbdata.ecode)
B
barrierye 已提交
211

212

B
barrierye 已提交
213
class Channel(Queue.Queue):
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
231 232
    def __init__(self, name=None, maxsize=-1, timeout=None):
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
233 234
        self._maxsize = maxsize
        self._timeout = timeout
235
        self.name = name
236
        self._stop = False
237

B
barrierye 已提交
238
        self._cv = threading.Condition()
239 240

        self._producers = []
B
barrierye 已提交
241 242 243 244 245 246 247
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
        self._consumer_base_idx = 0
        self._front_res = []
248 249 250 251 252 253 254 255

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
256
        return "[{}] {}".format(self.name, info_str)
257 258 259 260 261 262

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
263
        """ not thread safe, and can only be called during initialization. """
264 265 266 267
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
268 269

    def add_consumer(self, op_name):
B
barrierye 已提交
270
        """ not thread safe, and can only be called during initialization. """
271 272 273 274
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
275 276 277 278

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
279

280
    def push(self, channeldata, op_name=None):
281
        logging.debug(
282
            self._log("{} try to push data: {}".format(op_name,
B
barrierye 已提交
283
                                                       channeldata.__str__())))
284
        if len(self._producers) == 0:
285
            raise Exception(
286 287 288 289
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
290
            with self._cv:
291
                while self._stop is False:
B
barrierye 已提交
292
                    try:
293
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
294
                        break
B
barrierye 已提交
295
                    except Queue.Full:
B
barrierye 已提交
296 297
                        self._cv.wait()
                self._cv.notify_all()
298 299 300 301 302 303
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
304

305
        producer_num = len(self._producers)
B
barrierye 已提交
306
        data_id = channeldata.pbdata.id
307
        put_data = None
B
barrierye 已提交
308
        with self._cv:
309
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
310 311 312 313 314 315
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
B
barrierye 已提交
316
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
317 318 319 320 321 322
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
323

B
barrierye 已提交
324 325
            if put_data is None:
                logging.debug(
326
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
327 328
                              format(op_name)))
            else:
329
                while self._stop is False:
B
barrierye 已提交
330 331 332 333 334 335 336 337 338
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
339
        return True
340

341
    def front(self, op_name=None):
B
barrierye 已提交
342
        logging.debug(self._log("{} try to get data".format(op_name)))
343 344 345 346 347 348 349
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
350
            with self._cv:
351
                while self._stop is False and resp is None:
B
barrierye 已提交
352
                    try:
B
barrierye 已提交
353
                        resp = self.get(timeout=0)
B
barrierye 已提交
354 355 356
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
357 358 359
            logging.debug(
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
360 361 362 363 364
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
365

B
barrierye 已提交
366 367
        with self._cv:
            # data_idx = consumer_idx - base_idx
368
            while self._stop is False and self._consumers[
B
barrierye 已提交
369
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
370
                try:
B
barrierye 已提交
371
                    channeldata = self.get(timeout=0)
372
                    self._front_res.append(channeldata)
B
barrierye 已提交
373 374 375
                    break
                except Queue.Empty:
                    self._cv.wait()
376

B
barrierye 已提交
377
            consumer_idx = self._consumers[op_name]
B
barrierye 已提交
378
            base_idx = self._consumer_base_idx
B
barrierye 已提交
379 380 381 382 383 384 385 386 387
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
B
barrierye 已提交
388
                self._consumer_base_idx += 1
B
barrierye 已提交
389 390 391 392 393 394

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
B
barrierye 已提交
395

B
barrierye 已提交
396
            self._cv.notify_all()
397

398
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
399
        return resp  # reference, read only
B
barrierye 已提交
400

401 402 403 404
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
405
        self._cv.notify_all()
406

B
barrierye 已提交
407 408 409

class Op(object):
    def __init__(self,
410
                 name,
411
                 inputs,
B
barrierye 已提交
412 413 414 415 416
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
417
                 fetch_names=None,
B
barrierye 已提交
418
                 concurrency=1,
B
barrierye 已提交
419 420
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
421
        self._run = False
422
        self.name = name  # to identify the type of OP, it must be globally unique
423
        self._concurrency = concurrency  # amount of concurrency
424
        self.set_input_ops(inputs)
B
barrierye 已提交
425 426 427 428
        self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
        self._device = device
B
barrierye 已提交
429
        self._timeout = timeout
B
bug fix  
barrierye 已提交
430
        self._retry = max(1, retry)
431 432
        self._input = None
        self._outputs = []
B
barrierye 已提交
433

B
barrierye 已提交
434 435 436 437 438
    def set_client(self, client_config, server_name, fetch_names):
        self._client = None
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
439
            return
B
barrierye 已提交
440 441 442 443 444
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

B
barrierye 已提交
445 446 447
    def with_serving(self):
        return self._client is not None

448
    def get_input_channel(self):
449
        return self._input
B
barrierye 已提交
450

451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
466 467 468 469
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
470
        channel.add_consumer(self.name)
471
        self._input = channel
B
barrierye 已提交
472

473
    def get_output_channels(self):
B
barrierye 已提交
474 475
        return self._outputs

476 477
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
478
            raise TypeError(
479 480 481 482
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
483

484 485
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
486 487 488
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
489
        feed = channeldata.parse()
490
        return feed
B
barrierye 已提交
491 492

    def midprocess(self, data):
493 494 495 496 497 498 499
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
500 501 502 503
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
504 505

    def postprocess(self, output_data):
B
barrierye 已提交
506
        return output_data
B
barrierye 已提交
507 508

    def stop(self):
509 510 511
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
512
        self._run = False
B
barrierye 已提交
513

514
    def _parse_channeldata(self, channeldata):
B
barrierye 已提交
515
        data_id, error_pbdata = None, None
516 517 518
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
B
barrierye 已提交
519
            data_id = channeldata[key].pbdata.id
520
            for _, data in channeldata.items():
B
barrierye 已提交
521 522
                if data.pbdata.ecode != ChannelDataEcode.OK.value:
                    error_pbdata = data.pbdata
523 524
                    break
        else:
B
barrierye 已提交
525 526 527 528
            data_id = channeldata.pbdata.id
            if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
                error_pbdata = channeldata.pbdata
        return data_id, error_pbdata
529

B
barrierye 已提交
530
    def _push_to_output_channels(self, data, name=None):
B
bug fix  
barrierye 已提交
531 532
        if name is None:
            name = self.name
B
barrierye 已提交
533
        for channel in self._outputs:
B
bug fix  
barrierye 已提交
534
            channel.push(data, name)
B
barrierye 已提交
535

B
barrierye 已提交
536
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
537
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
538
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
539 540
        self._run = True
        while self._run:
B
barrierye 已提交
541
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
542
            channeldata = self._input.front(self.name)
B
barrierye 已提交
543
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
544
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
545

B
barrierye 已提交
546
            data_id, error_pbdata = self._parse_channeldata(channeldata)
547

B
bug fix  
barrierye 已提交
548
            # error data in predecessor Op
B
barrierye 已提交
549 550 551 552 553
            if error_pbdata is not None:
                self._push_to_output_channels(
                    ChannelData(
                        datatype=ChannelDataType.CHANNEL_PBDATA.value,
                        pbdata=error_pbdata))
B
barrierye 已提交
554 555
                continue

B
bug fix  
barrierye 已提交
556
            # preprecess
B
barrierye 已提交
557 558
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
559
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
560 561
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
562
                # preprocess function not implemented
B
barrierye 已提交
563 564 565 566 567 568
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
B
barrierye 已提交
569
                        data_id=data_id))
B
barrierye 已提交
570
                continue
B
bug fix  
barrierye 已提交
571
            except TypeError as e:
B
barrierye 已提交
572
                # Error type in channeldata.datatype
B
bug fix  
barrierye 已提交
573 574 575 576 577 578
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
B
barrierye 已提交
579
                        data_id=data_id))
B
bug fix  
barrierye 已提交
580 581 582 583 584 585
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
B
barrierye 已提交
586
                        ecode=ChannelDataEcode.UNKNOW.value,
B
bug fix  
barrierye 已提交
587
                        error_info=error_info,
B
barrierye 已提交
588
                        data_id=data_id))
B
bug fix  
barrierye 已提交
589
                continue
590

B
barrierye 已提交
591 592
            # midprocess
            call_future = None
B
barrierye 已提交
593
            if self.with_serving():
B
bug fix  
barrierye 已提交
594
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
595 596 597
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
B
bug fix  
barrierye 已提交
598
                        call_future = self.midprocess(preped_data)
B
barrierye 已提交
599 600 601 602
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
603
                else:
B
barrierye 已提交
604 605 606
                    for i in range(self._retry):
                        try:
                            call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
607 608 609 610
                                self._timeout,
                                self.midprocess,
                                args=(preped_data, ))
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
611 612
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
613 614
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
615 616
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
617
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
618 619 620 621 622 623 624
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
625
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
626 627 628
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
629
                            data_id=data_id))
B
barrierye 已提交
630 631
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
632

B
barrierye 已提交
633 634 635
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
B
barrierye 已提交
636
            if self.with_serving():
B
bug fix  
barrierye 已提交
637
                # use call_future
B
barrierye 已提交
638
                output_data = ChannelData(
B
barrierye 已提交
639
                    datatype=ChannelDataType.CHANNEL_FUTURE.value,
B
barrierye 已提交
640 641 642 643
                    future=call_future,
                    data_id=data_id,
                    callback_func=self.postprocess)
            else:
B
bug fix  
barrierye 已提交
644 645 646 647 648 649 650 651 652
                try:
                    postped_data = self.postprocess(preped_data)
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
653
                            data_id=data_id))
B
bug fix  
barrierye 已提交
654 655
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
656 657
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
658
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
659 660 661 662
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
663
                            data_id=data_id))
B
barrierye 已提交
664
                    continue
B
bug fix  
barrierye 已提交
665

B
barrierye 已提交
666 667 668 669
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
670 671 672 673
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
B
barrierye 已提交
674
            self._push_to_output_channels(output_data)
B
barrierye 已提交
675 676 677 678 679 680 681 682 683 684
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
685

686 687 688
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
689

B
bug fix  
barrierye 已提交
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

B
barrierye 已提交
710
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
711 712
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
713 714
        self._run = True
        while self._run:
B
bug fix  
barrierye 已提交
715
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
716
            channeldata = self._input.front(self.name)
B
bug fix  
barrierye 已提交
717 718 719 720 721
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
B
barrierye 已提交
722
                    self._push_to_output_channels(data, name=name)
B
bug fix  
barrierye 已提交
723
            else:
B
barrierye 已提交
724 725
                self._push_to_output_channels(channeldata,
                                              self._virtual_pred_ops[0].name)
B
bug fix  
barrierye 已提交
726 727 728
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
729
class GeneralPythonService(
B
barrierye 已提交
730
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
731
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
732
        super(GeneralPythonService, self).__init__()
733
        self.name = "#G"
734 735
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
736 737
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
738 739 740 741 742
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
743 744
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
745
        self._retry = retry
B
barrierye 已提交
746 747 748
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
749 750

    def _log(self, info_str):
751
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
752

753
    def set_in_channel(self, in_channel):
754 755 756 757
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
758
        in_channel.add_producer(self.name)
759 760 761
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
762 763 764 765
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
766
        out_channel.add_consumer(self.name)
767 768
        self._out_channel = out_channel

B
barrierye 已提交
769 770
    def _recive_out_channel_func(self):
        while True:
771
            channeldata = self._out_channel.front(self.name)
772
            if not isinstance(channeldata, ChannelData):
773 774
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
775
                              format(type(channeldata))))
B
barrierye 已提交
776
            with self._cv:
B
barrierye 已提交
777
                data_id = channeldata.pbdata.id
778
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
779
                self._cv.notify_all()
B
barrierye 已提交
780 781

    def _get_next_id(self):
B
barrierye 已提交
782
        with self._id_lock:
B
barrierye 已提交
783 784 785 786
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
787 788 789 790 791 792
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
793
        return resp
B
barrierye 已提交
794 795

    def _pack_data_for_infer(self, request):
796
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
797
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
798
        data_id = self._get_next_id()
B
barrierye 已提交
799 800
        pbdata.id = data_id
        pbdata.ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
801 802 803 804 805 806
        try:
            for idx, name in enumerate(request.feed_var_names):
                logging.debug(
                    self._log('name: {}'.format(request.feed_var_names[idx])))
                logging.debug(
                    self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
807 808 809 810 811 812
                inst = channel_pb2.Inst()
                inst.data = request.feed_insts[idx]
                inst.shape = request.shape[idx]
                inst.name = name
                inst.type = request.type[idx]
                pbdata.insts.append(inst)
B
barrierye 已提交
813
        except Exception as e:
B
barrierye 已提交
814 815 816 817 818
            pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
            pbdata.error_info = "rpc package error"
        return ChannelData(
            datatype=ChannelDataType.CHANNEL_PBDATA.value,
            pbdata=pbdata), data_id
B
barrierye 已提交
819

820 821 822
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
B
barrierye 已提交
823
        resp.ecode = channeldata.pbdata.ecode
B
bug fix  
barrierye 已提交
824
        if resp.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
825
            if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
826 827 828 829 830
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
B
barrierye 已提交
831 832 833
            elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
                                          ChannelDataType.CHANNEL_NPDATA.value):
                feed = channeldata.parse()
B
bug fix  
barrierye 已提交
834
                for name, var in feed.items():
835 836 837 838 839 840 841 842
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
B
barrierye 已提交
843
                    self._log("Error type({}) in datatype.".format(
B
barrierye 已提交
844
                        channeldata.datatype)))
B
barrierye 已提交
845
        else:
B
barrierye 已提交
846
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
847
        return resp
B
barrierye 已提交
848

B
barrierye 已提交
849
    def inference(self, request, context):
850
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
851
        data, data_id = self._pack_data_for_infer(request)
852
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
853

854
        resp_channeldata = None
B
barrierye 已提交
855 856
        for i in range(self._retry):
            logging.debug(self._log('push data'))
857 858 859
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
860 861

            logging.debug(self._log('wait for infer'))
862
            _profiler.record("{}-fetch_0".format(self.name))
863
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
864
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
865

B
barrierye 已提交
866
            if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
867
                break
B
barrierye 已提交
868 869
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
B
barrierye 已提交
870
                    i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
871

872
        _profiler.record("{}-postpack_0".format(self.name))
873
        resp = self._pack_data_for_resp(resp_channeldata)
874
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
875
        _profiler.print_profile()
B
barrierye 已提交
876 877
        return resp

B
barrierye 已提交
878 879

class PyServer(object):
B
barrierye 已提交
880
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
881
        self._channels = []
882
        self._user_ops = []
B
bug fix  
barrierye 已提交
883
        self._actual_ops = []
B
barrierye 已提交
884
        self._op_threads = []
B
barrierye 已提交
885 886
        self._port = None
        self._worker_num = None
B
barrierye 已提交
887 888
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
889
        self._retry = retry
B
barrierye 已提交
890
        _profiler.enable(profile)
B
barrierye 已提交
891 892 893 894 895

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
896 897 898
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
899
        self._user_ops.extend(ops)
B
barrierye 已提交
900 901

    def gen_desc(self):
902
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
903 904
        pass

905 906 907
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
908
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
909 910 911 912 913
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
914 915 916 917 918 919 920 921
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
922
                outdegs[pred_op.name].append(op)
923

B
bug fix  
barrierye 已提交
924
        # topo sort to get dag_views
925 926 927 928 929 930 931 932 933 934
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
935
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
936
                    indeg_num[succ_op.name] -= 1
937 938 939 940 941 942 943 944
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
B
barrierye 已提交
945 946 947 948
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")
949 950

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
951 952 953 954 955 956 957 958 959 960 961
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
962 963 964
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
965
        actual_view = None
966 967 968 969
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
970 971
            if actual_view is None:
                actual_view = view
972 973
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
974 975
            for op in actual_view:
                # find actual succ op in next view and create virtual op
976 977
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
978 979
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
980 981 982 983
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
984 985
                        # create virtual op
                        virtual_op = None
B
barrierye 已提交
986
                        virtual_op = VirtualOp(name=virtual_op_name_gen.next())
987
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
988 989 990 991 992
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
993 994 995
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
996
                if op.name in processed_op:
997
                    continue
B
barrierye 已提交
998
                channel = Channel(name=channel_name_gen.next())
999
                channels.append(channel)
B
bug fix  
barrierye 已提交
1000
                logging.debug("{} => {}".format(channel.name, op.name))
1001
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
1002
                pred_ops = pred_op_of_next_view_op[op.name]
1003 1004 1005
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
1006
                    # if pred_op is virtual op, it will use ancestors as producers to channel
1007
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
1008 1009
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
1010
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
1011 1012 1013
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
1025 1026
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
1027 1028
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
B
barrierye 已提交
1029
        output_channel = Channel(name=channel_name_gen.next())
1030
        channels.append(output_channel)
B
bug fix  
barrierye 已提交
1031
        last_op = dag_views[-1][0]
1032 1033
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
1034
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
1035 1036
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
1037
                # pass read op
B
fix bug  
barrierye 已提交
1038
                continue
B
bug fix  
barrierye 已提交
1039
            self._actual_ops.append(op)
1040
        self._channels = channels
B
bug fix  
barrierye 已提交
1041 1042
        for c in channels:
            logging.debug(c.debug())
1043 1044
        return input_channel, output_channel

B
barrierye 已提交
1045 1046 1047
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
1048 1049 1050

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
1051
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
1052
        for op in self._actual_ops:
B
barrierye 已提交
1053
            if op.with_serving():
B
fix bug  
barrierye 已提交
1054
                self.prepare_serving(op)
B
barrierye 已提交
1055 1056
        self.gen_desc()

B
barrierye 已提交
1057 1058 1059
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)

1060
    def _run_ops(self):
B
bug fix  
barrierye 已提交
1061
        for op in self._actual_ops:
B
barrierye 已提交
1062 1063 1064 1065 1066 1067 1068 1069
            op_concurrency = op.get_concurrency()
            logging.debug("run op: {}, op_concurrency: {}".format(
                op.name, op_concurrency))
            for c in range(op_concurrency):
                th = threading.Thread(
                    target=self._op_start_wrapper, args=(op, c))
                th.start()
                self._op_threads.append(th)
1070

1071
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1072
        for op in self._actual_ops:
1073 1074
            op.stop()

1075
    def run_server(self):
B
barrierye 已提交
1076
        self._run_ops()
B
barrierye 已提交
1077 1078
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1079
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1080 1081
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1082
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1083
        server.start()
1084 1085
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
barrierye 已提交
1086 1087
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
1088 1089 1090 1091 1092 1093 1094

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
1095 1096
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1097
        else:
1098 1099
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
1100 1101
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))