pyserver.py 24.6 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20 21 22
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24 25 26
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
B
barrierye 已提交
27
import python_service_channel_pb2
B
barrierye 已提交
28
import logging
29
import random
B
barrierye 已提交
30
import time
B
barrierye 已提交
31
import func_timeout
B
barrierye 已提交
32 33


B
barrierye 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


B
barrierye 已提交
70
class Channel(Queue.Queue):
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
88
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
89
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
90 91
        self._maxsize = maxsize
        self._timeout = timeout
92
        self._name = name
93

94 95 96 97 98 99 100 101 102
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._front_wait_interval = 0.1  # second
        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
103
        self._consumer_base_idx = 0
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
120
        """ not thread safe, and can only be called during initialization. """
121 122 123 124
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
125 126

    def add_consumer(self, op_name):
B
barrierye 已提交
127
        """ not thread safe, and can only be called during initialization. """
128 129 130 131
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
132 133 134 135

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
136

137 138 139 140
    def push(self, data, op_name=None):
        logging.debug(
            self._log("{} try to push data: {}".format(op_name, data)))
        if len(self._producers) == 0:
141
            raise Exception(
142 143 144 145
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
146 147 148 149 150 151 152 153
            with self._cv:
                while True:
                    try:
                        self.put(data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
154 155 156 157 158 159
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
160

161 162 163
        producer_num = len(self._producers)
        data_id = data.id
        put_data = None
B
barrierye 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        with self._cv:
            logging.debug(self._log("{} get lock ~".format(op_name)))
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
            self._push_res[data_id][op_name] = data
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
179

B
barrierye 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
            if put_data is None:
                logging.debug(
                    self._log("{} push data succ, not not push to queue.".
                              format(op_name)))
            else:
                while True:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
195
        return True
196

197 198 199 200 201 202 203 204 205
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
206 207 208 209 210 211 212
            with self._cv:
                while resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
213 214 215 216 217 218
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
219

B
barrierye 已提交
220 221 222 223 224 225 226 227 228 229
        with self._cv:
            # data_idx = consumer_idx - base_idx
            while self._consumers[op_name] - self._consumer_base_idx >= len(
                    self._front_res):
                try:
                    data = self.get(timeout=0)
                    self._front_res.append(data)
                    break
                except Queue.Empty:
                    self._cv.wait()
230

B
barrierye 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
249

B
barrierye 已提交
250
            self._cv.notify_all()
251

252
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
253
        return resp  # reference, read only
B
barrierye 已提交
254 255 256 257


class Op(object):
    def __init__(self,
258
                 name,
259
                 input,
B
barrierye 已提交
260
                 in_dtype,
B
barrierye 已提交
261
                 outputs,
B
barrierye 已提交
262
                 out_dtype,
B
barrierye 已提交
263 264 265 266 267
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
268
                 fetch_names=None,
B
barrierye 已提交
269
                 concurrency=1,
B
barrierye 已提交
270 271
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
272
        self._run = False
273 274 275
        # TODO: globally unique check
        self._name = name  # to identify the type of OP, it must be globally unique
        self._concurrency = concurrency  # amount of concurrency
276
        self.set_input(input)
B
barrierye 已提交
277
        self._in_dtype = in_dtype
B
barrierye 已提交
278
        self.set_outputs(outputs)
B
barrierye 已提交
279
        self._out_dtype = out_dtype
B
barrierye 已提交
280
        self._client = None
B
barrierye 已提交
281 282 283 284 285 286
        if client_config is not None and \
                server_name is not None and \
                fetch_names is not None:
            self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
287
        self._device = device
B
barrierye 已提交
288
        self._timeout = timeout
B
barrierye 已提交
289
        self._retry = retry
B
barrierye 已提交
290 291 292 293 294 295 296 297 298 299

    def set_client(self, client_config, server_name, fetch_names):
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

300 301
    def get_input(self):
        return self._input
B
barrierye 已提交
302

303 304 305 306 307 308 309
    def set_input(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self._name)
        self._input = channel
B
barrierye 已提交
310 311 312 313 314 315

    def get_outputs(self):
        return self._outputs

    def set_outputs(self, channels):
        if not isinstance(channels, list):
316 317 318 319 320
            raise TypeError(
                self._log('output channels must be list type, not {}'.format(
                    type(channels))))
        for channel in channels:
            channel.add_producer(self._name)
B
barrierye 已提交
321 322
        self._outputs = channels

323 324
    def preprocess(self, data):
        if isinstance(data, dict):
B
barrierye 已提交
325
            raise Exception(
326 327 328 329 330 331 332
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
        feed = {}
        for inst in data.insts:
            feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
        return feed
B
barrierye 已提交
333 334

    def midprocess(self, data):
335 336 337 338 339 340 341
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
B
barrierye 已提交
342
        fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
343
        logging.debug(self._log("finish predict"))
B
barrierye 已提交
344 345 346
        return fetch_map

    def postprocess(self, output_data):
347 348 349 350
        raise Exception(
            self._log(
                'Please override this method to convert data to the format in channel.'
            ))
B
barrierye 已提交
351

B
barrierye 已提交
352 353 354 355 356 357
    def errorprocess(self, error_info):
        data = python_service_channel_pb2.ChannelData()
        data.is_error = 1
        data.error_info = error_info
        return data

B
barrierye 已提交
358 359 360
    def stop(self):
        self._run = False

B
barrierye 已提交
361
    def start(self, concurrency_idx):
B
barrierye 已提交
362 363
        self._run = True
        while self._run:
B
barrierye 已提交
364
            _profiler.record("{}{}-get_0".format(self._name, concurrency_idx))
365
            input_data = self._input.front(self._name)
B
barrierye 已提交
366
            _profiler.record("{}{}-get_1".format(self._name, concurrency_idx))
367
            data_id = None
B
barrierye 已提交
368 369
            output_data = None
            error_data = None
370 371 372 373
            logging.debug(self._log("input_data: {}".format(input_data)))
            if isinstance(input_data, dict):
                key = input_data.keys()[0]
                data_id = input_data[key].id
B
barrierye 已提交
374 375 376 377
                for _, data in input_data.items():
                    if data.is_error != 0:
                        error_data = data
                        break
B
barrierye 已提交
378
            else:
379
                data_id = input_data.id
B
barrierye 已提交
380 381
                if input_data.is_error != 0:
                    error_data = input_data
B
barrierye 已提交
382

B
barrierye 已提交
383 384
            if error_data is None:
                _profiler.record("{}{}-prep_0".format(self._name,
B
barrierye 已提交
385
                                                      concurrency_idx))
B
barrierye 已提交
386 387
                data = self.preprocess(input_data)
                _profiler.record("{}{}-prep_1".format(self._name,
B
barrierye 已提交
388
                                                      concurrency_idx))
B
barrierye 已提交
389

B
barrierye 已提交
390 391
                error_info = None
                if self.with_serving():
B
barrierye 已提交
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
                    for i in range(self._retry):
                        _profiler.record("{}{}-midp_0".format(self._name,
                                                              concurrency_idx))
                        if self._time > 0:
                            try:
                                middata = func_timeout.func_timeout(
                                    self._time, self.midprocess, args=(data, ))
                            except func_timeout.FunctionTimedOut:
                                logging.error("error: timeout")
                                error_info = "{}({}): timeout".format(
                                    self._name, concurrency_idx)
                            except Exception as e:
                                logging.error("error: {}".format(e))
                                error_info = "{}({}): {}".format(
                                    self._name, concurrency_idx, e)
                        else:
                            middata = self.midprocess(data)
                        _profiler.record("{}{}-midp_1".format(self._name,
                                                              concurrency_idx))
                        if error_info is None:
                            data = middata
                            break
                        if i + 1 < self._retry:
                            error_info = None
                            logging.warn(
                                self._log("warn: timeout, retry({})".format(i +
                                                                            1)))
B
barrierye 已提交
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439

                _profiler.record("{}{}-postp_0".format(self._name,
                                                       concurrency_idx))
                if error_info is not None:
                    output_data = self.errorprocess(error_info)
                else:
                    output_data = self.postprocess(data)

                    if not isinstance(output_data,
                                      python_service_channel_pb2.ChannelData):
                        raise TypeError(
                            self._log(
                                'output_data must be ChannelData type, but get {}'.
                                format(type(output_data))))
                    output_data.is_error = 0
                _profiler.record("{}{}-postp_1".format(self._name,
                                                       concurrency_idx))

                output_data.id = data_id
            else:
                output_data = error_data
B
barrierye 已提交
440

B
barrierye 已提交
441
            _profiler.record("{}{}-push_0".format(self._name, concurrency_idx))
B
barrierye 已提交
442
            for channel in self._outputs:
443
                channel.push(output_data, self._name)
B
barrierye 已提交
444
            _profiler.record("{}{}-push_1".format(self._name, concurrency_idx))
445 446 447

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
448

449 450 451
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
452 453 454

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
455
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
456
        super(GeneralPythonService, self).__init__()
B
barrierye 已提交
457
        self._name = "#G"
458 459
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
460 461
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
462 463 464 465 466
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
467 468
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
469
        self._retry = retry
B
barrierye 已提交
470 471 472
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
473 474 475

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
476

477
    def set_in_channel(self, in_channel):
478 479 480 481 482
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
        in_channel.add_producer(self._name)
483 484 485
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
486 487 488 489 490
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
        out_channel.add_consumer(self._name)
491 492
        self._out_channel = out_channel

B
barrierye 已提交
493 494
    def _recive_out_channel_func(self):
        while True:
495 496 497 498 499
            data = self._out_channel.front(self._name)
            if not isinstance(data, python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
                              format(type(data))))
B
barrierye 已提交
500 501 502
            with self._cv:
                self._globel_resp_dict[data.id] = data
                self._cv.notify_all()
B
barrierye 已提交
503 504

    def _get_next_id(self):
B
barrierye 已提交
505
        with self._id_lock:
B
barrierye 已提交
506 507 508 509
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
510 511 512 513 514 515
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
516
        return resp
B
barrierye 已提交
517 518

    def _pack_data_for_infer(self, request):
519
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
520
        data = python_service_channel_pb2.ChannelData()
B
barrierye 已提交
521 522
        data_id = self._get_next_id()
        data.id = data_id
B
barrierye 已提交
523
        for idx, name in enumerate(request.feed_var_names):
524 525 526
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
527
            inst = python_service_channel_pb2.Inst()
B
barrierye 已提交
528
            inst.data = request.feed_insts[idx]
B
barrierye 已提交
529 530
            inst.name = name
            data.insts.append(inst)
B
barrierye 已提交
531 532 533
        return data, data_id

    def _pack_data_for_resp(self, data):
534
        logging.debug(self._log('get data'))
B
barrierye 已提交
535
        resp = general_python_service_pb2.Response()
536
        logging.debug(self._log('gen resp'))
B
barrierye 已提交
537
        logging.debug(data)
B
barrierye 已提交
538 539 540 541 542 543 544 545 546
        resp.is_error = data.is_error
        if resp.is_error == 0:
            for inst in data.insts:
                logging.debug(self._log('append data'))
                resp.fetch_insts.append(inst.data)
                logging.debug(self._log('append name'))
                resp.fetch_var_names.append(inst.name)
        else:
            resp.error_info = data.error_info
B
barrierye 已提交
547
        return resp
B
barrierye 已提交
548

B
barrierye 已提交
549
    def inference(self, request, context):
B
barrierye 已提交
550
        _profiler.record("{}-prepack_0".format(self._name))
B
barrierye 已提交
551
        data, data_id = self._pack_data_for_infer(request)
B
barrierye 已提交
552 553
        _profiler.record("{}-prepack_1".format(self._name))

B
barrierye 已提交
554 555 556 557 558 559 560 561 562 563 564
        for i in range(self._retry):
            logging.debug(self._log('push data'))
            _profiler.record("{}-push_0".format(self._name))
            self._in_channel.push(data, self._name)
            _profiler.record("{}-push_1".format(self._name))

            logging.debug(self._log('wait for infer'))
            resp_data = None
            _profiler.record("{}-fetch_0".format(self._name))
            resp_data = self._get_data_in_globel_resp_dict(data_id)
            _profiler.record("{}-fetch_1".format(self._name))
B
barrierye 已提交
565

B
barrierye 已提交
566 567 568
            if resp_data.is_error == 0:
                break
            logging.warn("retry({}): {}".format(i + 1, resp_data.error_info))
B
barrierye 已提交
569 570

        _profiler.record("{}-postpack_0".format(self._name))
B
barrierye 已提交
571
        resp = self._pack_data_for_resp(resp_data)
B
barrierye 已提交
572 573
        _profiler.record("{}-postpack_1".format(self._name))
        _profiler.print_profile()
B
barrierye 已提交
574 575
        return resp

B
barrierye 已提交
576 577

class PyServer(object):
B
barrierye 已提交
578
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
579 580 581 582 583
        self._channels = []
        self._ops = []
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
584 585
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
586
        self._retry = retry
B
barrierye 已提交
587
        _profiler.enable(profile)
B
barrierye 已提交
588 589 590 591 592

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
B
barrierye 已提交
593
        self._ops.append(op)
B
barrierye 已提交
594 595

    def gen_desc(self):
B
barrierye 已提交
596
        logging.info('here will generate desc for paas')
B
barrierye 已提交
597 598 599 600 601
        pass

    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
B
barrierye 已提交
602 603
        inputs = set()
        outputs = set()
B
barrierye 已提交
604
        for op in self._ops:
605
            inputs |= set([op.get_input()])
B
barrierye 已提交
606
            outputs |= set(op.get_outputs())
B
barrierye 已提交
607 608
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
609 610 611 612 613 614 615 616 617 618
        in_channel = inputs - outputs
        out_channel = outputs - inputs
        if len(in_channel) != 1 or len(out_channel) != 1:
            raise Exception(
                "in_channel(out_channel) more than 1 or no in_channel(out_channel)"
            )
        self._in_channel = in_channel.pop()
        self._out_channel = out_channel.pop()
        self.gen_desc()

B
barrierye 已提交
619 620
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
621

622
    def _run_ops(self):
B
barrierye 已提交
623
        for op in self._ops:
624
            op_concurrency = op.get_concurrency()
625 626
            logging.debug("run op: {}, op_concurrency: {}".format(
                op._name, op_concurrency))
627
            for c in range(op_concurrency):
B
barrierye 已提交
628
                # th = multiprocessing.Process(
629
                th = threading.Thread(
B
barrierye 已提交
630
                    target=self._op_start_wrapper, args=(op, c))
631 632 633 634 635
                th.start()
                self._op_threads.append(th)

    def run_server(self):
        self._run_ops()
B
barrierye 已提交
636 637
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
638
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
639 640
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
641
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
642 643 644 645
        server.start()
        try:
            for th in self._op_threads:
                th.join()
B
barrierye 已提交
646
            server.join()
B
barrierye 已提交
647 648 649 650 651 652 653 654 655 656 657 658 659 660
        except KeyboardInterrupt:
            server.stop(0)

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
            cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
        else:
            cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
661 662
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
663
        return
664
        # os.system(cmd)