server.py 4.1 KB
Newer Older
P
peng.xu 已提交
1
import logging
2
import sys
P
peng.xu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
import grpc
import time
import socket
import inspect
from urllib.parse import urlparse
from functools import wraps
from concurrent import futures
from grpc._cython import cygrpc
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
from mishards.grpc_utils import is_grpc_method
from mishards.service_handler import ServiceHandler
from mishards import settings

logger = logging.getLogger(__name__)


class Server:
    def __init__(self):
        self.pre_run_handlers = set()
        self.grpc_methods = set()
        self.error_handlers = {}
        self.exit_flag = False

    def init_app(self,
27 28
                 writable_topo,
                 readonly_topo,
P
peng.xu 已提交
29 30 31 32 33 34 35
                 tracer,
                 router,
                 discover,
                 port=19530,
                 max_workers=10,
                 **kwargs):
        self.port = int(port)
36 37
        self.writable_topo = writable_topo
        self.readonly_topo = readonly_topo
P
peng.xu 已提交
38 39 40 41
        self.tracer = tracer
        self.router = router
        self.discover = discover

42 43
        logger.debug('Init grpc server with max_workers: {}'.format(max_workers))

P
peng.xu 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57
        self.server_impl = grpc.server(
            thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
            options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
                     (cygrpc.ChannelArgKey.max_receive_message_length, -1)])

        self.server_impl = self.tracer.decorate(self.server_impl)

        self.register_pre_run_handler(self.pre_run_handler)

    def pre_run_handler(self):
        woserver = settings.WOSERVER
        url = urlparse(woserver)
        ip = socket.gethostbyname(url.hostname)
        socket.inet_pton(socket.AF_INET, ip)
58 59
        _, group = self.writable_topo.create('default')
        group.create(name='WOSERVER', uri='{}://{}:{}'.format(url.scheme, ip, url.port or 80))
P
peng.xu 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

    def register_pre_run_handler(self, func):
        logger.info('Regiterring {} into server pre_run_handlers'.format(func))
        self.pre_run_handlers.add(func)
        return func

    def wrap_method_with_errorhandler(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except Exception as e:
                if e.__class__ in self.error_handlers:
                    return self.error_handlers[e.__class__](e)
                raise

        return wrapper

    def errorhandler(self, exception):
        if inspect.isclass(exception) and issubclass(exception, Exception):

            def wrapper(func):
                self.error_handlers[exception] = func
                return func

            return wrapper
        return exception

    def on_pre_run(self):
        for handler in self.pre_run_handlers:
            handler()
91
        return self.discover.start()
P
peng.xu 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104

    def start(self, port=None):
        handler_class = self.decorate_handler(ServiceHandler)
        add_MilvusServiceServicer_to_server(
            handler_class(tracer=self.tracer,
                          router=self.router), self.server_impl)
        self.server_impl.add_insecure_port("[::]:{}".format(
            str(port or self.port)))
        self.server_impl.start()

    def run(self, port):
        logger.info('Milvus server start ......')
        port = port or self.port
105 106 107 108 109
        ok = self.on_pre_run()

        if not ok:
            logger.error('Terminate server due to error found in on_pre_run')
            sys.exit(1)
P
peng.xu 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

        self.start(port)
        logger.info('Listening on port {}'.format(port))

        try:
            while not self.exit_flag:
                time.sleep(5)
        except KeyboardInterrupt:
            self.stop()

    def stop(self):
        logger.info('Server is shuting down ......')
        self.exit_flag = True
        self.server_impl.stop(0)
        self.tracer.close()
        logger.info('Server is closed')

    def decorate_handler(self, handler):
        for key, attr in handler.__dict__.items():
            if is_grpc_method(attr):
                setattr(handler, key, self.wrap_method_with_errorhandler(attr))
        return handler