pyserver_multithread.py 43.1 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
B
barrierye 已提交
17
import sys
B
barrierye 已提交
18 19 20 21 22 23 24
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
import os
B
barrierye 已提交
25
import paddle_serving_server_gpu
W
wangjiawei04 已提交
26 27
#from paddle_serving_client import MultiLangClient as Client
from paddle_serving_client import Client
B
barrierye 已提交
28
from concurrent import futures
B
barrierye 已提交
29
import numpy as np
B
barrierye 已提交
30
import grpc
31 32 33 34
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
35
import logging
36
import random
B
barrierye 已提交
37
import time
B
barrierye 已提交
38
import func_timeout
39
import enum
40
import collections
B
barrierye 已提交
41 42


B
barrierye 已提交
43 44 45 46 47 48 49 50 51 52 53
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
54 55
        if self._enable is False:
            return
B
barrierye 已提交
56 57 58 59 60 61
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
62 63
        if self._enable is False:
            return
B
barrierye 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


83 84 85
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
86 87
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
B
barrierye 已提交
88 89
    RPC_PACKAGE_ERROR = 4
    UNKNOW = 5
90 91 92 93 94


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1
B
barrierye 已提交
95
    CHANNEL_NPDATA = 2
B
bug fix  
barrierye 已提交
96
    ERROR = 3
97 98 99 100


class ChannelData(object):
    def __init__(self,
B
barrierye 已提交
101
                 datatype=None,
102 103
                 future=None,
                 pbdata=None,
B
barrierye 已提交
104
                 npdata=None,
105
                 data_id=None,
B
barrierye 已提交
106 107 108 109 110 111
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
barrierye 已提交
112 113 114 115 116 117
        1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
        2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
        3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
        4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
        5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        6. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
118 119 120 121
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
B
bug fix  
barrierye 已提交
122
            datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
123
        else:
B
barrierye 已提交
124
            if datatype == ChannelDataType.CHANNEL_FUTURE.value:
B
barrierye 已提交
125 126 127
                if data_id is None:
                    raise ValueError("data_id cannot be None")
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
128 129 130 131 132 133
            elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
134 135
                    if ecode != ChannelDataEcode.OK.value:
                        logging.error(error_info)
B
barrierye 已提交
136
                    else:
B
bug fix  
barrierye 已提交
137
                        for name, value in npdata.items():
B
barrierye 已提交
138 139 140 141 142 143 144 145 146
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
            elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = self._check_npdata(npdata)
B
barrierye 已提交
147 148
                if ecode != ChannelDataEcode.OK.value:
                    logging.error(error_info)
B
barrierye 已提交
149 150
            else:
                raise ValueError("datatype not match")
B
barrierye 已提交
151
        self.future = future
152
        self.pbdata = pbdata
B
barrierye 已提交
153 154
        self.npdata = npdata
        self.datatype = datatype
B
barrierye 已提交
155 156 157
        self.id = data_id
        self.ecode = ecode
        self.error_info = error_info
158 159
        self.callback_func = callback_func

B
barrierye 已提交
160 161 162
    def _check_npdata(self, npdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
B
barrierye 已提交
163
        for _, value in npdata.items():
B
barrierye 已提交
164
            if not isinstance(value, np.ndarray):
B
barrierye 已提交
165 166
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the value of postped_data must " \
B
barrierye 已提交
167 168 169 170
                        "be np.ndarray, but get {}".format(type(value)))
                break
        return ecode, error_info

171 172
    def parse(self):
        # return narray
B
barrierye 已提交
173 174 175
        feed = None
        if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
            feed = {}
176 177 178
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
B
barrierye 已提交
179
        elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
180 181 182
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
B
barrierye 已提交
183 184
        elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            feed = self.npdata
185
        else:
B
barrierye 已提交
186
            raise TypeError("Error type({}) in datatype.".format(self.datatype))
187 188
        return feed

B
barrierye 已提交
189
    def __str__(self):
B
barrierye 已提交
190 191
        return "type[{}], ecode[{}], id[{}]".format(
            ChannelDataType(self.datatype).name, self.ecode, self.id)
B
barrierye 已提交
192

193

B
barrierye 已提交
194
class Channel(Queue.Queue):
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

B
barrierye 已提交
212 213
    def __init__(self, name=None, maxsize=-1, timeout=None):
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
214 215
        self._maxsize = maxsize
        self._timeout = timeout
216
        self.name = name
217
        self._stop = False
218

B
barrierye 已提交
219
        self._cv = threading.Condition()
220 221

        self._producers = []
B
barrierye 已提交
222 223 224 225 226 227 228
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
        self._consumer_base_idx = 0
        self._front_res = []
229 230 231 232 233 234 235 236

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
237
        return "[{}] {}".format(self.name, info_str)
238 239 240 241 242 243

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
244
        """ not thread safe, and can only be called during initialization. """
245 246 247 248
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
249 250

    def add_consumer(self, op_name):
B
barrierye 已提交
251
        """ not thread safe, and can only be called during initialization. """
252 253 254 255
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
256 257 258 259

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
260

261
    def push(self, channeldata, op_name=None):
262
        logging.debug(
263
            self._log("{} try to push data: {}".format(op_name,
B
barrierye 已提交
264
                                                       channeldata.__str__())))
265
        if len(self._producers) == 0:
266
            raise Exception(
267 268 269 270
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
271
            with self._cv:
272
                while self._stop is False:
B
barrierye 已提交
273
                    try:
274
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
275
                        break
B
barrierye 已提交
276
                    except Queue.Full:
B
barrierye 已提交
277 278
                        self._cv.wait()
                self._cv.notify_all()
279 280 281 282 283 284
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
285

286
        producer_num = len(self._producers)
B
barrierye 已提交
287
        data_id = channeldata.id
288
        put_data = None
B
barrierye 已提交
289
        with self._cv:
290
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
291 292 293 294 295 296
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
B
barrierye 已提交
297
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
298 299 300 301 302 303
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
304

B
barrierye 已提交
305 306
            if put_data is None:
                logging.debug(
307
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
308 309
                              format(op_name)))
            else:
310
                while self._stop is False:
B
barrierye 已提交
311 312 313 314 315 316 317 318 319
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
320
        return True
321

322
    def front(self, op_name=None):
B
barrierye 已提交
323
        logging.debug(self._log("{} try to get data".format(op_name)))
324 325 326 327 328 329 330
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
331
            with self._cv:
332
                while self._stop is False and resp is None:
B
barrierye 已提交
333
                    try:
B
barrierye 已提交
334
                        resp = self.get(timeout=0)
B
barrierye 已提交
335 336 337
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
338 339 340
            logging.debug(
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
341 342 343 344 345
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
346

B
barrierye 已提交
347 348
        with self._cv:
            # data_idx = consumer_idx - base_idx
349
            while self._stop is False and self._consumers[
B
barrierye 已提交
350
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
351
                try:
B
barrierye 已提交
352
                    channeldata = self.get(timeout=0)
353
                    self._front_res.append(channeldata)
B
barrierye 已提交
354 355 356
                    break
                except Queue.Empty:
                    self._cv.wait()
357

B
barrierye 已提交
358
            consumer_idx = self._consumers[op_name]
B
barrierye 已提交
359
            base_idx = self._consumer_base_idx
B
barrierye 已提交
360 361 362 363 364 365 366 367 368
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
B
barrierye 已提交
369
                self._consumer_base_idx += 1
B
barrierye 已提交
370 371 372 373 374 375

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
B
barrierye 已提交
376

B
barrierye 已提交
377
            self._cv.notify_all()
378

379
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
380
        return resp  # reference, read only
B
barrierye 已提交
381

382 383 384 385
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
386
        self._cv.notify_all()
387

B
barrierye 已提交
388 389 390

class Op(object):
    def __init__(self,
391
                 name,
392
                 inputs,
B
barrierye 已提交
393 394 395 396 397
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
398
                 fetch_names=None,
B
barrierye 已提交
399
                 concurrency=1,
B
barrierye 已提交
400 401
                 timeout=-1,
                 retry=2):
B
barrierye 已提交
402
        self._is_run = False
403
        self.name = name  # to identify the type of OP, it must be globally unique
404
        self._concurrency = concurrency  # amount of concurrency
405
        self.set_input_ops(inputs)
B
barrierye 已提交
406 407 408 409
        self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
        self._device = device
B
barrierye 已提交
410
        self._timeout = timeout
B
bug fix  
barrierye 已提交
411
        self._retry = max(1, retry)
412 413
        self._input = None
        self._outputs = []
B
barrierye 已提交
414

B
barrierye 已提交
415
    def set_client(self, client_config, server_name, fetch_names):
B
barrierye 已提交
416
        self.with_serving = True
B
barrierye 已提交
417 418 419
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
B
barrierye 已提交
420
            self.with_serving = False
421
            return
B
barrierye 已提交
422 423 424 425 426
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

427
    def get_input_channel(self):
428
        return self._input
B
barrierye 已提交
429

430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
445 446 447 448
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
449
        channel.add_consumer(self.name)
450
        self._input = channel
B
barrierye 已提交
451

452
    def get_output_channels(self):
B
barrierye 已提交
453 454
        return self._outputs

455 456
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
457
            raise TypeError(
458 459 460 461
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
462

463 464
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
465 466 467
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
468
        feed = channeldata.parse()
469
        return feed
B
barrierye 已提交
470

B
barrierye 已提交
471
    def midprocess(self, data, asyn=True):
472 473 474 475 476 477 478
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
B
barrierye 已提交
479 480 481 482 483 484 485 486 487
        if Client.__name__ == "MultiLangClient":
            call_result = self._client.predict(
                feed=data, fetch=self._fetch_names, asyn=asyn)
        elif Client.__name__ == "Client":
            call_result = self._client.predict(
                feed=data, fetch=self._fetch_names)
        else:
            raise Exception("unknow client type: {}".format(Client.__name__))
        logging.debug(self._log("get call_result"))
W
wangjiawei04 已提交
488
        return call_result
B
barrierye 已提交
489 490

    def postprocess(self, output_data):
B
barrierye 已提交
491
        return output_data
B
barrierye 已提交
492 493

    def stop(self):
494 495 496
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
B
barrierye 已提交
497
        self._is_run = False
B
barrierye 已提交
498

499
    def _parse_channeldata(self, channeldata):
B
barrierye 已提交
500
        data_id, error_channeldata = None, None
501 502
        if isinstance(channeldata, dict):
            parsed_data = {}
B
barrierye 已提交
503 504
            key = list(channeldata.keys())[0]
            data_id = channeldata[key].id
505
            for _, data in channeldata.items():
B
barrierye 已提交
506 507
                if data.ecode != ChannelDataEcode.OK.value:
                    error_channeldata = data
508 509
                    break
        else:
B
barrierye 已提交
510 511 512 513
            data_id = channeldata.id
            if channeldata.ecode != ChannelDataEcode.OK.value:
                error_channeldata = channeldata
        return data_id, error_channeldata
514

B
barrierye 已提交
515
    def _push_to_output_channels(self, data, name=None):
B
bug fix  
barrierye 已提交
516 517
        if name is None:
            name = self.name
B
barrierye 已提交
518
        for channel in self._outputs:
B
bug fix  
barrierye 已提交
519
            channel.push(data, name)
B
barrierye 已提交
520

B
barrierye 已提交
521
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
522
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
523
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
524 525
        self._is_run = True
        while self._is_run:
B
barrierye 已提交
526
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
527
            channeldata = self._input.front(self.name)
B
barrierye 已提交
528
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
529
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
530

B
barrierye 已提交
531
            data_id, error_channeldata = self._parse_channeldata(channeldata)
532

B
bug fix  
barrierye 已提交
533
            # error data in predecessor Op
B
barrierye 已提交
534 535
            if error_channeldata is not None:
                self._push_to_output_channels(error_channeldata)
B
barrierye 已提交
536 537
                continue

B
bug fix  
barrierye 已提交
538
            # preprecess
B
barrierye 已提交
539 540
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
541
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
542 543
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
544
                # preprocess function not implemented
B
barrierye 已提交
545 546 547 548 549 550
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
B
barrierye 已提交
551
                        data_id=data_id))
B
barrierye 已提交
552
                continue
B
bug fix  
barrierye 已提交
553
            except TypeError as e:
B
barrierye 已提交
554
                # Error type in channeldata.datatype
B
bug fix  
barrierye 已提交
555 556 557 558 559 560
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
B
barrierye 已提交
561
                        data_id=data_id))
B
bug fix  
barrierye 已提交
562 563 564 565 566 567
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
B
barrierye 已提交
568
                        ecode=ChannelDataEcode.UNKNOW.value,
B
bug fix  
barrierye 已提交
569
                        error_info=error_info,
B
barrierye 已提交
570
                        data_id=data_id))
B
bug fix  
barrierye 已提交
571
                continue
572

B
barrierye 已提交
573
            # midprocess
W
wangjiawei04 已提交
574 575
            midped_data = None
            asyn = False
B
barrierye 已提交
576
            if self.with_serving:
B
bug fix  
barrierye 已提交
577
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
578 579 580
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
W
wangjiawei04 已提交
581
                        midped_data = self.midprocess(preped_data, asyn)
B
barrierye 已提交
582 583 584 585
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
586
                else:
B
barrierye 已提交
587 588
                    for i in range(self._retry):
                        try:
W
wangjiawei04 已提交
589
                            midped_data = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
590 591
                                self._timeout,
                                self.midprocess,
W
wangjiawei04 已提交
592
                                args=(preped_data, asyn))
B
bug fix  
barrierye 已提交
593
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
594 595
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
596 597
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
598 599
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
600
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
601 602 603 604 605 606 607
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
608
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
609 610 611
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
612
                            data_id=data_id))
B
barrierye 已提交
613 614
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
W
wangjiawei04 已提交
615 616
            else:
                midped_data = preped_data
617

B
barrierye 已提交
618 619 620
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
B
barrierye 已提交
621
            if self.with_serving and asyn:
B
bug fix  
barrierye 已提交
622
                # use call_future
B
barrierye 已提交
623
                output_data = ChannelData(
B
barrierye 已提交
624
                    datatype=ChannelDataType.CHANNEL_FUTURE.value,
W
wangjiawei04 已提交
625
                    future=midped_data,
B
barrierye 已提交
626 627 628
                    data_id=data_id,
                    callback_func=self.postprocess)
            else:
B
bug fix  
barrierye 已提交
629
                try:
W
wangjiawei04 已提交
630
                    postped_data = self.postprocess(midped_data)
B
bug fix  
barrierye 已提交
631 632 633 634 635 636 637
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
638
                            data_id=data_id))
B
bug fix  
barrierye 已提交
639 640
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
641 642
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
643
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
644 645 646 647
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
B
barrierye 已提交
648
                            data_id=data_id))
B
barrierye 已提交
649
                    continue
B
bug fix  
barrierye 已提交
650

B
barrierye 已提交
651 652 653 654
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
655 656 657 658
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
B
barrierye 已提交
659
            self._push_to_output_channels(output_data)
B
barrierye 已提交
660 661 662 663 664 665 666 667 668 669
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
670

671 672 673
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
674

B
bug fix  
barrierye 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

B
barrierye 已提交
695
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
696 697
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
B
barrierye 已提交
698 699
        self._is_run = True
        while self._is_run:
B
bug fix  
barrierye 已提交
700
            _profiler.record("{}-get_0".format(op_info_prefix))
B
barrierye 已提交
701
            channeldata = self._input.front(self.name)
B
bug fix  
barrierye 已提交
702 703 704 705 706
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
B
barrierye 已提交
707
                    self._push_to_output_channels(data, name=name)
B
bug fix  
barrierye 已提交
708
            else:
B
barrierye 已提交
709 710
                self._push_to_output_channels(channeldata,
                                              self._virtual_pred_ops[0].name)
B
bug fix  
barrierye 已提交
711 712 713
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
714
class GeneralPythonService(
B
barrierye 已提交
715
        general_python_service_pb2_grpc.GeneralPythonServiceServicer):
B
barrierye 已提交
716
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
717
        super(GeneralPythonService, self).__init__()
718
        self.name = "#G"
719 720
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
721 722
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
723 724 725 726 727
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
728 729
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
730
        self._retry = retry
B
barrierye 已提交
731 732 733
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
734 735

    def _log(self, info_str):
736
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
737

738
    def set_in_channel(self, in_channel):
739 740 741 742
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
743
        in_channel.add_producer(self.name)
744 745 746
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
747 748 749 750
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
751
        out_channel.add_consumer(self.name)
752 753
        self._out_channel = out_channel

B
barrierye 已提交
754 755
    def _recive_out_channel_func(self):
        while True:
756
            channeldata = self._out_channel.front(self.name)
757
            if not isinstance(channeldata, ChannelData):
758 759
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
760
                              format(type(channeldata))))
B
barrierye 已提交
761
            with self._cv:
B
barrierye 已提交
762
                data_id = channeldata.id
763
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
764
                self._cv.notify_all()
B
barrierye 已提交
765 766

    def _get_next_id(self):
B
barrierye 已提交
767
        with self._id_lock:
B
barrierye 已提交
768 769 770 771
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
772 773 774 775 776 777
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
778
        return resp
B
barrierye 已提交
779 780

    def _pack_data_for_infer(self, request):
781
        logging.debug(self._log('start inferce'))
B
barrierye 已提交
782
        data_id = self._get_next_id()
B
barrierye 已提交
783
        npdata = {}
B
barrierye 已提交
784 785 786 787 788 789
        try:
            for idx, name in enumerate(request.feed_var_names):
                logging.debug(
                    self._log('name: {}'.format(request.feed_var_names[idx])))
                logging.debug(
                    self._log('data: {}'.format(request.feed_insts[idx])))
B
barrierye 已提交
790 791 792 793
                npdata[name] = np.frombuffer(
                    request.feed_insts[idx], dtype=request.type[idx])
                npdata[name].shape = np.frombuffer(
                    request.shape[idx], dtype="int32")
B
barrierye 已提交
794
        except Exception as e:
B
barrierye 已提交
795 796 797 798
            return ChannelData(
                ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
                error_info="rpc package error",
                data_id=data_id), data_id
B
barrierye 已提交
799 800
        return ChannelData(
            datatype=ChannelDataType.CHANNEL_PBDATA.value,
B
barrierye 已提交
801 802
            npdata=npdata,
            data_id=data_id), data_id
B
barrierye 已提交
803

804 805 806
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
B
barrierye 已提交
807
        resp.ecode = channeldata.ecode
B
bug fix  
barrierye 已提交
808
        if resp.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
809
            if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
810 811 812 813 814
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
B
barrierye 已提交
815 816 817
            elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
                                          ChannelDataType.CHANNEL_NPDATA.value):
                feed = channeldata.parse()
B
bug fix  
barrierye 已提交
818
                for name, var in feed.items():
819 820 821 822 823 824 825 826
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
B
barrierye 已提交
827
                    self._log("Error type({}) in datatype.".format(
B
barrierye 已提交
828
                        channeldata.datatype)))
B
barrierye 已提交
829
        else:
B
barrierye 已提交
830
            resp.error_info = channeldata.error_info
B
barrierye 已提交
831
        return resp
B
barrierye 已提交
832

B
barrierye 已提交
833
    def inference(self, request, context):
834
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
835
        data, data_id = self._pack_data_for_infer(request)
836
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
837

838
        resp_channeldata = None
B
barrierye 已提交
839 840
        for i in range(self._retry):
            logging.debug(self._log('push data'))
841 842 843
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
844 845

            logging.debug(self._log('wait for infer'))
846
            _profiler.record("{}-fetch_0".format(self.name))
847
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
848
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
849

B
barrierye 已提交
850
            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
851
                break
B
barrierye 已提交
852 853
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
B
barrierye 已提交
854
                    i + 1, resp_channeldata.error_info))
B
barrierye 已提交
855

856
        _profiler.record("{}-postpack_0".format(self.name))
857
        resp = self._pack_data_for_resp(resp_channeldata)
858
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
859
        _profiler.print_profile()
B
barrierye 已提交
860 861
        return resp

B
barrierye 已提交
862 863

class PyServer(object):
B
barrierye 已提交
864
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
865
        self._channels = []
866
        self._user_ops = []
B
bug fix  
barrierye 已提交
867
        self._actual_ops = []
B
barrierye 已提交
868
        self._op_threads = []
B
barrierye 已提交
869 870
        self._port = None
        self._worker_num = None
B
barrierye 已提交
871 872
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
873
        self._retry = retry
B
barrierye 已提交
874
        _profiler.enable(profile)
B
barrierye 已提交
875 876 877 878 879

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
880 881 882
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
883
        self._user_ops.extend(ops)
B
barrierye 已提交
884 885

    def gen_desc(self):
886
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
887 888
        pass

889 890 891
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
892
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
893 894 895 896 897
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
B
barrierye 已提交
898
        zero_indeg_num, zero_outdeg_num = 0, 0
899 900 901 902 903 904 905
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
B
barrierye 已提交
906
                zero_indeg_num += 1
907
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
908
                outdegs[pred_op.name].append(op)
B
barrierye 已提交
909 910 911 912 913 914 915
        if zero_indeg_num != 1:
            raise Exception("DAG contains multiple input Ops")
        for _, succ_list in outdegs.items():
            if len(succ_list) == 0:
                zero_outdeg_num += 1
        if zero_outdeg_num != 1:
            raise Exception("DAG contains multiple output Ops")
916

B
bug fix  
barrierye 已提交
917
        # topo sort to get dag_views
918 919 920 921 922 923 924 925 926 927
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
928
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
929
                    indeg_num[succ_op.name] -= 1
930 931 932 933 934 935 936 937 938 939
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
940 941 942 943 944 945 946 947 948 949 950
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
951 952 953
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
954
        actual_view = None
955 956 957 958
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
959 960
            if actual_view is None:
                actual_view = view
961 962
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
963 964
            for op in actual_view:
                # find actual succ op in next view and create virtual op
965 966
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
967 968
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
969 970 971 972
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
973 974
                        # create virtual op
                        virtual_op = None
B
barrierye 已提交
975 976 977 978 979 980 981 982
                        if sys.version_info.major == 2:
                            virtual_op = VirtualOp(
                                name=virtual_op_name_gen.next())
                        elif sys.version_info.major == 3:
                            virtual_op = VirtualOp(
                                name=virtual_op_name_gen.__next__())
                        else:
                            raise Exception("Error Python version")
983
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
984 985 986 987 988
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
989 990 991
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
992
                if op.name in processed_op:
993
                    continue
B
barrierye 已提交
994 995 996 997 998 999 1000
                if sys.version_info.major == 2:
                    channel = Channel(name=channel_name_gen.next())
                elif sys.version_info.major == 3:
                    channel = Channel(
                        self._manager, name=channel_name_gen.__next__())
                else:
                    raise Exception("Error Python version")
1001
                channels.append(channel)
B
bug fix  
barrierye 已提交
1002
                logging.debug("{} => {}".format(channel.name, op.name))
1003
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
1004
                pred_ops = pred_op_of_next_view_op[op.name]
1005 1006 1007
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
1008
                    # if pred_op is virtual op, it will use ancestors as producers to channel
1009
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
1010 1011
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
1012
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
1013 1014 1015
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
1027 1028
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
1029 1030
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
B
barrierye 已提交
1031 1032 1033 1034 1035 1036 1037
        if sys.version_info.major == 2:
            output_channel = Channel(name=channel_name_gen.next())
        elif sys.version_info.major == 3:
            output_channel = Channel(
                self._manager, name=channel_name_gen.__next__())
        else:
            raise Exception("Error Python version")
1038
        channels.append(output_channel)
B
bug fix  
barrierye 已提交
1039
        last_op = dag_views[-1][0]
1040 1041
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
1042
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
1043 1044
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
1045
                # pass read op
B
fix bug  
barrierye 已提交
1046
                continue
B
bug fix  
barrierye 已提交
1047
            self._actual_ops.append(op)
1048
        self._channels = channels
B
bug fix  
barrierye 已提交
1049 1050
        for c in channels:
            logging.debug(c.debug())
1051 1052
        return input_channel, output_channel

B
barrierye 已提交
1053 1054 1055
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
1056 1057 1058

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
1059
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
1060
        for op in self._actual_ops:
B
barrierye 已提交
1061
            if op.with_serving:
B
fix bug  
barrierye 已提交
1062
                self.prepare_serving(op)
B
barrierye 已提交
1063 1064
        self.gen_desc()

B
barrierye 已提交
1065 1066 1067
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)

1068
    def _run_ops(self):
B
bug fix  
barrierye 已提交
1069
        for op in self._actual_ops:
B
barrierye 已提交
1070 1071 1072 1073 1074 1075 1076 1077
            op_concurrency = op.get_concurrency()
            logging.debug("run op: {}, op_concurrency: {}".format(
                op.name, op_concurrency))
            for c in range(op_concurrency):
                th = threading.Thread(
                    target=self._op_start_wrapper, args=(op, c))
                th.start()
                self._op_threads.append(th)
1078

1079
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1080
        for op in self._actual_ops:
1081 1082
            op.stop()

1083
    def run_server(self):
B
barrierye 已提交
1084
        self._run_ops()
B
barrierye 已提交
1085 1086
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1087
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1088 1089
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1090
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1091
        server.start()
1092 1093
        server.wait_for_termination()
        self._stop_ops()  # TODO
B
barrierye 已提交
1094 1095
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
1096 1097 1098 1099 1100 1101

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

B
barrierye 已提交
1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
        if Client.__name__ == "MultiLangClient":
            if device == "cpu":
                cmd = "(Use grpc impl) python -m paddle_serving_server.serve" \
                      " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
            else:
                cmd = "(Use grpc impl) python -m paddle_serving_server_gpu.serve" \
                      " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
        elif Client.__name__ == "Client":
            if device == "cpu":
                cmd = "(Use brpc impl) python -m paddle_serving_server.serve" \
                      " --model {} --thread 4 --port {} &>/dev/null &".format(model_path, port)
            else:
                cmd = "(Use brpc impl) python -m paddle_serving_server_gpu.serve" \
                      " --model {} --thread 4 --port {} &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1116
        else:
B
barrierye 已提交
1117
            raise Exception("unknow client type: {}".format(Client.__name__))
1118 1119
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))