pyserver.py 33.7 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20
import paddle_serving_server
21
from paddle_serving_client import MultiLangClient as Client
B
barrierye 已提交
22
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24
import grpc
25 26 27 28
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
29
import logging
30
import random
B
barrierye 已提交
31
import time
B
barrierye 已提交
32
import func_timeout
33
import enum
34
import collections
B
barrierye 已提交
35 36


B
barrierye 已提交
37 38 39 40 41 42 43 44 45 46 47
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
48 49
        if self._enable is False:
            return
B
barrierye 已提交
50 51 52 53 54 55
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
56 57
        if self._enable is False:
            return
B
barrierye 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1


class ChannelData(object):
    def __init__(self,
                 future=None,
                 pbdata=None,
                 data_id=None,
                 callback_func=None):
        self.future = future
        if pbdata is None:
            if data_id is None:
                raise ValueError("data_id cannot be None")
            pbdata = channel_pb2.ChannelData()
            pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
            pbdata.ecode = ChannelDataEcode.OK.value
            pbdata.id = data_id
        self.pbdata = pbdata
        self.callback_func = callback_func

    def parse(self):
        # return narray
        feed = {}
        if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
        elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
        else:
            raise TypeError(
                self._log("Error type({}) in pbdata.type.".format(
                    self.pbdata.type)))
        return feed


B
barrierye 已提交
122
class Channel(Queue.Queue):
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
140
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
141
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
142 143
        self._maxsize = maxsize
        self._timeout = timeout
144
        self.name = name
145
        self._stop = False
146

147 148 149 150 151 152 153 154
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
155
        self._consumer_base_idx = 0
156 157 158 159 160 161 162 163 164
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
165
        return "[{}] {}".format(self.name, info_str)
166 167 168 169 170 171

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
172
        """ not thread safe, and can only be called during initialization. """
173 174 175 176
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
177 178

    def add_consumer(self, op_name):
B
barrierye 已提交
179
        """ not thread safe, and can only be called during initialization. """
180 181 182 183
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
184 185 186 187

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
188

189
    def push(self, channeldata, op_name=None):
190
        logging.debug(
191 192
            self._log("{} try to push data: {}".format(op_name,
                                                       channeldata.pbdata)))
193
        if len(self._producers) == 0:
194
            raise Exception(
195 196 197 198
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
199
            with self._cv:
200
                while self._stop is False:
B
barrierye 已提交
201
                    try:
202
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
203 204 205 206
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
207 208 209 210 211 212
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
213

214
        producer_num = len(self._producers)
215
        data_id = channeldata.pbdata.id
216
        put_data = None
B
barrierye 已提交
217
        with self._cv:
218
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
219 220 221 222 223 224
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
225
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
226 227 228 229 230 231
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
232

B
barrierye 已提交
233 234
            if put_data is None:
                logging.debug(
235
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
236 237
                              format(op_name)))
            else:
238
                while self._stop is False:
B
barrierye 已提交
239 240 241 242 243 244 245 246 247
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
248
        return True
249

250 251 252 253 254 255 256 257 258
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
259
            with self._cv:
260
                while self._stop is False and resp is None:
B
barrierye 已提交
261 262 263 264 265
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
266 267 268 269 270 271
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
272

B
barrierye 已提交
273 274
        with self._cv:
            # data_idx = consumer_idx - base_idx
275 276
            while self._stop is False and self._consumers[
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
277
                try:
278 279
                    channeldata = self.get(timeout=0)
                    self._front_res.append(channeldata)
B
barrierye 已提交
280 281 282
                    break
                except Queue.Empty:
                    self._cv.wait()
283

B
barrierye 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
302

B
barrierye 已提交
303
            self._cv.notify_all()
304

305
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
306
        return resp  # reference, read only
B
barrierye 已提交
307

308 309 310 311 312
    def stop(self):
        #TODO
        self.close()
        self._stop = True

B
barrierye 已提交
313 314 315

class Op(object):
    def __init__(self,
316
                 name,
317
                 inputs,
B
barrierye 已提交
318 319 320 321 322
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
323
                 fetch_names=None,
B
barrierye 已提交
324
                 concurrency=1,
B
barrierye 已提交
325 326
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
327
        self._run = False
328
        self.name = name  # to identify the type of OP, it must be globally unique
329
        self._concurrency = concurrency  # amount of concurrency
330 331
        self.set_input_ops(inputs)
        self.set_client(client_config, server_name, fetch_names)
B
barrierye 已提交
332 333
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
334
        self._device = device
B
barrierye 已提交
335
        self._timeout = timeout
B
barrierye 已提交
336
        self._retry = retry
337 338
        self._input = None
        self._outputs = []
B
barrierye 已提交
339 340

    def set_client(self, client_config, server_name, fetch_names):
341 342 343 344 345
        self._client = None
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
            return
B
barrierye 已提交
346 347 348 349 350 351 352 353
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

354
    def get_input_channel(self):
355
        return self._input
B
barrierye 已提交
356

357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
372 373 374 375
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
376
        channel.add_consumer(self.name)
377
        self._input = channel
B
barrierye 已提交
378

379
    def get_output_channels(self):
B
barrierye 已提交
380 381
        return self._outputs

382 383
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
384
            raise TypeError(
385 386 387 388
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
389

390 391
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
392
            raise Exception(
393 394 395
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
396
        feed = channeldata.parse()
397
        return feed
B
barrierye 已提交
398 399

    def midprocess(self, data):
400 401 402 403 404 405 406
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
407 408 409 410
        call_future = self._client.predict(
            feed=data, fetch=self._fetch_names, asyn=True)
        logging.debug(self._log("get call_future"))
        return call_future
B
barrierye 已提交
411 412

    def postprocess(self, output_data):
B
barrierye 已提交
413
        return output_data
B
barrierye 已提交
414

415 416 417 418
    def errorprocess(self, error_info, data_id):
        data = channel_pb2.ChannelData()
        data.ecode = 1
        data.id = data_id
B
barrierye 已提交
419 420 421
        data.error_info = error_info
        return data

B
barrierye 已提交
422
    def stop(self):
423 424 425
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
426 427
        self._run = False

428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
    def _parse_channeldata(self, channeldata):
        data_id, error_data = None, None
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
            data_id = channeldata[key].pbdata.id
            for _, data in channeldata.items():
                if data.pbdata.ecode != 0:
                    error_data = data
                    break
        else:
            data_id = channeldata.pbdata.id
            if channeldata.pbdata.ecode != 0:
                error_data = channeldata.pbdata
        return data_id, error_data

B
barrierye 已提交
444
    def start(self, concurrency_idx):
B
barrierye 已提交
445 446
        self._run = True
        while self._run:
447 448 449
            _profiler.record("{}{}-get_0".format(self.name, concurrency_idx))
            input_data = self._input.front(self.name)
            _profiler.record("{}{}-get_1".format(self.name, concurrency_idx))
450
            logging.debug(self._log("input_data: {}".format(input_data)))
B
barrierye 已提交
451

452 453 454
            data_id, error_data = self._parse_channeldata(input_data)

            output_data = None
B
barrierye 已提交
455
            if error_data is None:
456
                _profiler.record("{}{}-prep_0".format(self.name,
B
barrierye 已提交
457
                                                      concurrency_idx))
B
barrierye 已提交
458
                data = self.preprocess(input_data)
459
                _profiler.record("{}{}-prep_1".format(self.name,
B
barrierye 已提交
460
                                                      concurrency_idx))
B
barrierye 已提交
461

462
                call_future = None
B
barrierye 已提交
463 464
                error_info = None
                if self.with_serving():
B
barrierye 已提交
465
                    for i in range(self._retry):
466
                        _profiler.record("{}{}-midp_0".format(self.name,
B
barrierye 已提交
467
                                                              concurrency_idx))
B
bug fix  
barrierye 已提交
468
                        if self._timeout > 0:
B
barrierye 已提交
469
                            try:
470
                                call_future = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
471 472 473
                                    self._timeout,
                                    self.midprocess,
                                    args=(data, ))
B
barrierye 已提交
474 475 476
                            except func_timeout.FunctionTimedOut:
                                logging.error("error: timeout")
                                error_info = "{}({}): timeout".format(
477
                                    self.name, concurrency_idx)
B
barrierye 已提交
478 479 480
                            except Exception as e:
                                logging.error("error: {}".format(e))
                                error_info = "{}({}): {}".format(
481
                                    self.name, concurrency_idx, e)
B
barrierye 已提交
482
                        else:
483
                            call_future = self.midprocess(data)
484
                        _profiler.record("{}{}-midp_1".format(self.name,
B
barrierye 已提交
485 486 487 488 489 490
                                                              concurrency_idx))
                        if i + 1 < self._retry:
                            error_info = None
                            logging.warn(
                                self._log("warn: timeout, retry({})".format(i +
                                                                            1)))
491
                _profiler.record("{}{}-postp_0".format(self.name,
B
barrierye 已提交
492 493
                                                       concurrency_idx))
                if error_info is not None:
494 495
                    error_data = self.errorprocess(error_info, data_id)
                    output_data = ChannelData(pbdata=error_data)
B
barrierye 已提交
496
                else:
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
                    if self.with_serving():  # use call_future
                        output_data = ChannelData(
                            future=call_future,
                            data_id=data_id,
                            callback_func=self.postprocess)
                    else:
                        post_data = self.postprocess(data)
                        if not isinstance(post_data, dict):
                            raise TypeError(
                                self._log(
                                    'output_data must be dict type, but get {}'.
                                    format(type(output_data))))
                        pbdata = channel_pb2.ChannelData()
                        for name, value in post_data.items():
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
                        pbdata.ecode = 0
                        pbdata.id = data_id
                        output_data = ChannelData(pbdata=pbdata)
521
                _profiler.record("{}{}-postp_1".format(self.name,
B
barrierye 已提交
522 523
                                                       concurrency_idx))
            else:
524
                output_data = ChannelData(pbdata=error_data)
B
barrierye 已提交
525

526
            _profiler.record("{}{}-push_0".format(self.name, concurrency_idx))
B
barrierye 已提交
527
            for channel in self._outputs:
528 529
                channel.push(output_data, self.name)
            _profiler.record("{}{}-push_1".format(self.name, concurrency_idx))
530 531

    def _log(self, info_str):
532
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
533

534 535 536
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
537 538 539

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
540
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
541
        super(GeneralPythonService, self).__init__()
542
        self.name = "#G"
543 544
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
545 546
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
547 548 549 550 551
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
552 553
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
554
        self._retry = retry
B
barrierye 已提交
555 556 557
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
558 559

    def _log(self, info_str):
560
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
561

562
    def set_in_channel(self, in_channel):
563 564 565 566
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
567
        in_channel.add_producer(self.name)
568 569 570
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
571 572 573 574
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
575
        out_channel.add_consumer(self.name)
576 577
        self._out_channel = out_channel

B
barrierye 已提交
578 579
    def _recive_out_channel_func(self):
        while True:
580
            channeldata = self._out_channel.front(self.name)
581
            if not isinstance(channeldata, ChannelData):
582 583
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
584
                              format(type(channeldata))))
B
barrierye 已提交
585
            with self._cv:
586 587
                data_id = channeldata.pbdata.id
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
588
                self._cv.notify_all()
B
barrierye 已提交
589 590

    def _get_next_id(self):
B
barrierye 已提交
591
        with self._id_lock:
B
barrierye 已提交
592 593 594 595
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
596 597 598 599 600 601
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
602
        return resp
B
barrierye 已提交
603 604

    def _pack_data_for_infer(self, request):
605
        logging.debug(self._log('start inferce'))
606
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
607
        data_id = self._get_next_id()
608
        pbdata.id = data_id
B
barrierye 已提交
609
        for idx, name in enumerate(request.feed_var_names):
610 611 612
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
613
            inst = channel_pb2.Inst()
B
barrierye 已提交
614
            inst.data = request.feed_insts[idx]
615
            inst.shape = request.shape[idx]
B
barrierye 已提交
616
            inst.name = name
617 618 619 620
            inst.type = request.type[idx]
            pbdata.insts.append(inst)
        pbdata.ecode = 0  #TODO: parse request error
        return ChannelData(pbdata=pbdata), data_id
B
barrierye 已提交
621

622 623
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
624
        logging.debug(self._log('gen resp'))
625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
        resp = pyservice_pb2.Response()
        resp.ecode = channeldata.pbdata.ecode
        if resp.ecode == 0:
            if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
                for inst in channeldata.pbdata.insts:
                    logging.debug(self._log('append data'))
                    resp.fetch_insts.append(inst.data)
                    logging.debug(self._log('append name'))
                    resp.fetch_var_names.append(inst.name)
                    logging.debug(self._log('append shape'))
                    resp.shape.append(inst.shape)
                    logging.debug(self._log('append type'))
                    resp.type.append(inst.type)
            elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
                feed = channeldata.futures.result()
                if channeldata.callback_func is not None:
                    feed = channeldata.callback_func(feed)
                for name, var in feed:
                    logging.debug(self._log('append data'))
                    resp.fetch_insts.append(var.tobytes())
                    logging.debug(self._log('append name'))
                    resp.fetch_var_names.append(name)
                    logging.debug(self._log('append shape'))
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
                    self._log("Error type({}) in pbdata.type.".format(
                        self.pbdata.type)))
B
barrierye 已提交
656
        else:
657
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
658
        return resp
B
barrierye 已提交
659

B
barrierye 已提交
660
    def inference(self, request, context):
661
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
662
        data, data_id = self._pack_data_for_infer(request)
663
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
664

665
        resp_channeldata = None
B
barrierye 已提交
666 667
        for i in range(self._retry):
            logging.debug(self._log('push data'))
668 669 670
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
671 672

            logging.debug(self._log('wait for infer'))
673
            _profiler.record("{}-fetch_0".format(self.name))
674
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
675
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
676

677
            if resp_channeldata.pbdata.ecode == 0:
B
barrierye 已提交
678
                break
679 680
            logging.warn("retry({}): {}".format(
                i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
681

682
        _profiler.record("{}-postpack_0".format(self.name))
683
        resp = self._pack_data_for_resp(resp_channeldata)
684
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
685
        _profiler.print_profile()
B
barrierye 已提交
686 687
        return resp

B
barrierye 已提交
688

689 690 691 692
class VirtualOp(Op):
    pass


B
barrierye 已提交
693
class PyServer(object):
B
barrierye 已提交
694
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
695
        self._channels = []
696 697
        self._user_ops = []
        self._total_ops = []
B
barrierye 已提交
698 699 700
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
701 702
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
703
        self._retry = retry
B
barrierye 已提交
704
        _profiler.enable(profile)
B
barrierye 已提交
705 706 707 708 709

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
710 711 712 713
        self._user_ops.append(op)

    def add_ops(self, ops):
        self._user_ops.expand(ops)
B
barrierye 已提交
714 715

    def gen_desc(self):
716
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
717 718
        pass

719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
    def _topo_sort(self):
        indeg_num = {}
        outdegs = {}
        que_idx = 0  # scroll queue 
        ques = [Queue.SimpleQueue() for _ in range(2)]
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
                if op.name in outdegs:
                    outdegs[op.name].append(op)
                else:
                    outdegs[op.name] = [op]

        # get dag_views
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                op_name = op.name
                sorted_op_num += 1
                for succ_op in outdegs[op_name]:
                    indeg_num[op_name] -= 1
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")

        # create channels and virtual ops
        virtual_op_idx = 0
        channel_idx = 0
        virtual_ops = []
        channels = []
        input_channel = None
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
            actual_next_view = []
            pred_op_of_next_view_op = {}
            for op in view:
                # create virtual op
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
                        actual_next_view.append(succ_op)
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
                        vop = VirtualOp(name="vir{}".format(virtual_op_idx))
                        virtual_op_idx += 1
                        virtual_ops.append(virtual_op)
                        outdegs[vop.name] = [succ_op]
                        actual_next_view.append(vop)
                        # TODO: combine vop
                        pred_op_of_next_view_op[vop.name] = [op]
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
                op_name = op.name
                if op_name in processed_op:
                    continue
                channel = Channel(name="chl{}".format(channel_idx))
                channel_idx += 1
                channels.append(channel)
                op.add_input_channel(channel)
                pred_ops = pred_op_of_next_view_op[op_name]
                if v_idx == 0:
                    input_channel = channel
                else:
                    for pred_op in pred_ops:
                        pred_op.add_output_channel(channel)
                processed_op.add(op_name)
                # combine channel
                for other_op in actual_next_view[o_idx:]:
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
        output_channel = Channel(name="Ochl")
        channels.append(output_channel)
        last_op = dag_views[-1][0]
        last_op.add_output_channel(output_channel)

        self._ops = self._user_ops + virtual_ops
        self._channels = channels
        return input_channel, output_channel

B
barrierye 已提交
833 834 835
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
836 837 838 839

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
        self.out_channel = output_channel
B
barrierye 已提交
840 841
        self.gen_desc()

B
barrierye 已提交
842 843
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
844

845
    def _run_ops(self):
846
        #TODO
B
barrierye 已提交
847
        for op in self._ops:
848
            op_concurrency = op.get_concurrency()
849
            logging.debug("run op: {}, op_concurrency: {}".format(
850
                op.name, op_concurrency))
851
            for c in range(op_concurrency):
B
barrierye 已提交
852
                # th = multiprocessing.Process(
853
                th = threading.Thread(
B
barrierye 已提交
854
                    target=self._op_start_wrapper, args=(op, c))
855 856 857
                th.start()
                self._op_threads.append(th)

858
    def _stop_ops(self):
859
        # TODO
860 861 862
        for op in self._ops:
            op.stop()

863 864
    def run_server(self):
        self._run_ops()
B
barrierye 已提交
865 866
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
867
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
868 869
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
870
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
871
        server.start()
872 873 874 875
        server.wait_for_termination()
        self._stop_ops()  # TODO
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
876 877 878 879 880 881 882

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
883 884
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
885
        else:
886 887
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
888 889
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
890
        return
891
        # os.system(cmd)