pyserver.py 33.9 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20
import paddle_serving_server
21
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
22
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24
import grpc
25 26 27 28
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
29
import logging
30
import random
B
barrierye 已提交
31
import time
B
barrierye 已提交
32
import func_timeout
33
import enum
34
import collections
B
barrierye 已提交
35 36


B
barrierye 已提交
37 38 39 40 41 42 43 44 45 46 47
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
48 49
        if self._enable is False:
            return
B
barrierye 已提交
50 51 52 53 54 55
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
56 57
        if self._enable is False:
            return
B
barrierye 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1


class ChannelData(object):
    def __init__(self,
                 future=None,
                 pbdata=None,
                 data_id=None,
                 callback_func=None):
        self.future = future
        if pbdata is None:
            if data_id is None:
                raise ValueError("data_id cannot be None")
            pbdata = channel_pb2.ChannelData()
            pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
            pbdata.ecode = ChannelDataEcode.OK.value
            pbdata.id = data_id
        self.pbdata = pbdata
        self.callback_func = callback_func

    def parse(self):
        # return narray
        feed = {}
        if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
        elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
        else:
            raise TypeError(
                self._log("Error type({}) in pbdata.type.".format(
                    self.pbdata.type)))
        return feed


B
barrierye 已提交
122
class Channel(Queue.Queue):
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
140
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
141
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
142 143
        self._maxsize = maxsize
        self._timeout = timeout
144
        self.name = name
145
        self._stop = False
146

147 148 149 150 151 152 153 154
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
155
        self._consumer_base_idx = 0
156 157 158 159 160 161 162 163 164
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
165
        return "[{}] {}".format(self.name, info_str)
166 167 168 169 170 171

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
172
        """ not thread safe, and can only be called during initialization. """
173 174 175 176
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
177 178

    def add_consumer(self, op_name):
B
barrierye 已提交
179
        """ not thread safe, and can only be called during initialization. """
180 181 182 183
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
184 185 186 187

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
188

189
    def push(self, channeldata, op_name=None):
190
        logging.debug(
191 192
            self._log("{} try to push data: {}".format(op_name,
                                                       channeldata.pbdata)))
193
        if len(self._producers) == 0:
194
            raise Exception(
195 196 197 198
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
199
            with self._cv:
200
                while self._stop is False:
B
barrierye 已提交
201
                    try:
202
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
203 204 205 206
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
207 208 209 210 211 212
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
213

214
        producer_num = len(self._producers)
215
        data_id = channeldata.pbdata.id
216
        put_data = None
B
barrierye 已提交
217
        with self._cv:
218
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
219 220 221 222 223 224
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
225
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
226 227 228 229 230 231
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
232

B
barrierye 已提交
233 234
            if put_data is None:
                logging.debug(
235
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
236 237
                              format(op_name)))
            else:
238
                while self._stop is False:
B
barrierye 已提交
239 240 241 242 243 244 245 246 247
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
248
        return True
249

250 251 252 253 254 255 256 257 258
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
259
            with self._cv:
260
                while self._stop is False and resp is None:
B
barrierye 已提交
261 262 263 264 265
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
266 267 268 269 270 271
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
272

B
barrierye 已提交
273 274
        with self._cv:
            # data_idx = consumer_idx - base_idx
275 276
            while self._stop is False and self._consumers[
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
277
                try:
278 279
                    channeldata = self.get(timeout=0)
                    self._front_res.append(channeldata)
B
barrierye 已提交
280 281 282
                    break
                except Queue.Empty:
                    self._cv.wait()
283

B
barrierye 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
302

B
barrierye 已提交
303
            self._cv.notify_all()
304

305
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
306
        return resp  # reference, read only
B
barrierye 已提交
307

308 309 310 311 312
    def stop(self):
        #TODO
        self.close()
        self._stop = True

B
barrierye 已提交
313 314 315

class Op(object):
    def __init__(self,
316
                 name,
317
                 inputs,
B
barrierye 已提交
318 319 320 321 322
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
323
                 fetch_names=None,
B
barrierye 已提交
324
                 concurrency=1,
B
barrierye 已提交
325 326
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
327
        self._run = False
328
        self.name = name  # to identify the type of OP, it must be globally unique
329
        self._concurrency = concurrency  # amount of concurrency
330 331
        self.set_input_ops(inputs)
        self.set_client(client_config, server_name, fetch_names)
B
barrierye 已提交
332 333
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
334
        self._device = device
B
barrierye 已提交
335
        self._timeout = timeout
B
barrierye 已提交
336
        self._retry = retry
337 338
        self._input = None
        self._outputs = []
B
barrierye 已提交
339 340

    def set_client(self, client_config, server_name, fetch_names):
341 342 343 344 345
        self._client = None
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
            return
B
barrierye 已提交
346 347 348 349 350 351 352 353
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

354
    def get_input_channel(self):
355
        return self._input
B
barrierye 已提交
356

357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
372 373 374 375
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
376
        channel.add_consumer(self.name)
377
        self._input = channel
B
barrierye 已提交
378

379
    def get_output_channels(self):
B
barrierye 已提交
380 381
        return self._outputs

382 383
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
384
            raise TypeError(
385 386 387 388
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
389

390 391
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
392
            raise Exception(
393 394 395
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
396
        feed = channeldata.parse()
397
        return feed
B
barrierye 已提交
398 399

    def midprocess(self, data):
400 401 402 403 404 405 406
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
407 408 409 410
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
411 412

    def postprocess(self, output_data):
B
barrierye 已提交
413
        return output_data
B
barrierye 已提交
414

415 416 417 418
    def errorprocess(self, error_info, data_id):
        data = channel_pb2.ChannelData()
        data.ecode = 1
        data.id = data_id
B
barrierye 已提交
419 420 421
        data.error_info = error_info
        return data

B
barrierye 已提交
422
    def stop(self):
423 424 425
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
426 427
        self._run = False

428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
    def _parse_channeldata(self, channeldata):
        data_id, error_data = None, None
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
            data_id = channeldata[key].pbdata.id
            for _, data in channeldata.items():
                if data.pbdata.ecode != 0:
                    error_data = data
                    break
        else:
            data_id = channeldata.pbdata.id
            if channeldata.pbdata.ecode != 0:
                error_data = channeldata.pbdata
        return data_id, error_data

B
barrierye 已提交
444
    def start(self, concurrency_idx):
B
barrierye 已提交
445 446
        self._run = True
        while self._run:
447 448 449
            _profiler.record("{}{}-get_0".format(self.name, concurrency_idx))
            input_data = self._input.front(self.name)
            _profiler.record("{}{}-get_1".format(self.name, concurrency_idx))
450
            logging.debug(self._log("input_data: {}".format(input_data)))
B
barrierye 已提交
451

452 453 454
            data_id, error_data = self._parse_channeldata(input_data)

            output_data = None
B
barrierye 已提交
455
            if error_data is None:
456
                _profiler.record("{}{}-prep_0".format(self.name,
B
barrierye 已提交
457
                                                      concurrency_idx))
B
barrierye 已提交
458
                data = self.preprocess(input_data)
459
                _profiler.record("{}{}-prep_1".format(self.name,
B
barrierye 已提交
460
                                                      concurrency_idx))
B
barrierye 已提交
461

462
                call_future = None
B
barrierye 已提交
463 464
                error_info = None
                if self.with_serving():
465 466 467 468
                    _profiler.record("{}{}-midp_0".format(self.name,
                                                          concurrency_idx))
                    if self._timeout > 0:
                        for i in range(self._retry):
B
barrierye 已提交
469
                            try:
470
                                call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
471 472 473
                                    self._timeout,
                                    self.midprocess,
                                    args=(data, ))
B
barrierye 已提交
474 475 476
                            except func_timeout.FunctionTimedOut:
                                logging.error("error: timeout")
                                error_info = "{}({}): timeout".format(
477
                                    self.name, concurrency_idx)
478 479 480 481 482
                                if i + 1 < self._retry:
                                    error_info = None
                                    logging.warn(
                                        self._log("warn: timeout, retry({})".
                                                  format(i + 1)))
B
barrierye 已提交
483 484 485
                            except Exception as e:
                                logging.error("error: {}".format(e))
                                error_info = "{}({}): {}".format(
486
                                    self.name, concurrency_idx, e)
487 488 489 490 491 492 493 494 495 496
                                logging.warn(self._log(e))
                                # TODO
                                break
                            else:
                                break
                    else:
                        call_future = self.midprocess(data)

                    _profiler.record("{}{}-midp_1".format(self.name,
                                                          concurrency_idx))
497
                _profiler.record("{}{}-postp_0".format(self.name,
B
barrierye 已提交
498 499
                                                       concurrency_idx))
                if error_info is not None:
500 501
                    error_data = self.errorprocess(error_info, data_id)
                    output_data = ChannelData(pbdata=error_data)
B
barrierye 已提交
502
                else:
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
                    if self.with_serving():  # use call_future
                        output_data = ChannelData(
                            future=call_future,
                            data_id=data_id,
                            callback_func=self.postprocess)
                    else:
                        post_data = self.postprocess(data)
                        if not isinstance(post_data, dict):
                            raise TypeError(
                                self._log(
                                    'output_data must be dict type, but get {}'.
                                    format(type(output_data))))
                        pbdata = channel_pb2.ChannelData()
                        for name, value in post_data.items():
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
                        pbdata.ecode = 0
                        pbdata.id = data_id
                        output_data = ChannelData(pbdata=pbdata)
527
                _profiler.record("{}{}-postp_1".format(self.name,
B
barrierye 已提交
528 529
                                                       concurrency_idx))
            else:
530
                output_data = ChannelData(pbdata=error_data)
B
barrierye 已提交
531

532
            _profiler.record("{}{}-push_0".format(self.name, concurrency_idx))
B
barrierye 已提交
533
            for channel in self._outputs:
534 535
                channel.push(output_data, self.name)
            _profiler.record("{}{}-push_1".format(self.name, concurrency_idx))
536 537

    def _log(self, info_str):
538
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
539

540 541 542
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
543 544 545

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
546
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
547
        super(GeneralPythonService, self).__init__()
548
        self.name = "#G"
549 550
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
551 552
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
553 554 555 556 557
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
558 559
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
560
        self._retry = retry
B
barrierye 已提交
561 562 563
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
564 565

    def _log(self, info_str):
566
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
567

568
    def set_in_channel(self, in_channel):
569 570 571 572
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
573
        in_channel.add_producer(self.name)
574 575 576
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
577 578 579 580
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
581
        out_channel.add_consumer(self.name)
582 583
        self._out_channel = out_channel

B
barrierye 已提交
584 585
    def _recive_out_channel_func(self):
        while True:
586
            channeldata = self._out_channel.front(self.name)
587
            if not isinstance(channeldata, ChannelData):
588 589
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
590
                              format(type(channeldata))))
B
barrierye 已提交
591
            with self._cv:
592 593
                data_id = channeldata.pbdata.id
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
594
                self._cv.notify_all()
B
barrierye 已提交
595 596

    def _get_next_id(self):
B
barrierye 已提交
597
        with self._id_lock:
B
barrierye 已提交
598 599 600 601
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
602 603 604 605 606 607
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
608
        return resp
B
barrierye 已提交
609 610

    def _pack_data_for_infer(self, request):
611
        logging.debug(self._log('start inferce'))
612
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
613
        data_id = self._get_next_id()
614
        pbdata.id = data_id
B
barrierye 已提交
615
        for idx, name in enumerate(request.feed_var_names):
616 617 618
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
619
            inst = channel_pb2.Inst()
B
barrierye 已提交
620
            inst.data = request.feed_insts[idx]
621
            inst.shape = request.shape[idx]
B
barrierye 已提交
622
            inst.name = name
623 624 625 626
            inst.type = request.type[idx]
            pbdata.insts.append(inst)
        pbdata.ecode = 0  #TODO: parse request error
        return ChannelData(pbdata=pbdata), data_id
B
barrierye 已提交
627

628 629
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
630
        logging.debug(self._log('gen resp'))
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
        resp = pyservice_pb2.Response()
        resp.ecode = channeldata.pbdata.ecode
        if resp.ecode == 0:
            if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
                for inst in channeldata.pbdata.insts:
                    logging.debug(self._log('append data'))
                    resp.fetch_insts.append(inst.data)
                    logging.debug(self._log('append name'))
                    resp.fetch_var_names.append(inst.name)
                    logging.debug(self._log('append shape'))
                    resp.shape.append(inst.shape)
                    logging.debug(self._log('append type'))
                    resp.type.append(inst.type)
            elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
                feed = channeldata.futures.result()
                if channeldata.callback_func is not None:
                    feed = channeldata.callback_func(feed)
                for name, var in feed:
                    logging.debug(self._log('append data'))
                    resp.fetch_insts.append(var.tobytes())
                    logging.debug(self._log('append name'))
                    resp.fetch_var_names.append(name)
                    logging.debug(self._log('append shape'))
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
                    self._log("Error type({}) in pbdata.type.".format(
                        self.pbdata.type)))
B
barrierye 已提交
662
        else:
663
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
664
        return resp
B
barrierye 已提交
665

B
barrierye 已提交
666
    def inference(self, request, context):
667
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
668
        data, data_id = self._pack_data_for_infer(request)
669
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
670

671
        resp_channeldata = None
B
barrierye 已提交
672 673
        for i in range(self._retry):
            logging.debug(self._log('push data'))
674 675 676
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
677 678

            logging.debug(self._log('wait for infer'))
679
            _profiler.record("{}-fetch_0".format(self.name))
680
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
681
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
682

683
            if resp_channeldata.pbdata.ecode == 0:
B
barrierye 已提交
684
                break
685 686
            logging.warn("retry({}): {}".format(
                i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
687

688
        _profiler.record("{}-postpack_0".format(self.name))
689
        resp = self._pack_data_for_resp(resp_channeldata)
690
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
691
        _profiler.print_profile()
B
barrierye 已提交
692 693
        return resp

B
barrierye 已提交
694 695

class PyServer(object):
B
barrierye 已提交
696
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
697
        self._channels = []
698 699
        self._user_ops = []
        self._total_ops = []
B
barrierye 已提交
700 701 702
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
703 704
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
705
        self._retry = retry
B
barrierye 已提交
706
        _profiler.enable(profile)
B
barrierye 已提交
707 708 709 710 711

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
712 713 714
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
715
        self._user_ops.extend(ops)
B
barrierye 已提交
716 717

    def gen_desc(self):
718
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
719 720
        pass

721 722
    def _topo_sort(self):
        indeg_num = {}
B
fix bug  
barrierye 已提交
723
        outdegs = {op.name: [] for op in self._user_ops}
724
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
725
        ques = [Queue.Queue() for _ in range(2)]
726 727 728 729 730 731 732 733
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
734
                outdegs[pred_op.name].append(op)
735 736 737 738 739 740 741 742 743 744 745 746 747 748

        # get dag_views
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                op_name = op.name
                sorted_op_num += 1
                for succ_op in outdegs[op_name]:
B
fix bug  
barrierye 已提交
749
                    indeg_num[succ_op.name] -= 1
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")

        # create channels and virtual ops
        virtual_op_idx = 0
        channel_idx = 0
        virtual_ops = []
        channels = []
        input_channel = None
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
            actual_next_view = []
            pred_op_of_next_view_op = {}
            for op in view:
                # create virtual op
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
                        actual_next_view.append(succ_op)
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
fix bug  
barrierye 已提交
784
                        vop = Op(name="vir{}".format(virtual_op_idx), inputs=[])
785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
                        virtual_op_idx += 1
                        virtual_ops.append(virtual_op)
                        outdegs[vop.name] = [succ_op]
                        actual_next_view.append(vop)
                        # TODO: combine vop
                        pred_op_of_next_view_op[vop.name] = [op]
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
                op_name = op.name
                if op_name in processed_op:
                    continue
                channel = Channel(name="chl{}".format(channel_idx))
                channel_idx += 1
                channels.append(channel)
                op.add_input_channel(channel)
                pred_ops = pred_op_of_next_view_op[op_name]
                if v_idx == 0:
                    input_channel = channel
                else:
                    for pred_op in pred_ops:
                        pred_op.add_output_channel(channel)
                processed_op.add(op_name)
                # combine channel
                for other_op in actual_next_view[o_idx:]:
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
        output_channel = Channel(name="Ochl")
        channels.append(output_channel)
        last_op = dag_views[-1][0]
        last_op.add_output_channel(output_channel)

B
fix bug  
barrierye 已提交
828 829 830 831 832
        self._ops = virtual_ops
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                continue
            self._ops.append(op)
833 834 835
        self._channels = channels
        return input_channel, output_channel

B
barrierye 已提交
836 837 838
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
839 840 841

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
842 843 844 845
        self._out_channel = output_channel
        for op in self._ops:
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
846 847
        self.gen_desc()

B
barrierye 已提交
848 849
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
850

851
    def _run_ops(self):
B
barrierye 已提交
852
        for op in self._ops:
853
            op_concurrency = op.get_concurrency()
854
            logging.debug("run op: {}, op_concurrency: {}".format(
855
                op.name, op_concurrency))
856 857
            for c in range(op_concurrency):
                th = threading.Thread(
B
barrierye 已提交
858
                    target=self._op_start_wrapper, args=(op, c))
859 860 861
                th.start()
                self._op_threads.append(th)

862 863 864 865
    def _stop_ops(self):
        for op in self._ops:
            op.stop()

866 867
    def run_server(self):
        self._run_ops()
B
barrierye 已提交
868 869
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
870
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
871 872
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
873
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
874
        server.start()
875 876 877 878
        server.wait_for_termination()
        self._stop_ops()  # TODO
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
879 880 881 882 883 884 885

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
886 887
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
888
        else:
889 890
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
891 892
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
893
        return
894
        # os.system(cmd)