import copy import time import json import logging import threading from functools import wraps from collections import defaultdict from milvus import Milvus from mishards import (settings, exceptions, topology) from utils import singleton logger = logging.getLogger(__name__) # class Searchook(BaseSearchHook): # # def on_response(self, *args, **kwargs): # return True # # # class Connection: # def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): # self.name = name # self.uri = uri # self.max_retry = max_retry # self.retried = 0 # self.conn = Milvus() # self.error_handlers = [] if not error_handlers else error_handlers # self.on_retry_func = kwargs.get('on_retry_func', None) # # # define search hook # self.conn.set_hook(search_in_file=Searchook()) # # self._connect() # # def __str__(self): # return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri) # # def _connect(self, metadata=None): # try: # self.conn.connect(uri=self.uri) # except Exception as e: # if not self.error_handlers: # raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata) # for handler in self.error_handlers: # handler(e, metadata=metadata) # # @property # def can_retry(self): # return self.retried < self.max_retry # # @property # def connected(self): # return self.conn.connected() # # def on_retry(self): # if self.on_retry_func: # self.on_retry_func(self) # else: # self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried)) # # def on_connect(self, metadata=None): # while not self.connected and self.can_retry: # self.retried += 1 # self.on_retry() # self._connect(metadata=metadata) # # if not self.can_retry and not self.connected: # raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry, # metadata=metadata)) # # self.retried = 0 # # def connect(self, func, exception_handler=None): # @wraps(func) # def inner(*args, **kwargs): # self.on_connect() # try: # return func(*args, **kwargs) # except Exception as e: # if exception_handler: # exception_handler(e) # else: # raise e # return inner # # def __str__(self): # return ''.format(self.name, id(self)) # # def __repr__(self): # return self.__str__() # # # class Duration: # def __init__(self): # self.start_ts = time.time() # self.end_ts = None # # def stop(self): # if self.end_ts: # return False # # self.end_ts = time.time() # return True # # @property # def value(self): # if not self.end_ts: # return None # # return self.end_ts - self.start_ts # # # class ProxyMixin: # def __getattr__(self, name): # target = self.__dict__.get(name, None) # if target or not self.connection: # return target # return getattr(self.connection, name) # # # class ScopedConnection(ProxyMixin): # def __init__(self, pool, connection): # self.pool = pool # self.connection = connection # self.duration = Duration() # # def __del__(self): # self.release() # # def __str__(self): # return self.connection.__str__() # # def release(self): # if not self.pool or not self.connection: # return # self.pool.release(self.connection) # self.duration.stop() # self.pool.record_duration(self.connection, self.duration) # self.pool = None # self.connection = None # # # class ConnectionPool(topology.TopoObject): # def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs): # super().__init__(name) # self.capacity = capacity # self.pending_pool = set() # self.active_pool = set() # self.connection_ownership = {} # self.uri = uri # self.max_retry = max_retry # self.kwargs = kwargs # self.cv = threading.Condition() # self.durations = defaultdict(list) # # def record_duration(self, conn, duration): # if len(self.durations[conn]) >= 10000: # self.durations[conn].pop(0) # # self.durations[conn].append(duration) # # def stats(self): # out = {'connections': {}} # connections = out['connections'] # take_time = [] # for conn, durations in self.durations.items(): # total_time = sum(d.value for d in durations) # connections[id(conn)] = { # 'total_time': total_time, # 'called_times': len(durations) # } # take_time.append(total_time) # # out['max-time'] = max(take_time) # out['num'] = len(self.durations) # logger.debug(json.dumps(out, indent=2)) # return out # # def __len__(self): # return len(self.pending_pool) + len(self.active_pool) # # @property # def active_num(self): # return len(self.active_pool) # # def _is_full(self): # if self.capacity < 0: # return False # return len(self) >= self.capacity # # def fetch(self, timeout=1): # with self.cv: # timeout_times = 0 # while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1): # self.cv.notifyAll() # self.cv.wait(timeout) # timeout_times += 1 # # connection = None # if timeout_times >= 1: # return connection # # # logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) # if len(self.pending_pool) == 0: # connection = self.create() # else: # connection = self.pending_pool.pop() # # logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name)) # self.active_pool.add(connection) # scoped_connection = ScopedConnection(self, connection) # return scoped_connection # # def release(self, connection): # with self.cv: # if connection not in self.active_pool: # raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name)) # # logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name)) # # logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) # self.active_pool.remove(connection) # self.pending_pool.add(connection) # # def create(self): # connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs) # return connection class ConnectionGroup(topology.TopoGroup): def __init__(self, name): super().__init__(name) def stats(self): out = {} for name, item in self.items.items(): out[name] = item.stats() return out def on_pre_add(self, topo_object): # conn = topo_object.fetch() # conn.on_connect(metadata=None) status, version = topo_object.server_version() if not status.OK(): logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name)) return False if version not in settings.SERVER_VERSIONS: logger.error('Cannot connect to server of version: {}. Only {} supported'.format(version, settings.SERVER_VERSIONS)) return False return True def create(self, name, **kwargs): uri = kwargs.get('uri', None) if not uri: raise RuntimeError('\"uri\" is required to create connection pool') milvus_args = copy.deepcopy(kwargs) milvus_args["max_retry"] = settings.MAX_RETRY pool = Milvus(name=name, **milvus_args) status = self.add(pool) if status != topology.StatusType.OK: pool = None return status, pool class ConnectionTopology(topology.Topology): def __init__(self): super().__init__() def stats(self): out = {} for name, group in self.topo_groups.items(): out[name] = group.stats() return out def create(self, name): group = ConnectionGroup(name) status = self.add_group(group) if status == topology.StatusType.DUPLICATED: group = None return status, group