pyserver.py 29.0 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20
import paddle_serving_server
21
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
22
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24
import grpc
25 26 27 28
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
29
import logging
30
import random
B
barrierye 已提交
31
import time
B
barrierye 已提交
32
import func_timeout
33
import enum
B
barrierye 已提交
34 35


B
barrierye 已提交
36 37 38 39 40 41 42 43 44 45 46
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
47 48
        if self._enable is False:
            return
B
barrierye 已提交
49 50 51 52 53 54
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
55 56
        if self._enable is False:
            return
B
barrierye 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1


class ChannelData(object):
    def __init__(self,
                 future=None,
                 pbdata=None,
                 data_id=None,
                 callback_func=None):
        self.future = future
        if pbdata is None:
            if data_id is None:
                raise ValueError("data_id cannot be None")
            pbdata = channel_pb2.ChannelData()
            pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
            pbdata.ecode = ChannelDataEcode.OK.value
            pbdata.id = data_id
        self.pbdata = pbdata
        self.callback_func = callback_func

    def parse(self):
        # return narray
        feed = {}
        if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
        elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
        else:
            raise TypeError(
                self._log("Error type({}) in pbdata.type.".format(
                    self.pbdata.type)))
        return feed


B
barrierye 已提交
121
class Channel(Queue.Queue):
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
139
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
140
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
141 142
        self._maxsize = maxsize
        self._timeout = timeout
143
        self._name = name
144
        self._stop = False
145

146 147 148 149 150 151 152 153
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
154
        self._consumer_base_idx = 0
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
171
        """ not thread safe, and can only be called during initialization. """
172 173 174 175
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
176 177

    def add_consumer(self, op_name):
B
barrierye 已提交
178
        """ not thread safe, and can only be called during initialization. """
179 180 181 182
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
183 184 185 186

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
187

188
    def push(self, channeldata, op_name=None):
189
        logging.debug(
190 191
            self._log("{} try to push data: {}".format(op_name,
                                                       channeldata.pbdata)))
192
        if len(self._producers) == 0:
193
            raise Exception(
194 195 196 197
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
198
            with self._cv:
199
                while self._stop is False:
B
barrierye 已提交
200
                    try:
201
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
202 203 204 205
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
206 207 208 209 210 211
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
212

213
        producer_num = len(self._producers)
214
        data_id = channeldata.pbdata.id
215
        put_data = None
B
barrierye 已提交
216
        with self._cv:
217
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
218 219 220 221 222 223
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
224
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
225 226 227 228 229 230
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
231

B
barrierye 已提交
232 233
            if put_data is None:
                logging.debug(
234
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
235 236
                              format(op_name)))
            else:
237
                while self._stop is False:
B
barrierye 已提交
238 239 240 241 242 243 244 245 246
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
247
        return True
248

249 250 251 252 253 254 255 256 257
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
258
            with self._cv:
259
                while self._stop is False and resp is None:
B
barrierye 已提交
260 261 262 263 264
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
265 266 267 268 269 270
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
271

B
barrierye 已提交
272 273
        with self._cv:
            # data_idx = consumer_idx - base_idx
274 275
            while self._stop is False and self._consumers[
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
276
                try:
277 278
                    channeldata = self.get(timeout=0)
                    self._front_res.append(channeldata)
B
barrierye 已提交
279 280 281
                    break
                except Queue.Empty:
                    self._cv.wait()
282

B
barrierye 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
301

B
barrierye 已提交
302
            self._cv.notify_all()
303

304
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
305
        return resp  # reference, read only
B
barrierye 已提交
306

307 308 309 310 311
    def stop(self):
        #TODO
        self.close()
        self._stop = True

B
barrierye 已提交
312 313 314

class Op(object):
    def __init__(self,
315
                 name,
316
                 input,
B
barrierye 已提交
317 318 319 320 321 322
                 outputs,
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
323
                 fetch_names=None,
B
barrierye 已提交
324
                 concurrency=1,
B
barrierye 已提交
325 326
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
327
        self._run = False
328 329 330
        # TODO: globally unique check
        self._name = name  # to identify the type of OP, it must be globally unique
        self._concurrency = concurrency  # amount of concurrency
331
        self.set_input(input)
B
barrierye 已提交
332
        self.set_outputs(outputs)
B
barrierye 已提交
333
        self._client = None
B
barrierye 已提交
334 335 336 337 338 339
        if client_config is not None and \
                server_name is not None and \
                fetch_names is not None:
            self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
340
        self._device = device
B
barrierye 已提交
341
        self._timeout = timeout
B
barrierye 已提交
342
        self._retry = retry
B
barrierye 已提交
343 344 345 346 347 348 349 350 351 352

    def set_client(self, client_config, server_name, fetch_names):
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

353 354
    def get_input(self):
        return self._input
B
barrierye 已提交
355

356 357 358 359 360 361 362
    def set_input(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self._name)
        self._input = channel
B
barrierye 已提交
363 364 365 366 367 368

    def get_outputs(self):
        return self._outputs

    def set_outputs(self, channels):
        if not isinstance(channels, list):
369 370 371 372 373
            raise TypeError(
                self._log('output channels must be list type, not {}'.format(
                    type(channels))))
        for channel in channels:
            channel.add_producer(self._name)
B
barrierye 已提交
374 375
        self._outputs = channels

376 377
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
378
            raise Exception(
379 380 381
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
382
        feed = channeldata.parse()
383
        return feed
B
barrierye 已提交
384 385

    def midprocess(self, data):
386 387 388 389 390 391 392
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
393 394 395 396
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
397 398

    def postprocess(self, output_data):
B
barrierye 已提交
399
        return output_data
B
barrierye 已提交
400

401 402 403 404
    def errorprocess(self, error_info, data_id):
        data = channel_pb2.ChannelData()
        data.ecode = 1
        data.id = data_id
B
barrierye 已提交
405 406 407
        data.error_info = error_info
        return data

B
barrierye 已提交
408
    def stop(self):
409 410 411
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
412 413
        self._run = False

414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
    def _parse_channeldata(self, channeldata):
        data_id, error_data = None, None
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
            data_id = channeldata[key].pbdata.id
            for _, data in channeldata.items():
                if data.pbdata.ecode != 0:
                    error_data = data
                    break
        else:
            data_id = channeldata.pbdata.id
            if channeldata.pbdata.ecode != 0:
                error_data = channeldata.pbdata
        return data_id, error_data

B
barrierye 已提交
430
    def start(self, concurrency_idx):
B
barrierye 已提交
431 432
        self._run = True
        while self._run:
B
barrierye 已提交
433
            _profiler.record("{}{}-get_0".format(self._name, concurrency_idx))
434
            input_data = self._input.front(self._name)
B
barrierye 已提交
435
            _profiler.record("{}{}-get_1".format(self._name, concurrency_idx))
436
            logging.debug(self._log("input_data: {}".format(input_data)))
B
barrierye 已提交
437

438 439 440
            data_id, error_data = self._parse_channeldata(input_data)

            output_data = None
B
barrierye 已提交
441 442
            if error_data is None:
                _profiler.record("{}{}-prep_0".format(self._name,
B
barrierye 已提交
443
                                                      concurrency_idx))
B
barrierye 已提交
444 445
                data = self.preprocess(input_data)
                _profiler.record("{}{}-prep_1".format(self._name,
B
barrierye 已提交
446
                                                      concurrency_idx))
B
barrierye 已提交
447

448
                call_future = None
B
barrierye 已提交
449 450
                error_info = None
                if self.with_serving():
B
barrierye 已提交
451 452 453
                    for i in range(self._retry):
                        _profiler.record("{}{}-midp_0".format(self._name,
                                                              concurrency_idx))
B
bug fix  
barrierye 已提交
454
                        if self._timeout > 0:
B
barrierye 已提交
455
                            try:
456
                                call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
457 458 459
                                    self._timeout,
                                    self.midprocess,
                                    args=(data, ))
B
barrierye 已提交
460 461 462 463 464 465 466 467 468
                            except func_timeout.FunctionTimedOut:
                                logging.error("error: timeout")
                                error_info = "{}({}): timeout".format(
                                    self._name, concurrency_idx)
                            except Exception as e:
                                logging.error("error: {}".format(e))
                                error_info = "{}({}): {}".format(
                                    self._name, concurrency_idx, e)
                        else:
469
                            call_future = self.midprocess(data)
B
barrierye 已提交
470 471 472 473 474 475 476
                        _profiler.record("{}{}-midp_1".format(self._name,
                                                              concurrency_idx))
                        if i + 1 < self._retry:
                            error_info = None
                            logging.warn(
                                self._log("warn: timeout, retry({})".format(i +
                                                                            1)))
B
barrierye 已提交
477 478 479
                _profiler.record("{}{}-postp_0".format(self._name,
                                                       concurrency_idx))
                if error_info is not None:
480 481
                    error_data = self.errorprocess(error_info, data_id)
                    output_data = ChannelData(pbdata=error_data)
B
barrierye 已提交
482
                else:
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
                    if self.with_serving():  # use call_future
                        output_data = ChannelData(
                            future=call_future,
                            data_id=data_id,
                            callback_func=self.postprocess)
                    else:
                        post_data = self.postprocess(data)
                        if not isinstance(post_data, dict):
                            raise TypeError(
                                self._log(
                                    'output_data must be dict type, but get {}'.
                                    format(type(output_data))))
                        pbdata = channel_pb2.ChannelData()
                        for name, value in post_data.items():
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
                        pbdata.ecode = 0
                        pbdata.id = data_id
                        output_data = ChannelData(pbdata=pbdata)
B
barrierye 已提交
507 508 509
                _profiler.record("{}{}-postp_1".format(self._name,
                                                       concurrency_idx))
            else:
510
                output_data = ChannelData(pbdata=error_data)
B
barrierye 已提交
511

B
barrierye 已提交
512
            _profiler.record("{}{}-push_0".format(self._name, concurrency_idx))
B
barrierye 已提交
513
            for channel in self._outputs:
514
                channel.push(output_data, self._name)
B
barrierye 已提交
515
            _profiler.record("{}{}-push_1".format(self._name, concurrency_idx))
516 517 518

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
519

520 521 522
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
523 524 525

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
526
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
527
        super(GeneralPythonService, self).__init__()
B
barrierye 已提交
528
        self._name = "#G"
529 530
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
531 532
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
533 534 535 536 537
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
538 539
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
540
        self._retry = retry
B
barrierye 已提交
541 542 543
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
544 545 546

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
547

548
    def set_in_channel(self, in_channel):
549 550 551 552 553
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
        in_channel.add_producer(self._name)
554 555 556
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
557 558 559 560 561
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
        out_channel.add_consumer(self._name)
562 563
        self._out_channel = out_channel

B
barrierye 已提交
564 565
    def _recive_out_channel_func(self):
        while True:
566 567
            channeldata = self._out_channel.front(self._name)
            if not isinstance(channeldata, ChannelData):
568 569
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
570
                              format(type(channeldata))))
B
barrierye 已提交
571
            with self._cv:
572 573
                data_id = channeldata.pbdata.id
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
574
                self._cv.notify_all()
B
barrierye 已提交
575 576

    def _get_next_id(self):
B
barrierye 已提交
577
        with self._id_lock:
B
barrierye 已提交
578 579 580 581
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
582 583 584 585 586 587
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
588
        return resp
B
barrierye 已提交
589 590

    def _pack_data_for_infer(self, request):
591
        logging.debug(self._log('start inferce'))
592
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
593
        data_id = self._get_next_id()
594
        pbdata.id = data_id
B
barrierye 已提交
595
        for idx, name in enumerate(request.feed_var_names):
596 597 598
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
599
            inst = channel_pb2.Inst()
B
barrierye 已提交
600
            inst.data = request.feed_insts[idx]
601
            inst.shape = request.shape[idx]
B
barrierye 已提交
602
            inst.name = name
603 604 605 606
            inst.type = request.type[idx]
            pbdata.insts.append(inst)
        pbdata.ecode = 0  #TODO: parse request error
        return ChannelData(pbdata=pbdata), data_id
B
barrierye 已提交
607

608 609
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
610
        logging.debug(self._log('gen resp'))
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
        resp = pyservice_pb2.Response()
        resp.ecode = channeldata.pbdata.ecode
        if resp.ecode == 0:
            if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
                for inst in channeldata.pbdata.insts:
                    logging.debug(self._log('append data'))
                    resp.fetch_insts.append(inst.data)
                    logging.debug(self._log('append name'))
                    resp.fetch_var_names.append(inst.name)
                    logging.debug(self._log('append shape'))
                    resp.shape.append(inst.shape)
                    logging.debug(self._log('append type'))
                    resp.type.append(inst.type)
            elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
                feed = channeldata.futures.result()
                if channeldata.callback_func is not None:
                    feed = channeldata.callback_func(feed)
                for name, var in feed:
                    logging.debug(self._log('append data'))
                    resp.fetch_insts.append(var.tobytes())
                    logging.debug(self._log('append name'))
                    resp.fetch_var_names.append(name)
                    logging.debug(self._log('append shape'))
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
                    self._log("Error type({}) in pbdata.type.".format(
                        self.pbdata.type)))
B
barrierye 已提交
642
        else:
643
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
644
        return resp
B
barrierye 已提交
645

B
barrierye 已提交
646
    def inference(self, request, context):
B
barrierye 已提交
647
        _profiler.record("{}-prepack_0".format(self._name))
B
barrierye 已提交
648
        data, data_id = self._pack_data_for_infer(request)
B
barrierye 已提交
649 650
        _profiler.record("{}-prepack_1".format(self._name))

651
        resp_channeldata = None
B
barrierye 已提交
652 653 654 655 656 657 658 659
        for i in range(self._retry):
            logging.debug(self._log('push data'))
            _profiler.record("{}-push_0".format(self._name))
            self._in_channel.push(data, self._name)
            _profiler.record("{}-push_1".format(self._name))

            logging.debug(self._log('wait for infer'))
            _profiler.record("{}-fetch_0".format(self._name))
660
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
B
barrierye 已提交
661
            _profiler.record("{}-fetch_1".format(self._name))
B
barrierye 已提交
662

663
            if resp_channeldata.pbdata.ecode == 0:
B
barrierye 已提交
664
                break
665 666
            logging.warn("retry({}): {}".format(
                i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
667 668

        _profiler.record("{}-postpack_0".format(self._name))
669
        resp = self._pack_data_for_resp(resp_channeldata)
B
barrierye 已提交
670 671
        _profiler.record("{}-postpack_1".format(self._name))
        _profiler.print_profile()
B
barrierye 已提交
672 673
        return resp

B
barrierye 已提交
674 675

class PyServer(object):
B
barrierye 已提交
676
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
677 678 679 680 681
        self._channels = []
        self._ops = []
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
682 683
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
684
        self._retry = retry
B
barrierye 已提交
685
        _profiler.enable(profile)
B
barrierye 已提交
686 687 688 689 690

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
B
barrierye 已提交
691
        self._ops.append(op)
B
barrierye 已提交
692 693

    def gen_desc(self):
694
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
695 696 697 698 699
        pass

    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
B
barrierye 已提交
700 701
        inputs = set()
        outputs = set()
B
barrierye 已提交
702
        for op in self._ops:
703
            inputs |= set([op.get_input()])
B
barrierye 已提交
704
            outputs |= set(op.get_outputs())
B
barrierye 已提交
705 706
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
707 708 709 710 711 712 713 714 715 716
        in_channel = inputs - outputs
        out_channel = outputs - inputs
        if len(in_channel) != 1 or len(out_channel) != 1:
            raise Exception(
                "in_channel(out_channel) more than 1 or no in_channel(out_channel)"
            )
        self._in_channel = in_channel.pop()
        self._out_channel = out_channel.pop()
        self.gen_desc()

B
barrierye 已提交
717 718
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
719

720
    def _run_ops(self):
B
barrierye 已提交
721
        for op in self._ops:
722
            op_concurrency = op.get_concurrency()
723 724
            logging.debug("run op: {}, op_concurrency: {}".format(
                op._name, op_concurrency))
725
            for c in range(op_concurrency):
B
barrierye 已提交
726
                # th = multiprocessing.Process(
727
                th = threading.Thread(
B
barrierye 已提交
728
                    target=self._op_start_wrapper, args=(op, c))
729 730 731
                th.start()
                self._op_threads.append(th)

732 733 734 735
    def _stop_ops(self):
        for op in self._ops:
            op.stop()

736 737
    def run_server(self):
        self._run_ops()
B
barrierye 已提交
738 739
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
740
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
741 742
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
743
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
744
        server.start()
745 746 747 748
        server.wait_for_termination()
        self._stop_ops()  # TODO
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
749 750 751 752 753 754 755

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
756 757
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
758
        else:
759 760
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
761 762
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
763
        return
764
        # os.system(cmd)