server.py 4.2 KB
Newer Older
P
peng.xu 已提交
1
import logging
2
import sys
P
peng.xu 已提交
3 4 5 6 7 8 9 10
import grpc
import time
import socket
import inspect
from urllib.parse import urlparse
from functools import wraps
from concurrent import futures
from grpc._cython import cygrpc
11
import milvus
P
peng.xu 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
from mishards.grpc_utils import is_grpc_method
from mishards.service_handler import ServiceHandler
from mishards import settings

logger = logging.getLogger(__name__)


class Server:
    def __init__(self):
        self.pre_run_handlers = set()
        self.grpc_methods = set()
        self.error_handlers = {}
        self.exit_flag = False

    def init_app(self,
28 29
                 writable_topo,
                 readonly_topo,
P
peng.xu 已提交
30 31 32 33 34 35 36
                 tracer,
                 router,
                 discover,
                 port=19530,
                 max_workers=10,
                 **kwargs):
        self.port = int(port)
37 38
        self.writable_topo = writable_topo
        self.readonly_topo = readonly_topo
P
peng.xu 已提交
39 40 41 42
        self.tracer = tracer
        self.router = router
        self.discover = discover

43 44
        logger.debug('Init grpc server with max_workers: {}'.format(max_workers))

P
peng.xu 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58
        self.server_impl = grpc.server(
            thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
            options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
                     (cygrpc.ChannelArgKey.max_receive_message_length, -1)])

        self.server_impl = self.tracer.decorate(self.server_impl)

        self.register_pre_run_handler(self.pre_run_handler)

    def pre_run_handler(self):
        woserver = settings.WOSERVER
        url = urlparse(woserver)
        ip = socket.gethostbyname(url.hostname)
        socket.inet_pton(socket.AF_INET, ip)
59 60
        _, group = self.writable_topo.create('default')
        group.create(name='WOSERVER', uri='{}://{}:{}'.format(url.scheme, ip, url.port or 80))
P
peng.xu 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

    def register_pre_run_handler(self, func):
        logger.info('Regiterring {} into server pre_run_handlers'.format(func))
        self.pre_run_handlers.add(func)
        return func

    def wrap_method_with_errorhandler(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except Exception as e:
                if e.__class__ in self.error_handlers:
                    return self.error_handlers[e.__class__](e)
                raise

        return wrapper

    def errorhandler(self, exception):
        if inspect.isclass(exception) and issubclass(exception, Exception):

            def wrapper(func):
                self.error_handlers[exception] = func
                return func

            return wrapper
        return exception

    def on_pre_run(self):
        for handler in self.pre_run_handlers:
            handler()
92
        return self.discover.start()
P
peng.xu 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105

    def start(self, port=None):
        handler_class = self.decorate_handler(ServiceHandler)
        add_MilvusServiceServicer_to_server(
            handler_class(tracer=self.tracer,
                          router=self.router), self.server_impl)
        self.server_impl.add_insecure_port("[::]:{}".format(
            str(port or self.port)))
        self.server_impl.start()

    def run(self, port):
        logger.info('Milvus server start ......')
        port = port or self.port
106 107 108 109 110
        ok = self.on_pre_run()

        if not ok:
            logger.error('Terminate server due to error found in on_pre_run')
            sys.exit(1)
P
peng.xu 已提交
111 112

        self.start(port)
113 114
        logger.info(f'Server Version: {settings.SERVER_VERSIONS[-1]}')
        logger.info(f'Python SDK Version: {milvus.__version__}')
P
peng.xu 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
        logger.info('Listening on port {}'.format(port))

        try:
            while not self.exit_flag:
                time.sleep(5)
        except KeyboardInterrupt:
            self.stop()

    def stop(self):
        logger.info('Server is shuting down ......')
        self.exit_flag = True
        self.server_impl.stop(0)
        self.tracer.close()
        logger.info('Server is closed')

    def decorate_handler(self, handler):
        for key, attr in handler.__dict__.items():
            if is_grpc_method(attr):
                setattr(handler, key, self.wrap_method_with_errorhandler(attr))
        return handler