server.py 3.9 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import logging
import six
import asyncio
import threading

import grpc
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import propeller.paddle.service.utils as serv_utils

from concurrent.futures import ThreadPoolExecutor

import paddle.fluid as F

from time import sleep, time

log = logging.getLogger(__name__)


def profile(msg):
    def decfn(fn):
        def retfn(*args, **kwargs):
            start = time()
            ret = fn(*args, **kwargs)
            end = time()
            log.debug('%s timecost: %.5f' % (msg, end - start))
            return ret

        return retfn

    return decfn


def serve(model_dir, host, num_concurrent=None):
    if six.PY2:
        raise RuntimeError('propeller service work in python3 only')
    num_worker = len(F.cuda_places(
    )) if num_concurrent is None else num_concurrent
    pool = ThreadPoolExecutor(num_worker)

    class Predictor(object):
        def __init__(self, did):
            log.debug('create predictor on card %d' % did)
            config = F.core.AnalysisConfig(model_dir)
            config.enable_use_gpu(5000, did)
            self._predictor = F.core.create_paddle_predictor(config)

        @profile('paddle')
        def __call__(self, args):
            for i, a in enumerate(args):
                a.name = 'placeholder_%d' % i
            res = self._predictor.run(args)
            return res

    predictor_context = {}

    class InferenceService(interface_pb2_grpc.InferenceServicer):
        @profile('service')
        def Infer(self, request, context):
            try:
                slots = request.slots
                current_thread = threading.current_thread()
                log.debug('%d slots received dispatch to thread %s' %
                          (len(slots), current_thread))
                if current_thread not in predictor_context:
                    did = list(pool._threads).index(current_thread)
                    log.debug('spawning worker thread %d' % did)
                    predictor = Predictor(did)
                    predictor_context[current_thread] = predictor
                else:
                    predictor = predictor_context[current_thread]
                slots = [serv_utils.slot_to_paddlearray(s) for s in slots]
                ret = predictor(slots)
                response = [serv_utils.paddlearray_to_slot(r) for r in ret]
            except Exception as e:
                log.exception(e)
                raise e
            return interface_pb2.Slots(slots=response)

    server = grpc.server(pool)
    interface_pb2_grpc.add_InferenceServicer_to_server(InferenceService(),
                                                       server)
    server.add_insecure_port(host)
    server.start()
    log.info('server started on %s...' % host)
    try:
        while True:
            sleep(100000)
    except KeyboardInterrupt as e:
        pass
    log.info('server stoped...')


if __name__ == '__main__':
    from propeller import log
    log.setLevel(logging.DEBUG)
    serve(
        '/home/work/chenxuyi/playground/grpc_play/ernie2.0/',
        '10.255.138.19:8334',
        num_concurrent=3)