pyserver.py 24.8 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import Queue
B
barrierye 已提交
18
import os
B
barrierye 已提交
19
import sys
B
barrierye 已提交
20 21 22
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
B
barrierye 已提交
23
import numpy as np
B
barrierye 已提交
24 25 26
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
B
barrierye 已提交
27
import python_service_channel_pb2
B
barrierye 已提交
28
import logging
29
import random
B
barrierye 已提交
30
import time
B
barrierye 已提交
31
import func_timeout
B
barrierye 已提交
32 33


B
barrierye 已提交
34 35 36 37 38 39 40 41 42 43 44
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
45 46
        if self._enable is False:
            return
B
barrierye 已提交
47 48 49 50 51 52
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
53 54
        if self._enable is False:
            return
B
barrierye 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


B
barrierye 已提交
74
class Channel(Queue.Queue):
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
92
    def __init__(self, name=None, maxsize=-1, timeout=None):
B
barrierye 已提交
93
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
94 95
        self._maxsize = maxsize
        self._timeout = timeout
96
        self._name = name
97

98 99 100 101 102 103 104 105 106
        self._cv = threading.Condition()

        self._producers = []
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._front_wait_interval = 0.1  # second
        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
107
        self._consumer_base_idx = 0
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        self._front_res = []

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
124
        """ not thread safe, and can only be called during initialization. """
125 126 127 128
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
129 130

    def add_consumer(self, op_name):
B
barrierye 已提交
131
        """ not thread safe, and can only be called during initialization. """
132 133 134 135
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
136 137 138 139

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
140

141 142 143 144
    def push(self, data, op_name=None):
        logging.debug(
            self._log("{} try to push data: {}".format(op_name, data)))
        if len(self._producers) == 0:
145
            raise Exception(
146 147 148 149
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
150 151 152 153 154 155 156 157
            with self._cv:
                while True:
                    try:
                        self.put(data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
                self._cv.notify_all()
158 159 160 161 162 163
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
164

165 166 167
        producer_num = len(self._producers)
        data_id = data.id
        put_data = None
B
barrierye 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        with self._cv:
            logging.debug(self._log("{} get lock ~".format(op_name)))
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
            self._push_res[data_id][op_name] = data
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
183

B
barrierye 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
            if put_data is None:
                logging.debug(
                    self._log("{} push data succ, not not push to queue.".
                              format(op_name)))
            else:
                while True:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
199
        return True
200

201 202 203 204 205 206 207 208 209
    def front(self, op_name=None):
        logging.debug(self._log("{} try to get data".format(op_name)))
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
210 211 212 213 214 215 216
            with self._cv:
                while resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
217 218 219 220 221 222
            logging.debug(self._log("{} get data succ!".format(op_name)))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
223

B
barrierye 已提交
224 225 226 227 228 229 230 231 232 233
        with self._cv:
            # data_idx = consumer_idx - base_idx
            while self._consumers[op_name] - self._consumer_base_idx >= len(
                    self._front_res):
                try:
                    data = self.get(timeout=0)
                    self._front_res.append(data)
                    break
                except Queue.Empty:
                    self._cv.wait()
234

B
barrierye 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
            consumer_idx = self._consumers[op_name]
            base_idx = self._consumer_base_idx
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
                self._consumer_base_idx += 1

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
253

B
barrierye 已提交
254
            self._cv.notify_all()
255

256
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
257
        return resp  # reference, read only
B
barrierye 已提交
258 259 260 261


class Op(object):
    def __init__(self,
262
                 name,
263
                 input,
B
barrierye 已提交
264
                 in_dtype,
B
barrierye 已提交
265
                 outputs,
B
barrierye 已提交
266
                 out_dtype,
B
barrierye 已提交
267 268 269 270 271
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
272
                 fetch_names=None,
B
barrierye 已提交
273
                 concurrency=1,
B
barrierye 已提交
274 275
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
276
        self._run = False
277 278 279
        # TODO: globally unique check
        self._name = name  # to identify the type of OP, it must be globally unique
        self._concurrency = concurrency  # amount of concurrency
280
        self.set_input(input)
B
barrierye 已提交
281
        self._in_dtype = in_dtype
B
barrierye 已提交
282
        self.set_outputs(outputs)
B
barrierye 已提交
283
        self._out_dtype = out_dtype
B
barrierye 已提交
284
        self._client = None
B
barrierye 已提交
285 286 287 288 289 290
        if client_config is not None and \
                server_name is not None and \
                fetch_names is not None:
            self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
B
barrierye 已提交
291
        self._device = device
B
barrierye 已提交
292
        self._timeout = timeout
B
barrierye 已提交
293
        self._retry = retry
B
barrierye 已提交
294 295 296 297 298 299 300 301 302 303

    def set_client(self, client_config, server_name, fetch_names):
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

    def with_serving(self):
        return self._client is not None

304 305
    def get_input(self):
        return self._input
B
barrierye 已提交
306

307 308 309 310 311 312 313
    def set_input(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self._name)
        self._input = channel
B
barrierye 已提交
314 315 316 317 318 319

    def get_outputs(self):
        return self._outputs

    def set_outputs(self, channels):
        if not isinstance(channels, list):
320 321 322 323 324
            raise TypeError(
                self._log('output channels must be list type, not {}'.format(
                    type(channels))))
        for channel in channels:
            channel.add_producer(self._name)
B
barrierye 已提交
325 326
        self._outputs = channels

327 328
    def preprocess(self, data):
        if isinstance(data, dict):
B
barrierye 已提交
329
            raise Exception(
330 331 332 333 334 335 336
                self._log(
                    'this Op has multiple previous inputs. Please override this method'
                ))
        feed = {}
        for inst in data.insts:
            feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
        return feed
B
barrierye 已提交
337 338

    def midprocess(self, data):
339 340 341 342 343 344 345
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
B
barrierye 已提交
346
        fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
347
        logging.debug(self._log("finish predict"))
B
barrierye 已提交
348 349 350
        return fetch_map

    def postprocess(self, output_data):
351 352 353 354
        raise Exception(
            self._log(
                'Please override this method to convert data to the format in channel.'
            ))
B
barrierye 已提交
355

B
barrierye 已提交
356 357 358 359 360 361
    def errorprocess(self, error_info):
        data = python_service_channel_pb2.ChannelData()
        data.is_error = 1
        data.error_info = error_info
        return data

B
barrierye 已提交
362 363 364
    def stop(self):
        self._run = False

B
barrierye 已提交
365
    def start(self, concurrency_idx):
B
barrierye 已提交
366 367
        self._run = True
        while self._run:
B
barrierye 已提交
368
            _profiler.record("{}{}-get_0".format(self._name, concurrency_idx))
369
            input_data = self._input.front(self._name)
B
barrierye 已提交
370
            _profiler.record("{}{}-get_1".format(self._name, concurrency_idx))
371
            data_id = None
B
barrierye 已提交
372 373
            output_data = None
            error_data = None
374 375 376 377
            logging.debug(self._log("input_data: {}".format(input_data)))
            if isinstance(input_data, dict):
                key = input_data.keys()[0]
                data_id = input_data[key].id
B
barrierye 已提交
378 379 380 381
                for _, data in input_data.items():
                    if data.is_error != 0:
                        error_data = data
                        break
B
barrierye 已提交
382
            else:
383
                data_id = input_data.id
B
barrierye 已提交
384 385
                if input_data.is_error != 0:
                    error_data = input_data
B
barrierye 已提交
386

B
barrierye 已提交
387 388
            if error_data is None:
                _profiler.record("{}{}-prep_0".format(self._name,
B
barrierye 已提交
389
                                                      concurrency_idx))
B
barrierye 已提交
390 391
                data = self.preprocess(input_data)
                _profiler.record("{}{}-prep_1".format(self._name,
B
barrierye 已提交
392
                                                      concurrency_idx))
B
barrierye 已提交
393

B
barrierye 已提交
394 395
                error_info = None
                if self.with_serving():
B
barrierye 已提交
396 397 398
                    for i in range(self._retry):
                        _profiler.record("{}{}-midp_0".format(self._name,
                                                              concurrency_idx))
B
bug fix  
barrierye 已提交
399
                        if self._timeout > 0:
B
barrierye 已提交
400 401
                            try:
                                middata = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
402 403 404
                                    self._timeout,
                                    self.midprocess,
                                    args=(data, ))
B
barrierye 已提交
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
                            except func_timeout.FunctionTimedOut:
                                logging.error("error: timeout")
                                error_info = "{}({}): timeout".format(
                                    self._name, concurrency_idx)
                            except Exception as e:
                                logging.error("error: {}".format(e))
                                error_info = "{}({}): {}".format(
                                    self._name, concurrency_idx, e)
                        else:
                            middata = self.midprocess(data)
                        _profiler.record("{}{}-midp_1".format(self._name,
                                                              concurrency_idx))
                        if error_info is None:
                            data = middata
                            break
                        if i + 1 < self._retry:
                            error_info = None
                            logging.warn(
                                self._log("warn: timeout, retry({})".format(i +
                                                                            1)))
B
barrierye 已提交
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445

                _profiler.record("{}{}-postp_0".format(self._name,
                                                       concurrency_idx))
                if error_info is not None:
                    output_data = self.errorprocess(error_info)
                else:
                    output_data = self.postprocess(data)

                    if not isinstance(output_data,
                                      python_service_channel_pb2.ChannelData):
                        raise TypeError(
                            self._log(
                                'output_data must be ChannelData type, but get {}'.
                                format(type(output_data))))
                    output_data.is_error = 0
                _profiler.record("{}{}-postp_1".format(self._name,
                                                       concurrency_idx))

                output_data.id = data_id
            else:
                output_data = error_data
B
barrierye 已提交
446

B
barrierye 已提交
447
            _profiler.record("{}{}-push_0".format(self._name, concurrency_idx))
B
barrierye 已提交
448
            for channel in self._outputs:
449
                channel.push(output_data, self._name)
B
barrierye 已提交
450
            _profiler.record("{}{}-push_1".format(self._name, concurrency_idx))
451 452 453

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
454

455 456 457
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
458 459 460

class GeneralPythonService(
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
461
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
462
        super(GeneralPythonService, self).__init__()
B
barrierye 已提交
463
        self._name = "#G"
464 465
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
466 467
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
468 469 470 471 472
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
473 474
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
475
        self._retry = retry
B
barrierye 已提交
476 477 478
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
479 480 481

    def _log(self, info_str):
        return "[{}] {}".format(self._name, info_str)
B
barrierye 已提交
482

483
    def set_in_channel(self, in_channel):
484 485 486 487 488
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
        in_channel.add_producer(self._name)
489 490 491
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
492 493 494 495 496
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
        out_channel.add_consumer(self._name)
497 498
        self._out_channel = out_channel

B
barrierye 已提交
499 500
    def _recive_out_channel_func(self):
        while True:
501 502 503 504 505
            data = self._out_channel.front(self._name)
            if not isinstance(data, python_service_channel_pb2.ChannelData):
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
                              format(type(data))))
B
barrierye 已提交
506 507 508
            with self._cv:
                self._globel_resp_dict[data.id] = data
                self._cv.notify_all()
B
barrierye 已提交
509 510

    def _get_next_id(self):
B
barrierye 已提交
511
        with self._id_lock:
B
barrierye 已提交
512 513 514 515
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
516 517 518 519 520 521
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
522
        return resp
B
barrierye 已提交
523 524

    def _pack_data_for_infer(self, request):
525
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
526
        data = python_service_channel_pb2.ChannelData()
B
barrierye 已提交
527 528
        data_id = self._get_next_id()
        data.id = data_id
B
bug fix  
barrierye 已提交
529
        data.is_error = 0
B
barrierye 已提交
530
        for idx, name in enumerate(request.feed_var_names):
531 532 533
            logging.debug(
                self._log('name: {}'.format(request.feed_var_names[idx])))
            logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
534
            inst = python_service_channel_pb2.Inst()
B
barrierye 已提交
535
            inst.data = request.feed_insts[idx]
B
barrierye 已提交
536 537
            inst.name = name
            data.insts.append(inst)
B
barrierye 已提交
538 539 540
        return data, data_id

    def _pack_data_for_resp(self, data):
541
        logging.debug(self._log('get data'))
B
barrierye 已提交
542
        resp = general_python_service_pb2.Response()
543
        logging.debug(self._log('gen resp'))
B
barrierye 已提交
544
        logging.debug(data)
B
barrierye 已提交
545 546 547 548 549 550 551 552 553
        resp.is_error = data.is_error
        if resp.is_error == 0:
            for inst in data.insts:
                logging.debug(self._log('append data'))
                resp.fetch_insts.append(inst.data)
                logging.debug(self._log('append name'))
                resp.fetch_var_names.append(inst.name)
        else:
            resp.error_info = data.error_info
B
barrierye 已提交
554
        return resp
B
barrierye 已提交
555

B
barrierye 已提交
556
    def inference(self, request, context):
B
barrierye 已提交
557
        _profiler.record("{}-prepack_0".format(self._name))
B
barrierye 已提交
558
        data, data_id = self._pack_data_for_infer(request)
B
barrierye 已提交
559 560
        _profiler.record("{}-prepack_1".format(self._name))

B
barrierye 已提交
561 562 563 564 565 566 567 568 569 570 571
        for i in range(self._retry):
            logging.debug(self._log('push data'))
            _profiler.record("{}-push_0".format(self._name))
            self._in_channel.push(data, self._name)
            _profiler.record("{}-push_1".format(self._name))

            logging.debug(self._log('wait for infer'))
            resp_data = None
            _profiler.record("{}-fetch_0".format(self._name))
            resp_data = self._get_data_in_globel_resp_dict(data_id)
            _profiler.record("{}-fetch_1".format(self._name))
B
barrierye 已提交
572

B
barrierye 已提交
573 574 575
            if resp_data.is_error == 0:
                break
            logging.warn("retry({}): {}".format(i + 1, resp_data.error_info))
B
barrierye 已提交
576 577

        _profiler.record("{}-postpack_0".format(self._name))
B
barrierye 已提交
578
        resp = self._pack_data_for_resp(resp_data)
B
barrierye 已提交
579 580
        _profiler.record("{}-postpack_1".format(self._name))
        _profiler.print_profile()
B
barrierye 已提交
581 582
        return resp

B
barrierye 已提交
583 584

class PyServer(object):
B
barrierye 已提交
585
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
586 587 588 589 590
        self._channels = []
        self._ops = []
        self._op_threads = []
        self._port = None
        self._worker_num = None
B
barrierye 已提交
591 592
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
593
        self._retry = retry
B
barrierye 已提交
594
        _profiler.enable(profile)
B
barrierye 已提交
595 596 597 598 599

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
B
barrierye 已提交
600
        self._ops.append(op)
B
barrierye 已提交
601 602

    def gen_desc(self):
B
barrierye 已提交
603
        logging.info('here will generate desc for paas')
B
barrierye 已提交
604 605 606 607 608
        pass

    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
B
barrierye 已提交
609 610
        inputs = set()
        outputs = set()
B
barrierye 已提交
611
        for op in self._ops:
612
            inputs |= set([op.get_input()])
B
barrierye 已提交
613
            outputs |= set(op.get_outputs())
B
barrierye 已提交
614 615
            if op.with_serving():
                self.prepare_serving(op)
B
barrierye 已提交
616 617 618 619 620 621 622 623 624 625
        in_channel = inputs - outputs
        out_channel = outputs - inputs
        if len(in_channel) != 1 or len(out_channel) != 1:
            raise Exception(
                "in_channel(out_channel) more than 1 or no in_channel(out_channel)"
            )
        self._in_channel = in_channel.pop()
        self._out_channel = out_channel.pop()
        self.gen_desc()

B
barrierye 已提交
626 627
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)
B
barrierye 已提交
628

629
    def _run_ops(self):
B
barrierye 已提交
630
        for op in self._ops:
631
            op_concurrency = op.get_concurrency()
632 633
            logging.debug("run op: {}, op_concurrency: {}".format(
                op._name, op_concurrency))
634
            for c in range(op_concurrency):
B
barrierye 已提交
635
                # th = multiprocessing.Process(
636
                th = threading.Thread(
B
barrierye 已提交
637
                    target=self._op_start_wrapper, args=(op, c))
638 639 640 641 642
                th.start()
                self._op_threads.append(th)

    def run_server(self):
        self._run_ops()
B
barrierye 已提交
643 644
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
645
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
646 647
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
648
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
649 650 651 652
        server.start()
        try:
            for th in self._op_threads:
                th.join()
B
barrierye 已提交
653
            server.join()
B
barrierye 已提交
654 655 656 657 658 659 660 661 662 663 664 665 666 667
        except KeyboardInterrupt:
            server.stop(0)

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
            cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
        else:
            cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
                model_path, port)
668 669
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))
B
barrierye 已提交
670
        return
671
        # os.system(cmd)