From 8a68d34919811b83a8db8327f999b75803d5971c Mon Sep 17 00:00:00 2001 From: BossZou <40255591+BossZou@users.noreply.github.com> Date: Fri, 8 May 2020 16:40:30 +0800 Subject: [PATCH] Mishards upgrade (#2260) * shards ci Signed-off-by: yhz <413554850@qq.com> * update interface Signed-off-by: Yhz * update mishards Signed-off-by: Yhz * update changlog Signed-off-by: yhz <413554850@qq.com> * [skip ci] mishards dev test pass (fix #2252) Signed-off-by: yhz <413554850@qq.com> --- CHANGELOG.md | 4 + shards/mishards/connections.py | 430 +++++++++--------- .../mishards/grpc_utils/grpc_args_parser.py | 4 +- shards/mishards/router/__init__.py | 26 +- .../plugins/file_based_hash_ring_router.py | 19 +- shards/mishards/service_handler.py | 195 +++++--- shards/requirements.txt | 4 +- tests/milvus_python_test/test_connect.py | 4 +- 8 files changed, 378 insertions(+), 308 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d20d9dc..04239337 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,13 @@ Please mark all change in change log and use the issue from GitHub - \#1997 Index file missed after compact - \#2073 Fix CheckDBConfigBackendUrl error message - \#2076 CheckMetricConfigAddress error message +- \#2120 Fix Search expected failed if search params set invalid +- \#2121 Allow regex match partition tag when search - \#2128 Check has_partition params - \#2131 Distance/ID returned is not correct if searching with duplicate ids - \#2141 Fix server start failed if wal directory exist - \#2169 Fix SingleIndexTest.IVFSQHybrid unittest +- \#2194 Fix get collection info failed - \#2196 Fix server start failed if wal is disabled - \#2231 Use server_config to define hard-delete delay time for segment files @@ -45,6 +48,7 @@ Please mark all change in change log and use the issue from GitHub - \#2185 Change id to string format in http module - \#2186 Update endpoints in http module - \#2190 Fix memory usage is twice of index size when using GPU searching +- \#2252 Upgrade mishards to v0.9.0 ## Task diff --git a/shards/mishards/connections.py b/shards/mishards/connections.py index ab8c78e3..e94e99f5 100644 --- a/shards/mishards/connections.py +++ b/shards/mishards/connections.py @@ -5,7 +5,7 @@ import threading from functools import wraps from collections import defaultdict from milvus import Milvus -from milvus.client.hooks import BaseSearchHook +# from milvus.client.hooks import BaseSearchHook from mishards import (settings, exceptions, topology) from utils import singleton @@ -13,216 +13,216 @@ from utils import singleton logger = logging.getLogger(__name__) -class Searchook(BaseSearchHook): - - def on_response(self, *args, **kwargs): - return True - - -class Connection: - def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): - self.name = name - self.uri = uri - self.max_retry = max_retry - self.retried = 0 - self.conn = Milvus() - self.error_handlers = [] if not error_handlers else error_handlers - self.on_retry_func = kwargs.get('on_retry_func', None) - - # define search hook - self.conn.set_hook(search_in_file=Searchook()) - # self._connect() - - def __str__(self): - return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri) - - def _connect(self, metadata=None): - try: - self.conn.connect(uri=self.uri) - except Exception as e: - if not self.error_handlers: - raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata) - for handler in self.error_handlers: - handler(e, metadata=metadata) - - @property - def can_retry(self): - return self.retried < self.max_retry - - @property - def connected(self): - return self.conn.connected() - - def on_retry(self): - if self.on_retry_func: - self.on_retry_func(self) - else: - self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried)) - - def on_connect(self, metadata=None): - while not self.connected and self.can_retry: - self.retried += 1 - self.on_retry() - self._connect(metadata=metadata) - - if not self.can_retry and not self.connected: - raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry, - metadata=metadata)) - - self.retried = 0 - - def connect(self, func, exception_handler=None): - @wraps(func) - def inner(*args, **kwargs): - self.on_connect() - try: - return func(*args, **kwargs) - except Exception as e: - if exception_handler: - exception_handler(e) - else: - raise e - return inner - - def __str__(self): - return ''.format(self.name, id(self)) - - def __repr__(self): - return self.__str__() - - -class Duration: - def __init__(self): - self.start_ts = time.time() - self.end_ts = None - - def stop(self): - if self.end_ts: - return False - - self.end_ts = time.time() - return True - - @property - def value(self): - if not self.end_ts: - return None - - return self.end_ts - self.start_ts - - -class ProxyMixin: - def __getattr__(self, name): - target = self.__dict__.get(name, None) - if target or not self.connection: - return target - return getattr(self.connection, name) - - -class ScopedConnection(ProxyMixin): - def __init__(self, pool, connection): - self.pool = pool - self.connection = connection - self.duration = Duration() - - def __del__(self): - self.release() - - def __str__(self): - return self.connection.__str__() - - def release(self): - if not self.pool or not self.connection: - return - self.pool.release(self.connection) - self.duration.stop() - self.pool.record_duration(self.connection, self.duration) - self.pool = None - self.connection = None - - -class ConnectionPool(topology.TopoObject): - def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs): - super().__init__(name) - self.capacity = capacity - self.pending_pool = set() - self.active_pool = set() - self.connection_ownership = {} - self.uri = uri - self.max_retry = max_retry - self.kwargs = kwargs - self.cv = threading.Condition() - self.durations = defaultdict(list) - - def record_duration(self, conn, duration): - if len(self.durations[conn]) >= 10000: - self.durations[conn].pop(0) - - self.durations[conn].append(duration) - - def stats(self): - out = {'connections': {}} - connections = out['connections'] - take_time = [] - for conn, durations in self.durations.items(): - total_time = sum(d.value for d in durations) - connections[id(conn)] = { - 'total_time': total_time, - 'called_times': len(durations) - } - take_time.append(total_time) - - out['max-time'] = max(take_time) - out['num'] = len(self.durations) - logger.debug(json.dumps(out, indent=2)) - return out - - def __len__(self): - return len(self.pending_pool) + len(self.active_pool) - - @property - def active_num(self): - return len(self.active_pool) - - def _is_full(self): - if self.capacity < 0: - return False - return len(self) >= self.capacity - - def fetch(self, timeout=1): - with self.cv: - timeout_times = 0 - while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1): - self.cv.notifyAll() - self.cv.wait(timeout) - timeout_times += 1 - - connection = None - if timeout_times >= 1: - return connection - - # logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) - if len(self.pending_pool) == 0: - connection = self.create() - else: - connection = self.pending_pool.pop() - # logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name)) - self.active_pool.add(connection) - scoped_connection = ScopedConnection(self, connection) - return scoped_connection - - def release(self, connection): - with self.cv: - if connection not in self.active_pool: - raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name)) - # logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name)) - # logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) - self.active_pool.remove(connection) - self.pending_pool.add(connection) - - def create(self): - connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs) - return connection +# class Searchook(BaseSearchHook): +# +# def on_response(self, *args, **kwargs): +# return True +# +# +# class Connection: +# def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): +# self.name = name +# self.uri = uri +# self.max_retry = max_retry +# self.retried = 0 +# self.conn = Milvus() +# self.error_handlers = [] if not error_handlers else error_handlers +# self.on_retry_func = kwargs.get('on_retry_func', None) +# +# # define search hook +# self.conn.set_hook(search_in_file=Searchook()) +# # self._connect() +# +# def __str__(self): +# return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri) +# +# def _connect(self, metadata=None): +# try: +# self.conn.connect(uri=self.uri) +# except Exception as e: +# if not self.error_handlers: +# raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata) +# for handler in self.error_handlers: +# handler(e, metadata=metadata) +# +# @property +# def can_retry(self): +# return self.retried < self.max_retry +# +# @property +# def connected(self): +# return self.conn.connected() +# +# def on_retry(self): +# if self.on_retry_func: +# self.on_retry_func(self) +# else: +# self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried)) +# +# def on_connect(self, metadata=None): +# while not self.connected and self.can_retry: +# self.retried += 1 +# self.on_retry() +# self._connect(metadata=metadata) +# +# if not self.can_retry and not self.connected: +# raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry, +# metadata=metadata)) +# +# self.retried = 0 +# +# def connect(self, func, exception_handler=None): +# @wraps(func) +# def inner(*args, **kwargs): +# self.on_connect() +# try: +# return func(*args, **kwargs) +# except Exception as e: +# if exception_handler: +# exception_handler(e) +# else: +# raise e +# return inner +# +# def __str__(self): +# return ''.format(self.name, id(self)) +# +# def __repr__(self): +# return self.__str__() +# +# +# class Duration: +# def __init__(self): +# self.start_ts = time.time() +# self.end_ts = None +# +# def stop(self): +# if self.end_ts: +# return False +# +# self.end_ts = time.time() +# return True +# +# @property +# def value(self): +# if not self.end_ts: +# return None +# +# return self.end_ts - self.start_ts +# +# +# class ProxyMixin: +# def __getattr__(self, name): +# target = self.__dict__.get(name, None) +# if target or not self.connection: +# return target +# return getattr(self.connection, name) +# +# +# class ScopedConnection(ProxyMixin): +# def __init__(self, pool, connection): +# self.pool = pool +# self.connection = connection +# self.duration = Duration() +# +# def __del__(self): +# self.release() +# +# def __str__(self): +# return self.connection.__str__() +# +# def release(self): +# if not self.pool or not self.connection: +# return +# self.pool.release(self.connection) +# self.duration.stop() +# self.pool.record_duration(self.connection, self.duration) +# self.pool = None +# self.connection = None +# +# +# class ConnectionPool(topology.TopoObject): +# def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs): +# super().__init__(name) +# self.capacity = capacity +# self.pending_pool = set() +# self.active_pool = set() +# self.connection_ownership = {} +# self.uri = uri +# self.max_retry = max_retry +# self.kwargs = kwargs +# self.cv = threading.Condition() +# self.durations = defaultdict(list) +# +# def record_duration(self, conn, duration): +# if len(self.durations[conn]) >= 10000: +# self.durations[conn].pop(0) +# +# self.durations[conn].append(duration) +# +# def stats(self): +# out = {'connections': {}} +# connections = out['connections'] +# take_time = [] +# for conn, durations in self.durations.items(): +# total_time = sum(d.value for d in durations) +# connections[id(conn)] = { +# 'total_time': total_time, +# 'called_times': len(durations) +# } +# take_time.append(total_time) +# +# out['max-time'] = max(take_time) +# out['num'] = len(self.durations) +# logger.debug(json.dumps(out, indent=2)) +# return out +# +# def __len__(self): +# return len(self.pending_pool) + len(self.active_pool) +# +# @property +# def active_num(self): +# return len(self.active_pool) +# +# def _is_full(self): +# if self.capacity < 0: +# return False +# return len(self) >= self.capacity +# +# def fetch(self, timeout=1): +# with self.cv: +# timeout_times = 0 +# while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1): +# self.cv.notifyAll() +# self.cv.wait(timeout) +# timeout_times += 1 +# +# connection = None +# if timeout_times >= 1: +# return connection +# +# # logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) +# if len(self.pending_pool) == 0: +# connection = self.create() +# else: +# connection = self.pending_pool.pop() +# # logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name)) +# self.active_pool.add(connection) +# scoped_connection = ScopedConnection(self, connection) +# return scoped_connection +# +# def release(self, connection): +# with self.cv: +# if connection not in self.active_pool: +# raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name)) +# # logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name)) +# # logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) +# self.active_pool.remove(connection) +# self.pending_pool.add(connection) +# +# def create(self): +# connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs) +# return connection class ConnectionGroup(topology.TopoGroup): @@ -237,9 +237,9 @@ class ConnectionGroup(topology.TopoGroup): return out def on_pre_add(self, topo_object): - conn = topo_object.fetch() - conn.on_connect(metadata=None) - status, version = conn.conn.server_version() + # conn = topo_object.fetch() + # conn.on_connect(metadata=None) + status, version = topo_object.server_version() if not status.OK(): logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name)) return False @@ -254,7 +254,7 @@ class ConnectionGroup(topology.TopoGroup): uri = kwargs.get('uri', None) if not uri: raise RuntimeError('\"uri\" is required to create connection pool') - pool = ConnectionPool(name=name, **kwargs) + pool = Milvus(name=name, **kwargs) status = self.add(pool) if status != topology.StatusType.OK: pool = None diff --git a/shards/mishards/grpc_utils/grpc_args_parser.py b/shards/mishards/grpc_utils/grpc_args_parser.py index 67ca043b..7a5e36c4 100644 --- a/shards/mishards/grpc_utils/grpc_args_parser.py +++ b/shards/mishards/grpc_utils/grpc_args_parser.py @@ -116,9 +116,9 @@ class GrpcArgsParser(object): @error_status def parse_proto_VectorIdentity(cls, param): _collection_name = param.collection_name - _id = param.id + _ids = list(param.id_array) - return _collection_name, _id + return _collection_name, _ids @classmethod @error_status diff --git a/shards/mishards/router/__init__.py b/shards/mishards/router/__init__.py index 033aa3f5..3f064cb3 100644 --- a/shards/mishards/router/__init__.py +++ b/shards/mishards/router/__init__.py @@ -10,11 +10,13 @@ class RouterMixin: raise NotImplemented() def connection(self, metadata=None): - conn = self.writable_topo.get_group('default').get('WOSERVER').fetch() - if conn: - conn.on_connect(metadata=metadata) + # conn = self.writable_topo.get_group('default').get('WOSERVER').fetch() + conn = self.writable_topo.get_group('default').get('WOSERVER') + # if conn: + # conn.on_connect(metadata=metadata) # PXU TODO: should return conn - return conn.conn + return conn + # return conn.conn def query_conn(self, name, metadata=None): if not name: @@ -27,9 +29,15 @@ class RouterMixin: raise exceptions.ConnectionNotFoundError( message=f'Conn Group {name} is Empty. Please Check your configurations', metadata=metadata) - conn = group.get(name).fetch() - if not conn: - raise exceptions.ConnectionNotFoundError( - message=f'Conn {name} Not Found', metadata=metadata) - conn.on_connect(metadata=metadata) + # conn = group.get(name).fetch() + # if not conn: + # raise exceptions.ConnectionNotFoundError( + # message=f'Conn {name} Not Found', metadata=metadata) + # conn.on_connect(metadata=metadata) + + # conn = self.readonly_topo.get_group(name).get(name).fetch() + conn = self.readonly_topo.get_group(name).get(name) + # if not conn: + # raise exceptions.ConnectionNotFoundError(name, metadata=metadata) + # conn.on_connect(metadata=metadata) return conn diff --git a/shards/mishards/router/plugins/file_based_hash_ring_router.py b/shards/mishards/router/plugins/file_based_hash_ring_router.py index d4c66cce..8e691075 100644 --- a/shards/mishards/router/plugins/file_based_hash_ring_router.py +++ b/shards/mishards/router/plugins/file_based_hash_ring_router.py @@ -1,4 +1,5 @@ import logging +import re from sqlalchemy import exc as sqlalchemy_exc from sqlalchemy import and_, or_ from mishards.models import Tables, TableFiles @@ -31,8 +32,8 @@ class Factory(RouterMixin): else: # TODO: collection default partition is '_default' cond = and_(Tables.state != Tables.TO_DELETE, - Tables.owner_table == collection_name, - Tables.partition_tag.in_(partition_tags)) + Tables.owner_table == collection_name) + # Tables.partition_tag.in_(partition_tags)) if '_default' in partition_tags: default_par_cond = and_(Tables.table_id == collection_name, Tables.state != Tables.TO_DELETE) cond = or_(cond, default_par_cond) @@ -45,7 +46,19 @@ class Factory(RouterMixin): logger.error("Cannot find collection {} / {} in metadata".format(collection_name, partition_tags)) raise exceptions.CollectionNotFoundError('{}:{}'.format(collection_name, partition_tags), metadata=metadata) - collection_list = [str(collection.table_id) for collection in collections] + collection_list = [] + if not partition_tags: + collection_list = [str(collection.table_id) for collection in collections] + else: + for collection in collections: + if collection.table_id == collection_name: + collection_list.append(collection_name) + continue + + for tag in partition_tags: + if re.match(tag, collection.partition_tag): + collection_list.append(collection.table_id) + break file_type_cond = or_( TableFiles.file_type == TableFiles.FILE_TYPE_RAW, diff --git a/shards/mishards/service_handler.py b/shards/mishards/service_handler.py index a0661d84..51f496d1 100644 --- a/shards/mishards/service_handler.py +++ b/shards/mishards/service_handler.py @@ -122,48 +122,36 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): metadata = kwargs.get('metadata', None) - rs = [] all_topk_results = [] - def search(addr, collection_id, file_ids, vectors, topk, params, **kwargs): - logger.info( - 'Send Search Request: addr={};collection_id={};ids={};nq={};topk={};params={}' - .format(addr, collection_id, file_ids, len(vectors), topk, params)) - - conn = self.router.query_conn(addr, metadata=metadata) - start = time.time() - span = kwargs.get('span', None) - span = span if span else (None if self.tracer.empty else - context.get_active_span().context) - - with self.tracer.start_span('search_{}'.format(addr), - child_of=span): - ret = conn.conn.search_vectors_in_files(collection_name=collection_id, - file_ids=file_ids, - query_records=vectors, - top_k=topk, - params=params) - if ret.status.error_code != 0: - logger.error("Search fail {}".format(ret.status)) - - end = time.time() - all_topk_results.append(ret) - with self.tracer.start_span('do_search', child_of=p_span) as span: - with ThreadPoolExecutor(max_workers=self.max_workers) as pool: + if len(routing) == 0: + logger.warning('SearchVector: partition_tags = {}'.format(partition_tags)) + ft = self.router.connection().search(collection_id, topk, vectors, list(partition_tags), search_params, _async=True) + ret = ft.result(raw=True) + all_topk_results.append(ret) + else: + futures = [] for addr, file_ids in routing.items(): - res = pool.submit(search, - addr, - collection_id, - file_ids, - vectors, - topk, - search_params, - span=span) - rs.append(res) - - for res in rs: - res.result() + conn = self.router.query_conn(addr, metadata=metadata) + start = time.time() + span = kwargs.get('span', None) + span = span if span else (None if self.tracer.empty else + context.get_active_span().context) + + with self.tracer.start_span('search_{}'.format(addr), + child_of=span): + logger.warning("Search file ids is {}".format(file_ids)) + future = conn.search_vectors_in_files(collection_name=collection_id, + file_ids=file_ids, + query_records=vectors, + top_k=topk, + params=search_params, _async=True) + futures.append(future) + + for f in futures: + ret = f.result(raw=True) + all_topk_results.append(ret) reverse = collection_meta.metric_type == Types.MetricType.IP with self.tracer.start_span('do_merge', child_of=p_span): @@ -231,6 +219,13 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return status_pb2.Status(error_code=_status.code, reason=_status.message) + @mark_grpc_method + def HasPartition(self, request, context): + _collection_name, _tag = Parser.parse_proto_PartitionParam(request) + _status, _ok = self.router.connection().has_partition(_collection_name, _tag) + return milvus_pb2.BoolReply(status_pb2.Status(error_code=_status.code, + reason=_status.message), bool_reply=_ok) + @mark_grpc_method def ShowPartitions(self, request, context): _status, _collection_name = Parser.parse_proto_CollectionName(request) @@ -370,6 +365,72 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): def SearchInFiles(self, request, context): raise NotImplemented() + @mark_grpc_method + def SearchByID(self, request, context): + metadata = {'resp_class': milvus_pb2.TopKQueryResult} + + collection_name = request.collection_name + + topk = request.topk + + if len(request.extra_params) == 0: + raise exceptions.SearchParamError(message="Search param loss", metadata=metadata) + params = ujson.loads(str(request.extra_params[0].value)) + + logger.info('Search {}: topk={} params={}'.format( + collection_name, topk, params)) + + if topk > self.MAX_TOPK or topk <= 0: + raise exceptions.InvalidTopKError( + message='Invalid topk: {}'.format(topk), metadata=metadata) + + collection_meta = self.collection_meta.get(collection_name, None) + + if not collection_meta: + status, info = self.router.connection( + metadata=metadata).describe_collection(collection_name) + if not status.OK(): + raise exceptions.CollectionNotFoundError(collection_name, + metadata=metadata) + + self.collection_meta[collection_name] = info + collection_meta = info + + start = time.time() + + query_record_array = [] + if int(collection_meta.metric_type) >= MetricType.HAMMING.value: + for query_record in request.query_record_array: + query_record_array.append(bytes(query_record.binary_data)) + else: + for query_record in request.query_record_array: + query_record_array.append(list(query_record.float_data)) + + partition_tags = getattr(request, "partition_tag_array", []) + ids = getattr(request, "id_array", []) + search_result = self.router.connection(metadata=metadata).search_by_ids(collection_name, ids, topk, partition_tags, params) + # status, id_results, dis_results = self._do_query(context, + # collection_name, + # collection_meta, + # query_record_array, + # topk, + # params, + # partition_tags=getattr(request, "partition_tag_array", []), + # metadata=metadata) + + now = time.time() + logger.info('SearchVector takes: {}'.format(now - start)) + return search_result + # + # topk_result_list = milvus_pb2.TopKQueryResult( + # status=status_pb2.Status(error_code=status.error_code, + # reason=status.reason), + # row_num=len(request.query_record_array) if len(id_results) else 0, + # ids=id_results, + # distances=dis_results) + # return topk_result_list + # raise NotImplemented() + def _describe_collection(self, collection_name, metadata=None): return self.router.connection(metadata=metadata).describe_collection(collection_name) @@ -416,32 +477,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): metadata = {'resp_class': milvus_pb2.CollectionInfo} - logger.info('ShowCollectionInfo {}'.format(_collection_name)) _status, _info = self._collection_info(metadata=metadata, collection_name=_collection_name) + _info_str = ujson.dumps(_info) if _status.OK(): - _collection_info = milvus_pb2.CollectionInfo( + return milvus_pb2.CollectionInfo( status=status_pb2.Status(error_code=_status.code, reason=_status.message), - total_row_count=_info.count + json_info=_info_str ) - for par_stat in _info.partitions_stat: - _par = milvus_pb2.PartitionStat( - tag=par_stat.tag, - total_row_count=par_stat.count - ) - for seg_stat in par_stat.segments_stat: - _par.segments_stat.add( - segment_name=seg_stat.segment_name, - row_count=seg_stat.count, - index_name=seg_stat.index_name, - data_size=seg_stat.data_size, - ) - - _collection_info.partitions_stat.append(_par) - return _collection_info - return milvus_pb2.CollectionInfo( status=status_pb2.Status(error_code=_status.code, reason=_status.message), @@ -564,35 +609,35 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): grpc_index.extra_params.add(key='params', value=ujson.dumps(_index_param._params)) return grpc_index - def _get_vector_by_id(self, collection_name, vec_id, metadata): - return self.router.connection(metadata=metadata).get_vector_by_id(collection_name, vec_id) + def _get_vectors_by_id(self, collection_name, ids, metadata): + return self.router.connection(metadata=metadata).get_vectors_by_ids(collection_name, ids) @mark_grpc_method - def GetVectorByID(self, request, context): + def GetVectorsByID(self, request, context): _status, unpacks = Parser.parse_proto_VectorIdentity(request) if not _status.OK(): return status_pb2.Status(error_code=_status.code, reason=_status.message) - metadata = {'resp_class': milvus_pb2.VectorData} + metadata = {'resp_class': milvus_pb2.VectorsData} - _collection_name, _id = unpacks + _collection_name, _ids = unpacks logger.info('GetVectorByID {}'.format(_collection_name)) - _status, vector = self._get_vector_by_id(_collection_name, _id, metadata) - - if not vector: - return milvus_pb2.VectorData(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), ) - - if isinstance(vector, bytes): - records = milvus_pb2.RowRecord(binary_data=vector) + _status, vectors = self._get_vectors_by_id(_collection_name, _ids, metadata) + _rpc_status = status_pb2.Status(error_code=_status.code, reason=_status.message) + if not vectors: + return milvus_pb2.VectorsData(status=_rpc_status, ) + + if len(vectors) == 0: + return milvus_pb2.VectorsData(status=_rpc_status, vectors_data=[]) + if isinstance(vectors[0], bytes): + records = [milvus_pb2.RowRecord(binary_data=v) for v in vectors] else: - records = milvus_pb2.RowRecord(float_data=vector) + records = [milvus_pb2.RowRecord(float_data=v) for v in vectors] - return milvus_pb2.VectorData(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), - vector_data=records - ) + response = milvus_pb2.VectorsData(status=_rpc_status) + response.vectors_data.extend(records) + return response def _get_vector_ids(self, collection_name, segment_name, metadata): return self.router.connection(metadata=metadata).get_vector_ids(collection_name, segment_name) diff --git a/shards/requirements.txt b/shards/requirements.txt index b4c1921c..fcadc447 100644 --- a/shards/requirements.txt +++ b/shards/requirements.txt @@ -14,8 +14,8 @@ py==1.8.0 pyasn1==0.4.7 pyasn1-modules==0.2.6 pylint==2.5.0 -pymilvus==0.2.10 -#pymilvus-test==0.3.3 +#pymilvus==0.2.10 +pymilvus-test==0.3.10 pyparsing==2.4.0 pytest==4.6.3 pytest-level==0.1.1 diff --git a/tests/milvus_python_test/test_connect.py b/tests/milvus_python_test/test_connect.py index a5ee634f..e380d943 100644 --- a/tests/milvus_python_test/test_connect.py +++ b/tests/milvus_python_test/test_connect.py @@ -104,11 +104,11 @@ class TestConnect: ''' uri_value = "" if self.local_ip(args): - milvus = get_milvus(uri=uri_value, handler=args["handler"]) + milvus = get_milvus(None, None, uri=uri_value, handler=args["handler"]) # assert milvus.connected() else: with pytest.raises(Exception) as e: - milvus = get_milvus(uri=uri_value, handler=args["handler"]) + milvus = get_milvus(None, None, uri=uri_value, handler=args["handler"]) # assert not milvus.connected() # disable -- GitLab