pyserver.py 42.3 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
W
wangjiawei04 已提交
17
import Queue
B
barrierye 已提交
18
import os
W
wangjiawei04 已提交
19
import sys
B
barrierye 已提交
20
import paddle_serving_server
W
wangjiawei04 已提交
21 22
#from paddle_serving_client import MultiLangClient as Client
from paddle_serving_client import Client
B
barrierye 已提交
23
from concurrent import futures
B
barrierye 已提交
24
import numpy as np
B
barrierye 已提交
25
import grpc
26 27 28 29
from .proto import general_model_config_pb2 as m_config
from .proto import general_python_service_pb2 as pyservice_pb2
from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
B
barrierye 已提交
30
import logging
31
import random
B
barrierye 已提交
32
import time
B
barrierye 已提交
33
import func_timeout
34
import enum
35
import collections
B
barrierye 已提交
36 37


B
barrierye 已提交
38 39 40 41 42 43 44 45 46 47 48
class _TimeProfiler(object):
    def __init__(self):
        self._pid = os.getpid()
        self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
        self._time_record = Queue.Queue()
        self._enable = False

    def enable(self, enable):
        self._enable = enable

    def record(self, name_with_tag):
B
bug fix  
barrierye 已提交
49 50
        if self._enable is False:
            return
B
barrierye 已提交
51 52 53 54 55 56
        name_with_tag = name_with_tag.split("_")
        tag = name_with_tag[-1]
        name = '_'.join(name_with_tag[:-1])
        self._time_record.put((name, tag, int(round(time.time() * 1000000))))

    def print_profile(self):
B
bug fix  
barrierye 已提交
57 58
        if self._enable is False:
            return
B
barrierye 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        sys.stderr.write(self._print_head)
        tmp = {}
        while not self._time_record.empty():
            name, tag, timestamp = self._time_record.get()
            if name in tmp:
                ptag, ptimestamp = tmp.pop(name)
                sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
                sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
            else:
                tmp[name] = (tag, timestamp)
        sys.stderr.write('\n')
        for name, item in tmp.items():
            tag, timestamp = item
            self._time_record.put((name, tag, timestamp))


_profiler = _TimeProfiler()


78 79 80
class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
B
barrierye 已提交
81 82
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
B
barrierye 已提交
83 84
    RPC_PACKAGE_ERROR = 4
    UNKNOW = 5
85 86 87 88 89


class ChannelDataType(enum.Enum):
    CHANNEL_PBDATA = 0
    CHANNEL_FUTURE = 1
B
barrierye 已提交
90
    CHANNEL_NPDATA = 2
B
bug fix  
barrierye 已提交
91
    ERROR = 3
92 93 94 95


class ChannelData(object):
    def __init__(self,
B
barrierye 已提交
96
                 datatype=None,
97 98
                 future=None,
                 pbdata=None,
B
barrierye 已提交
99
                 npdata=None,
100
                 data_id=None,
B
barrierye 已提交
101 102 103 104 105 106
                 callback_func=None,
                 ecode=None,
                 error_info=None):
        '''
        There are several ways to use it:
        
B
barrierye 已提交
107 108 109 110 111 112
        1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
        2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
        3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
        4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
        5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        6. ChannelData(ecode, error_info, data_id)
B
barrierye 已提交
113 114 115 116
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
W
wangjiawei04 已提交
117 118 119 120
            pbdata = channel_pb2.ChannelData()
            pbdata.ecode = ecode
            pbdata.id = data_id
            pbdata.error_info = error_info
B
bug fix  
barrierye 已提交
121
            datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
122
        else:
B
barrierye 已提交
123
            if datatype == ChannelDataType.CHANNEL_FUTURE.value:
W
wangjiawei04 已提交
124 125 126 127 128 129
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
                    pbdata.ecode = ChannelDataEcode.OK.value
                    pbdata.id = data_id
B
barrierye 已提交
130 131 132 133 134
            elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
                if pbdata is None:
                    if data_id is None:
                        raise ValueError("data_id cannot be None")
                    pbdata = channel_pb2.ChannelData()
W
wangjiawei04 已提交
135
                    pbdata.id = data_id
B
barrierye 已提交
136
                    ecode, error_info = self._check_npdata(npdata)
W
wangjiawei04 已提交
137 138 139 140
                    pbdata.ecode = ecode
                    if pbdata.ecode != ChannelDataEcode.OK.value:
                        pbdata.error_info = error_info
                        logging.error(pbdata.error_info)
B
barrierye 已提交
141
                    else:
B
bug fix  
barrierye 已提交
142
                        for name, value in npdata.items():
B
barrierye 已提交
143 144 145 146 147 148 149 150 151
                            inst = channel_pb2.Inst()
                            inst.data = value.tobytes()
                            inst.name = name
                            inst.shape = np.array(
                                value.shape, dtype="int32").tobytes()
                            inst.type = str(value.dtype)
                            pbdata.insts.append(inst)
            elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = self._check_npdata(npdata)
W
wangjiawei04 已提交
152 153 154 155 156 157
                pbdata = channel_pb2.ChannelData()
                pbdata.id = data_id
                pbdata.ecode = ecode
                if pbdata.ecode != ChannelDataEcode.OK.value:
                    pbdata.error_info = error_info
                    logging.error(pbdata.error_info)
B
barrierye 已提交
158 159
            else:
                raise ValueError("datatype not match")
W
wangjiawei04 已提交
160 161 162 163
        if not isinstance(pbdata, channel_pb2.ChannelData):
            raise TypeError(
                "pbdata must be pyserving_channel_pb2.ChannelData type({})".
                format(type(pbdata)))
B
barrierye 已提交
164
        self.future = future
165
        self.pbdata = pbdata
B
barrierye 已提交
166 167
        self.npdata = npdata
        self.datatype = datatype
168 169
        self.callback_func = callback_func

B
barrierye 已提交
170 171 172 173 174 175 176 177 178 179
    def _check_npdata(self, npdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for name, value in npdata.items():
            if not isinstance(name, (str, unicode)):
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the key of postped_data must " \
                        "be str, but get {}".format(type(name)))
                break
            if not isinstance(value, np.ndarray):
B
barrierye 已提交
180 181
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("the value of postped_data must " \
B
barrierye 已提交
182 183 184 185
                        "be np.ndarray, but get {}".format(type(value)))
                break
        return ecode, error_info

186 187
    def parse(self):
        # return narray
B
barrierye 已提交
188 189 190
        feed = None
        if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
            feed = {}
191 192 193
            for inst in self.pbdata.insts:
                feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
                feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
B
barrierye 已提交
194
        elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
195 196 197
            feed = self.future.result()
            if self.callback_func is not None:
                feed = self.callback_func(feed)
B
barrierye 已提交
198 199
        elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            feed = self.npdata
200
        else:
W
wangjiawei04 已提交
201
            raise TypeError("Error type({}) in datatype.".format(datatype))
202 203
        return feed

B
barrierye 已提交
204
    def __str__(self):
W
wangjiawei04 已提交
205 206
        return "type[{}], ecode[{}]".format(
            ChannelDataType(self.datatype).name, self.pbdata.ecode)
B
barrierye 已提交
207

208

W
wangjiawei04 已提交
209
class Channel(Queue.Queue):
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    """ 
    The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
    3. (TODO) Timeout and BatchSize are not fully supported.

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
    """

W
wangjiawei04 已提交
227 228
    def __init__(self, name=None, maxsize=-1, timeout=None):
        Queue.Queue.__init__(self, maxsize=maxsize)
B
barrierye 已提交
229 230
        self._maxsize = maxsize
        self._timeout = timeout
231
        self.name = name
232
        self._stop = False
233

W
wangjiawei04 已提交
234
        self._cv = threading.Condition()
235 236

        self._producers = []
W
wangjiawei04 已提交
237 238 239 240 241 242 243
        self._producer_res_count = {}  # {data_id: count}
        self._push_res = {}  # {data_id: {op_name: data}}

        self._consumers = {}  # {op_name: idx}
        self._idx_consumer_num = {}  # {idx: num}
        self._consumer_base_idx = 0
        self._front_res = []
244 245 246 247 248 249 250 251

    def get_producers(self):
        return self._producers

    def get_consumers(self):
        return self._consumers.keys()

    def _log(self, info_str):
252
        return "[{}] {}".format(self.name, info_str)
253 254 255 256 257 258

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
B
barrierye 已提交
259
        """ not thread safe, and can only be called during initialization. """
260 261 262 263
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)
264 265

    def add_consumer(self, op_name):
B
barrierye 已提交
266
        """ not thread safe, and can only be called during initialization. """
267 268 269 270
        if op_name in self._consumers:
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
        self._consumers[op_name] = 0
271 272 273 274

        if self._idx_consumer_num.get(0) is None:
            self._idx_consumer_num[0] = 0
        self._idx_consumer_num[0] += 1
B
barrierye 已提交
275

276
    def push(self, channeldata, op_name=None):
277
        logging.debug(
278
            self._log("{} try to push data: {}".format(op_name,
B
barrierye 已提交
279
                                                       channeldata.__str__())))
280
        if len(self._producers) == 0:
281
            raise Exception(
282 283 284 285
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
B
barrierye 已提交
286
            with self._cv:
287
                while self._stop is False:
B
barrierye 已提交
288
                    try:
289
                        self.put(channeldata, timeout=0)
B
barrierye 已提交
290
                        break
B
barrierye 已提交
291
                    except Queue.Full:
B
barrierye 已提交
292 293
                        self._cv.wait()
                self._cv.notify_all()
294 295 296 297 298 299
            logging.debug(self._log("{} push data succ!".format(op_name)))
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))
300

301
        producer_num = len(self._producers)
W
wangjiawei04 已提交
302
        data_id = channeldata.pbdata.id
303
        put_data = None
B
barrierye 已提交
304
        with self._cv:
305
            logging.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
306 307 308 309 310 311
            if data_id not in self._push_res:
                self._push_res[data_id] = {
                    name: None
                    for name in self._producers
                }
                self._producer_res_count[data_id] = 0
W
wangjiawei04 已提交
312
            self._push_res[data_id][op_name] = channeldata
B
barrierye 已提交
313 314 315 316 317 318
            if self._producer_res_count[data_id] + 1 == producer_num:
                put_data = self._push_res[data_id]
                self._push_res.pop(data_id)
                self._producer_res_count.pop(data_id)
            else:
                self._producer_res_count[data_id] += 1
319

B
barrierye 已提交
320 321
            if put_data is None:
                logging.debug(
322
                    self._log("{} push data succ, but not push to queue.".
B
barrierye 已提交
323 324
                              format(op_name)))
            else:
325
                while self._stop is False:
B
barrierye 已提交
326 327 328 329 330 331 332 333 334
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()

                logging.debug(
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
335
        return True
336

337
    def front(self, op_name=None):
W
wangjiawei04 已提交
338
        logging.debug(self._log("{} try to get data".format(op_name)))
339 340 341 342 343 344 345
        if len(self._consumers) == 0:
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
        elif len(self._consumers) == 1:
            resp = None
B
barrierye 已提交
346
            with self._cv:
347
                while self._stop is False and resp is None:
B
barrierye 已提交
348
                    try:
W
wangjiawei04 已提交
349
                        resp = self.get(timeout=0)
B
barrierye 已提交
350 351 352
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
353 354 355
            logging.debug(
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
356 357 358 359 360
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))
361

B
barrierye 已提交
362 363
        with self._cv:
            # data_idx = consumer_idx - base_idx
364
            while self._stop is False and self._consumers[
W
wangjiawei04 已提交
365
                    op_name] - self._consumer_base_idx >= len(self._front_res):
B
barrierye 已提交
366
                try:
W
wangjiawei04 已提交
367
                    channeldata = self.get(timeout=0)
368
                    self._front_res.append(channeldata)
B
barrierye 已提交
369 370 371
                    break
                except Queue.Empty:
                    self._cv.wait()
372

B
barrierye 已提交
373
            consumer_idx = self._consumers[op_name]
W
wangjiawei04 已提交
374
            base_idx = self._consumer_base_idx
B
barrierye 已提交
375 376 377 378 379 380 381 382 383
            data_idx = consumer_idx - base_idx
            resp = self._front_res[data_idx]
            logging.debug(self._log("{} get data: {}".format(op_name, resp)))

            self._idx_consumer_num[consumer_idx] -= 1
            if consumer_idx == base_idx and self._idx_consumer_num[
                    consumer_idx] == 0:
                self._idx_consumer_num.pop(consumer_idx)
                self._front_res.pop(0)
W
wangjiawei04 已提交
384
                self._consumer_base_idx += 1
B
barrierye 已提交
385 386 387 388 389 390

            self._consumers[op_name] += 1
            new_consumer_idx = self._consumers[op_name]
            if self._idx_consumer_num.get(new_consumer_idx) is None:
                self._idx_consumer_num[new_consumer_idx] = 0
            self._idx_consumer_num[new_consumer_idx] += 1
W
wangjiawei04 已提交
391

B
barrierye 已提交
392
            self._cv.notify_all()
393

394
        logging.debug(self._log("multi | {} get data succ!".format(op_name)))
395
        return resp  # reference, read only
B
barrierye 已提交
396

397 398 399 400
    def stop(self):
        #TODO
        self.close()
        self._stop = True
B
bug fix  
barrierye 已提交
401
        self._cv.notify_all()
402

B
barrierye 已提交
403 404 405

class Op(object):
    def __init__(self,
406
                 name,
407
                 inputs,
B
barrierye 已提交
408 409 410 411 412
                 server_model=None,
                 server_port=None,
                 device=None,
                 client_config=None,
                 server_name=None,
413
                 fetch_names=None,
B
barrierye 已提交
414
                 concurrency=1,
B
barrierye 已提交
415 416
                 timeout=-1,
                 retry=2):
W
wangjiawei04 已提交
417
        self._run = False
418
        self.name = name  # to identify the type of OP, it must be globally unique
419
        self._concurrency = concurrency  # amount of concurrency
420
        self.set_input_ops(inputs)
W
wangjiawei04 已提交
421 422 423 424
        self.set_client(client_config, server_name, fetch_names)
        self._server_model = server_model
        self._server_port = server_port
        self._device = device
B
barrierye 已提交
425
        self._timeout = timeout
B
bug fix  
barrierye 已提交
426
        self._retry = max(1, retry)
427 428
        self._input = None
        self._outputs = []
B
barrierye 已提交
429

W
wangjiawei04 已提交
430 431 432 433 434
    def set_client(self, client_config, server_name, fetch_names):
        self._client = None
        if client_config is None or \
                server_name is None or \
                fetch_names is None:
435
            return
B
barrierye 已提交
436 437 438 439 440
        self._client = Client()
        self._client.load_client_config(client_config)
        self._client.connect([server_name])
        self._fetch_names = fetch_names

W
wangjiawei04 已提交
441 442 443
    def with_serving(self):
        return self._client is not None

444
    def get_input_channel(self):
445
        return self._input
B
barrierye 已提交
446

447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)

    def add_input_channel(self, channel):
462 463 464 465
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
466
        channel.add_consumer(self.name)
467
        self._input = channel
B
barrierye 已提交
468

469
    def get_output_channels(self):
B
barrierye 已提交
470 471
        return self._outputs

472 473
    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
474
            raise TypeError(
475 476 477 478
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
B
barrierye 已提交
479

480 481
    def preprocess(self, channeldata):
        if isinstance(channeldata, dict):
B
barrierye 已提交
482 483 484
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
485
        feed = channeldata.parse()
486
        return feed
B
barrierye 已提交
487

W
wangjiawei04 已提交
488
    def midprocess(self, data, asyn):
489 490 491 492 493 494 495
        if not isinstance(data, dict):
            raise Exception(
                self._log(
                    'data must be dict type(the output of preprocess()), but get {}'.
                    format(type(data))))
        logging.debug(self._log('data: {}'.format(data)))
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
W
wangjiawei04 已提交
496 497 498 499
        #call_result = self._client.predict(
        #    feed=data, fetch=self._fetch_names, asyn=asyn)
        call_result = self._client.predict(
            feed=data, fetch=self._fetch_names)
500
        logging.debug(self._log("get call_future"))
W
wangjiawei04 已提交
501
        return call_result
B
barrierye 已提交
502 503

    def postprocess(self, output_data):
B
barrierye 已提交
504
        return output_data
B
barrierye 已提交
505 506

    def stop(self):
507 508 509
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
W
wangjiawei04 已提交
510
        self._run = False
B
barrierye 已提交
511

512
    def _parse_channeldata(self, channeldata):
W
wangjiawei04 已提交
513
        data_id, error_pbdata = None, None
514 515 516
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = channeldata.keys()[0]
W
wangjiawei04 已提交
517
            data_id = channeldata[key].pbdata.id
518
            for _, data in channeldata.items():
W
wangjiawei04 已提交
519 520
                if data.pbdata.ecode != ChannelDataEcode.OK.value:
                    error_pbdata = data.pbdata
521 522
                    break
        else:
W
wangjiawei04 已提交
523 524 525 526
            data_id = channeldata.pbdata.id
            if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
                error_pbdata = channeldata.pbdata
        return data_id, error_pbdata
527

W
wangjiawei04 已提交
528
    def _push_to_output_channels(self, data, name=None):
B
bug fix  
barrierye 已提交
529 530
        if name is None:
            name = self.name
W
wangjiawei04 已提交
531
        for channel in self._outputs:
B
bug fix  
barrierye 已提交
532
            channel.push(data, name)
B
barrierye 已提交
533

W
wangjiawei04 已提交
534
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
535
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
536
        log = self._get_log_func(op_info_prefix)
W
wangjiawei04 已提交
537 538
        self._run = True
        while self._run:
B
barrierye 已提交
539
            _profiler.record("{}-get_0".format(op_info_prefix))
W
wangjiawei04 已提交
540
            channeldata = self._input.front(self.name)
B
barrierye 已提交
541
            _profiler.record("{}-get_1".format(op_info_prefix))
B
bug fix  
barrierye 已提交
542
            logging.debug(log("input_data: {}".format(channeldata)))
B
barrierye 已提交
543

W
wangjiawei04 已提交
544
            data_id, error_pbdata = self._parse_channeldata(channeldata)
545

B
bug fix  
barrierye 已提交
546
            # error data in predecessor Op
W
wangjiawei04 已提交
547 548 549 550 551
            if error_pbdata is not None:
                self._push_to_output_channels(
                    ChannelData(
                        datatype=ChannelDataType.CHANNEL_PBDATA.value,
                        pbdata=error_pbdata))
B
barrierye 已提交
552 553
                continue

B
bug fix  
barrierye 已提交
554
            # preprecess
B
barrierye 已提交
555 556
            try:
                _profiler.record("{}-prep_0".format(op_info_prefix))
B
bug fix  
barrierye 已提交
557
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
558 559
                _profiler.record("{}-prep_1".format(op_info_prefix))
            except NotImplementedError as e:
B
bug fix  
barrierye 已提交
560
                # preprocess function not implemented
B
barrierye 已提交
561 562 563 564 565 566
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
W
wangjiawei04 已提交
567
                        data_id=data_id))
B
barrierye 已提交
568
                continue
B
bug fix  
barrierye 已提交
569
            except TypeError as e:
B
barrierye 已提交
570
                # Error type in channeldata.datatype
B
bug fix  
barrierye 已提交
571 572 573 574 575 576
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
W
wangjiawei04 已提交
577
                        data_id=data_id))
B
bug fix  
barrierye 已提交
578 579 580 581 582 583
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
B
barrierye 已提交
584
                        ecode=ChannelDataEcode.UNKNOW.value,
B
bug fix  
barrierye 已提交
585
                        error_info=error_info,
W
wangjiawei04 已提交
586
                        data_id=data_id))
B
bug fix  
barrierye 已提交
587
                continue
588

B
barrierye 已提交
589
            # midprocess
W
wangjiawei04 已提交
590 591 592
            midped_data = None
            asyn = False
            if self.with_serving():
B
bug fix  
barrierye 已提交
593
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
594 595 596
                _profiler.record("{}-midp_0".format(op_info_prefix))
                if self._timeout <= 0:
                    try:
W
wangjiawei04 已提交
597
                        midped_data = self.midprocess(preped_data, asyn)
B
barrierye 已提交
598 599 600 601
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
B
barrierye 已提交
602
                else:
B
barrierye 已提交
603 604
                    for i in range(self._retry):
                        try:
W
wangjiawei04 已提交
605
                            midped_data = func_timeout.func_timeout(
B
bug fix  
barrierye 已提交
606 607
                                self._timeout,
                                self.midprocess,
W
wangjiawei04 已提交
608
                                args=(preped_data, asyn))
B
bug fix  
barrierye 已提交
609
                        except func_timeout.FunctionTimedOut as e:
B
barrierye 已提交
610 611
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
B
bug fix  
barrierye 已提交
612 613
                                error_info = log(e)
                                logging.error(error_info)
B
barrierye 已提交
614 615
                            else:
                                logging.warn(
B
bug fix  
barrierye 已提交
616
                                    log("timeout, retry({})".format(i + 1)))
B
barrierye 已提交
617 618 619 620 621 622 623
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
B
bug fix  
barrierye 已提交
624
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
625 626 627
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
W
wangjiawei04 已提交
628
                            data_id=data_id))
B
barrierye 已提交
629 630
                    continue
                _profiler.record("{}-midp_1".format(op_info_prefix))
W
wangjiawei04 已提交
631 632
            else:
                midped_data = preped_data
633

B
barrierye 已提交
634 635 636
            # postprocess
            output_data = None
            _profiler.record("{}-postp_0".format(op_info_prefix))
W
wangjiawei04 已提交
637
            if self.with_serving() and asyn:
B
bug fix  
barrierye 已提交
638
                # use call_future
B
barrierye 已提交
639
                output_data = ChannelData(
B
barrierye 已提交
640
                    datatype=ChannelDataType.CHANNEL_FUTURE.value,
W
wangjiawei04 已提交
641
                    future=midped_data,
B
barrierye 已提交
642 643 644
                    data_id=data_id,
                    callback_func=self.postprocess)
            else:
B
bug fix  
barrierye 已提交
645
                try:
W
wangjiawei04 已提交
646
                    postped_data = self.postprocess(midped_data)
B
bug fix  
barrierye 已提交
647 648 649 650 651 652 653
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log(e)
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
W
wangjiawei04 已提交
654
                            data_id=data_id))
B
bug fix  
barrierye 已提交
655 656
                    continue
                if not isinstance(postped_data, dict):
B
barrierye 已提交
657 658
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = log("output of postprocess funticon must be " \
B
bug fix  
barrierye 已提交
659
                            "dict type, but get {}".format(type(postped_data)))
B
barrierye 已提交
660 661 662 663
                    logging.error(error_info)
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
W
wangjiawei04 已提交
664
                            data_id=data_id))
B
barrierye 已提交
665
                    continue
B
bug fix  
barrierye 已提交
666

B
barrierye 已提交
667 668 669 670
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
671 672 673 674
            _profiler.record("{}-postp_1".format(op_info_prefix))

            # push data to channel (if run succ)
            _profiler.record("{}-push_0".format(op_info_prefix))
W
wangjiawei04 已提交
675
            self._push_to_output_channels(output_data)
B
barrierye 已提交
676 677 678 679 680 681 682 683 684 685
            _profiler.record("{}-push_1".format(op_info_prefix))

    def _log(self, info):
        return "{} {}".format(self.name, info)

    def _get_log_func(self, op_info_prefix):
        def log_func(info_str):
            return "{} {}".format(op_info_prefix, info_str)

        return log_func
B
barrierye 已提交
686

687 688 689
    def get_concurrency(self):
        return self._concurrency

B
barrierye 已提交
690

B
bug fix  
barrierye 已提交
691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710
class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
            name=name, inputs=None, concurrency=concurrency)
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, Channel):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)

W
wangjiawei04 已提交
711
    def start(self, concurrency_idx):
B
bug fix  
barrierye 已提交
712 713
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
        log = self._get_log_func(op_info_prefix)
W
wangjiawei04 已提交
714 715
        self._run = True
        while self._run:
B
bug fix  
barrierye 已提交
716
            _profiler.record("{}-get_0".format(op_info_prefix))
W
wangjiawei04 已提交
717
            channeldata = self._input.front(self.name)
B
bug fix  
barrierye 已提交
718 719 720 721 722
            _profiler.record("{}-get_1".format(op_info_prefix))

            _profiler.record("{}-push_0".format(op_info_prefix))
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
W
wangjiawei04 已提交
723
                    self._push_to_output_channels(data, name=name)
B
bug fix  
barrierye 已提交
724
            else:
W
wangjiawei04 已提交
725 726
                self._push_to_output_channels(channeldata,
                                              self._virtual_pred_ops[0].name)
B
bug fix  
barrierye 已提交
727 728 729
            _profiler.record("{}-push_1".format(op_info_prefix))


B
barrierye 已提交
730
class GeneralPythonService(
W
wangjiawei04 已提交
731
        general_python_service_pb2_grpc.GeneralPythonService):
B
barrierye 已提交
732
    def __init__(self, in_channel, out_channel, retry=2):
B
barrierye 已提交
733
        super(GeneralPythonService, self).__init__()
734
        self.name = "#G"
735 736
        self.set_in_channel(in_channel)
        self.set_out_channel(out_channel)
737 738
        logging.debug(self._log(in_channel.debug()))
        logging.debug(self._log(out_channel.debug()))
B
barrierye 已提交
739 740 741 742 743
        #TODO: 
        #  multi-lock for different clients
        #  diffenert lock for server and client
        self._id_lock = threading.Lock()
        self._cv = threading.Condition()
B
barrierye 已提交
744 745
        self._globel_resp_dict = {}
        self._id_counter = 0
B
barrierye 已提交
746
        self._retry = retry
B
barrierye 已提交
747 748 749
        self._recive_func = threading.Thread(
            target=GeneralPythonService._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
750 751

    def _log(self, info_str):
752
        return "[{}] {}".format(self.name, info_str)
B
barrierye 已提交
753

754
    def set_in_channel(self, in_channel):
755 756 757 758
        if not isinstance(in_channel, Channel):
            raise TypeError(
                self._log('in_channel must be Channel type, but get {}'.format(
                    type(in_channel))))
759
        in_channel.add_producer(self.name)
760 761 762
        self._in_channel = in_channel

    def set_out_channel(self, out_channel):
763 764 765 766
        if not isinstance(out_channel, Channel):
            raise TypeError(
                self._log('out_channel must be Channel type, but get {}'.format(
                    type(out_channel))))
767
        out_channel.add_consumer(self.name)
768 769
        self._out_channel = out_channel

B
barrierye 已提交
770 771
    def _recive_out_channel_func(self):
        while True:
772
            channeldata = self._out_channel.front(self.name)
773
            if not isinstance(channeldata, ChannelData):
774 775
                raise TypeError(
                    self._log('data must be ChannelData type, but get {}'.
776
                              format(type(channeldata))))
B
barrierye 已提交
777
            with self._cv:
W
wangjiawei04 已提交
778
                data_id = channeldata.pbdata.id
779
                self._globel_resp_dict[data_id] = channeldata
B
barrierye 已提交
780
                self._cv.notify_all()
B
barrierye 已提交
781 782

    def _get_next_id(self):
B
barrierye 已提交
783
        with self._id_lock:
B
barrierye 已提交
784 785 786 787
            self._id_counter += 1
            return self._id_counter - 1

    def _get_data_in_globel_resp_dict(self, data_id):
B
barrierye 已提交
788 789 790 791 792 793
        resp = None
        with self._cv:
            while data_id not in self._globel_resp_dict:
                self._cv.wait()
            resp = self._globel_resp_dict.pop(data_id)
            self._cv.notify_all()
B
barrierye 已提交
794
        return resp
B
barrierye 已提交
795 796

    def _pack_data_for_infer(self, request):
797
        logging.debug(self._log('start inferce'))
W
wangjiawei04 已提交
798
        pbdata = channel_pb2.ChannelData()
B
barrierye 已提交
799
        data_id = self._get_next_id()
W
wangjiawei04 已提交
800 801
        pbdata.id = data_id
        pbdata.ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
802 803 804 805 806 807
        try:
            for idx, name in enumerate(request.feed_var_names):
                logging.debug(
                    self._log('name: {}'.format(request.feed_var_names[idx])))
                logging.debug(
                    self._log('data: {}'.format(request.feed_insts[idx])))
W
wangjiawei04 已提交
808 809 810 811 812 813
                inst = channel_pb2.Inst()
                inst.data = request.feed_insts[idx]
                inst.shape = request.shape[idx]
                inst.name = name
                inst.type = request.type[idx]
                pbdata.insts.append(inst)
B
barrierye 已提交
814
        except Exception as e:
W
wangjiawei04 已提交
815 816 817 818 819
            pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
            pbdata.error_info = "rpc package error"
        return ChannelData(
            datatype=ChannelDataType.CHANNEL_PBDATA.value,
            pbdata=pbdata), data_id
B
barrierye 已提交
820

821 822 823
    def _pack_data_for_resp(self, channeldata):
        logging.debug(self._log('get channeldata'))
        resp = pyservice_pb2.Response()
W
wangjiawei04 已提交
824
        resp.ecode = channeldata.pbdata.ecode
B
bug fix  
barrierye 已提交
825
        if resp.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
826
            if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
827 828 829 830 831
                for inst in channeldata.pbdata.insts:
                    resp.fetch_insts.append(inst.data)
                    resp.fetch_var_names.append(inst.name)
                    resp.shape.append(inst.shape)
                    resp.type.append(inst.type)
B
barrierye 已提交
832 833 834
            elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
                                          ChannelDataType.CHANNEL_NPDATA.value):
                feed = channeldata.parse()
B
bug fix  
barrierye 已提交
835
                for name, var in feed.items():
836 837 838 839 840 841 842 843
                    resp.fetch_insts.append(var.tobytes())
                    resp.fetch_var_names.append(name)
                    resp.shape.append(
                        np.array(
                            var.shape, dtype="int32").tobytes())
                    resp.type.append(str(var.dtype))
            else:
                raise TypeError(
B
barrierye 已提交
844
                    self._log("Error type({}) in datatype.".format(
B
barrierye 已提交
845
                        channeldata.datatype)))
B
barrierye 已提交
846
        else:
W
wangjiawei04 已提交
847
            resp.error_info = channeldata.pbdata.error_info
B
barrierye 已提交
848
        return resp
B
barrierye 已提交
849

B
barrierye 已提交
850
    def inference(self, request, context):
851
        _profiler.record("{}-prepack_0".format(self.name))
B
barrierye 已提交
852
        data, data_id = self._pack_data_for_infer(request)
853
        _profiler.record("{}-prepack_1".format(self.name))
B
barrierye 已提交
854

855
        resp_channeldata = None
B
barrierye 已提交
856 857
        for i in range(self._retry):
            logging.debug(self._log('push data'))
858 859 860
            _profiler.record("{}-push_0".format(self.name))
            self._in_channel.push(data, self.name)
            _profiler.record("{}-push_1".format(self.name))
B
barrierye 已提交
861 862

            logging.debug(self._log('wait for infer'))
863
            _profiler.record("{}-fetch_0".format(self.name))
864
            resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
865
            _profiler.record("{}-fetch_1".format(self.name))
B
barrierye 已提交
866

W
wangjiawei04 已提交
867
            if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
B
barrierye 已提交
868
                break
B
barrierye 已提交
869 870
            if i + 1 < self._retry:
                logging.warn("retry({}): {}".format(
W
wangjiawei04 已提交
871
                    i + 1, resp_channeldata.pbdata.error_info))
B
barrierye 已提交
872

873
        _profiler.record("{}-postpack_0".format(self.name))
874
        resp = self._pack_data_for_resp(resp_channeldata)
875
        _profiler.record("{}-postpack_1".format(self.name))
B
barrierye 已提交
876
        _profiler.print_profile()
B
barrierye 已提交
877 878
        return resp

B
barrierye 已提交
879 880

class PyServer(object):
B
barrierye 已提交
881
    def __init__(self, retry=2, profile=False):
B
barrierye 已提交
882
        self._channels = []
883
        self._user_ops = []
B
bug fix  
barrierye 已提交
884
        self._actual_ops = []
W
wangjiawei04 已提交
885
        self._op_threads = []
B
barrierye 已提交
886 887
        self._port = None
        self._worker_num = None
B
barrierye 已提交
888 889
        self._in_channel = None
        self._out_channel = None
B
barrierye 已提交
890
        self._retry = retry
B
barrierye 已提交
891
        _profiler.enable(profile)
B
barrierye 已提交
892 893 894 895 896

    def add_channel(self, channel):
        self._channels.append(channel)

    def add_op(self, op):
897 898 899
        self._user_ops.append(op)

    def add_ops(self, ops):
B
fix bug  
barrierye 已提交
900
        self._user_ops.extend(ops)
B
barrierye 已提交
901 902

    def gen_desc(self):
903
        logging.info('here will generate desc for PAAS')
B
barrierye 已提交
904 905
        pass

906 907 908
    def _topo_sort(self):
        indeg_num = {}
        que_idx = 0  # scroll queue 
B
fix bug  
barrierye 已提交
909
        ques = [Queue.Queue() for _ in range(2)]
B
bug fix  
barrierye 已提交
910 911 912 913 914
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
                op.name = "#G"  # update read_op.name
                break
        outdegs = {op.name: [] for op in self._user_ops}
915 916 917 918 919 920 921 922
        for idx, op in enumerate(self._user_ops):
            # check the name of op is globally unique
            if op.name in indeg_num:
                raise Exception("the name of Op must be unique")
            indeg_num[op.name] = len(op.get_input_ops())
            if indeg_num[op.name] == 0:
                ques[que_idx].put(op)
            for pred_op in op.get_input_ops():
B
fix bug  
barrierye 已提交
923
                outdegs[pred_op.name].append(op)
924

B
bug fix  
barrierye 已提交
925
        # topo sort to get dag_views
926 927 928 929 930 931 932 933 934 935
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
B
bug fix  
barrierye 已提交
936
                for succ_op in outdegs[op.name]:
B
fix bug  
barrierye 已提交
937
                    indeg_num[succ_op.name] -= 1
938 939 940 941 942 943 944 945
                    if indeg_num[succ_op.name] == 0:
                        next_que.put(succ_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
        if sorted_op_num < len(self._user_ops):
            raise Exception("not legal DAG")
W
wangjiawei04 已提交
946 947 948 949
        if len(dag_views[0]) != 1:
            raise Exception("DAG contains multiple input Ops")
        if len(dag_views[-1]) != 1:
            raise Exception("DAG contains multiple output Ops")
950 951

        # create channels and virtual ops
B
bug fix  
barrierye 已提交
952 953 954 955 956 957 958 959 960 961 962
        def name_generator(prefix):
            def number_generator():
                idx = 0
                while True:
                    yield "{}{}".format(prefix, idx)
                    idx += 1

            return number_generator()

        virtual_op_name_gen = name_generator("vir")
        channel_name_gen = name_generator("chl")
963 964 965
        virtual_ops = []
        channels = []
        input_channel = None
B
bug fix  
barrierye 已提交
966
        actual_view = None
967 968 969 970
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
B
bug fix  
barrierye 已提交
971 972
            if actual_view is None:
                actual_view = view
973 974
            actual_next_view = []
            pred_op_of_next_view_op = {}
B
bug fix  
barrierye 已提交
975 976
            for op in actual_view:
                # find actual succ op in next view and create virtual op
977 978
                for succ_op in outdegs[op.name]:
                    if succ_op in next_view:
B
bug fix  
barrierye 已提交
979 980
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
981 982 983 984
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
B
bug fix  
barrierye 已提交
985 986
                        # create virtual op
                        virtual_op = None
W
wangjiawei04 已提交
987
                        virtual_op = VirtualOp(name=virtual_op_name_gen.next())
988
                        virtual_ops.append(virtual_op)
B
bug fix  
barrierye 已提交
989 990 991 992 993
                        outdegs[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
994 995 996
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
B
bug fix  
barrierye 已提交
997
                if op.name in processed_op:
998
                    continue
W
wangjiawei04 已提交
999
                channel = Channel(name=channel_name_gen.next())
1000
                channels.append(channel)
B
bug fix  
barrierye 已提交
1001
                logging.debug("{} => {}".format(channel.name, op.name))
1002
                op.add_input_channel(channel)
B
bug fix  
barrierye 已提交
1003
                pred_ops = pred_op_of_next_view_op[op.name]
1004 1005 1006
                if v_idx == 0:
                    input_channel = channel
                else:
B
bug fix  
barrierye 已提交
1007
                    # if pred_op is virtual op, it will use ancestors as producers to channel
1008
                    for pred_op in pred_ops:
B
bug fix  
barrierye 已提交
1009 1010
                        logging.debug("{} => {}".format(pred_op.name,
                                                        channel.name))
1011
                        pred_op.add_output_channel(channel)
B
bug fix  
barrierye 已提交
1012 1013 1014
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
bug fix  
barrierye 已提交
1026 1027
                        logging.debug("{} => {}".format(channel.name,
                                                        other_op.name))
1028 1029
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
W
wangjiawei04 已提交
1030
        output_channel = Channel(name=channel_name_gen.next())
1031
        channels.append(output_channel)
B
bug fix  
barrierye 已提交
1032
        last_op = dag_views[-1][0]
1033 1034
        last_op.add_output_channel(output_channel)

B
bug fix  
barrierye 已提交
1035
        self._actual_ops = virtual_ops
B
fix bug  
barrierye 已提交
1036 1037
        for op in self._user_ops:
            if len(op.get_input_ops()) == 0:
B
bug fix  
barrierye 已提交
1038
                # pass read op
B
fix bug  
barrierye 已提交
1039
                continue
B
bug fix  
barrierye 已提交
1040
            self._actual_ops.append(op)
1041
        self._channels = channels
B
bug fix  
barrierye 已提交
1042 1043
        for c in channels:
            logging.debug(c.debug())
1044 1045
        return input_channel, output_channel

B
barrierye 已提交
1046 1047 1048
    def prepare_server(self, port, worker_num):
        self._port = port
        self._worker_num = worker_num
1049 1050 1051

        input_channel, output_channel = self._topo_sort()
        self._in_channel = input_channel
B
fix bug  
barrierye 已提交
1052
        self._out_channel = output_channel
B
bug fix  
barrierye 已提交
1053
        for op in self._actual_ops:
W
wangjiawei04 已提交
1054
            if op.with_serving():
B
fix bug  
barrierye 已提交
1055
                self.prepare_serving(op)
B
barrierye 已提交
1056 1057
        self.gen_desc()

W
wangjiawei04 已提交
1058 1059 1060
    def _op_start_wrapper(self, op, concurrency_idx):
        return op.start(concurrency_idx)

1061
    def _run_ops(self):
B
bug fix  
barrierye 已提交
1062
        for op in self._actual_ops:
W
wangjiawei04 已提交
1063 1064 1065 1066 1067 1068 1069 1070
            op_concurrency = op.get_concurrency()
            logging.debug("run op: {}, op_concurrency: {}".format(
                op.name, op_concurrency))
            for c in range(op_concurrency):
                th = threading.Thread(
                    target=self._op_start_wrapper, args=(op, c))
                th.start()
                self._op_threads.append(th)
1071

1072
    def _stop_ops(self):
B
bug fix  
barrierye 已提交
1073
        for op in self._actual_ops:
1074 1075
            op.stop()

1076
    def run_server(self):
W
wangjiawei04 已提交
1077
        self._run_ops()
B
barrierye 已提交
1078 1079
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self._worker_num))
B
barrierye 已提交
1080
        general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
B
barrierye 已提交
1081 1082
            GeneralPythonService(self._in_channel, self._out_channel,
                                 self._retry), server)
B
barrierye 已提交
1083
        server.add_insecure_port('[::]:{}'.format(self._port))
B
barrierye 已提交
1084
        server.start()
1085 1086
        server.wait_for_termination()
        self._stop_ops()  # TODO
W
wangjiawei04 已提交
1087 1088
        for th in self._op_threads:
            th.join()
B
barrierye 已提交
1089 1090 1091 1092 1093 1094 1095

    def prepare_serving(self, op):
        model_path = op._server_model
        port = op._server_port
        device = op._device

        if device == "cpu":
1096 1097
            cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
B
barrierye 已提交
1098
        else:
1099 1100
            cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
                  " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
1101 1102
        # run a server (not in PyServing)
        logging.info("run a server (not in PyServing): {}".format(cmd))