未验证 提交 8a68d349 编写于 作者: B BossZou 提交者: GitHub

Mishards upgrade (#2260)

* shards ci
Signed-off-by: Nyhz <413554850@qq.com>

* update interface
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* update mishards
Signed-off-by: NYhz <yinghao.zou@zilliz.com>

* update changlog
Signed-off-by: Nyhz <413554850@qq.com>

* [skip ci] mishards dev test pass (fix #2252)
Signed-off-by: Nyhz <413554850@qq.com>
上级 6048a3bc
......@@ -13,10 +13,13 @@ Please mark all change in change log and use the issue from GitHub
- \#1997 Index file missed after compact
- \#2073 Fix CheckDBConfigBackendUrl error message
- \#2076 CheckMetricConfigAddress error message
- \#2120 Fix Search expected failed if search params set invalid
- \#2121 Allow regex match partition tag when search
- \#2128 Check has_partition params
- \#2131 Distance/ID returned is not correct if searching with duplicate ids
- \#2141 Fix server start failed if wal directory exist
- \#2169 Fix SingleIndexTest.IVFSQHybrid unittest
- \#2194 Fix get collection info failed
- \#2196 Fix server start failed if wal is disabled
- \#2231 Use server_config to define hard-delete delay time for segment files
......@@ -45,6 +48,7 @@ Please mark all change in change log and use the issue from GitHub
- \#2185 Change id to string format in http module
- \#2186 Update endpoints in http module
- \#2190 Fix memory usage is twice of index size when using GPU searching
- \#2252 Upgrade mishards to v0.9.0
## Task
......
......@@ -5,7 +5,7 @@ import threading
from functools import wraps
from collections import defaultdict
from milvus import Milvus
from milvus.client.hooks import BaseSearchHook
# from milvus.client.hooks import BaseSearchHook
from mishards import (settings, exceptions, topology)
from utils import singleton
......@@ -13,216 +13,216 @@ from utils import singleton
logger = logging.getLogger(__name__)
class Searchook(BaseSearchHook):
def on_response(self, *args, **kwargs):
return True
class Connection:
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
self.name = name
self.uri = uri
self.max_retry = max_retry
self.retried = 0
self.conn = Milvus()
self.error_handlers = [] if not error_handlers else error_handlers
self.on_retry_func = kwargs.get('on_retry_func', None)
# define search hook
self.conn.set_hook(search_in_file=Searchook())
# self._connect()
def __str__(self):
return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri)
def _connect(self, metadata=None):
try:
self.conn.connect(uri=self.uri)
except Exception as e:
if not self.error_handlers:
raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata)
for handler in self.error_handlers:
handler(e, metadata=metadata)
@property
def can_retry(self):
return self.retried < self.max_retry
@property
def connected(self):
return self.conn.connected()
def on_retry(self):
if self.on_retry_func:
self.on_retry_func(self)
else:
self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried))
def on_connect(self, metadata=None):
while not self.connected and self.can_retry:
self.retried += 1
self.on_retry()
self._connect(metadata=metadata)
if not self.can_retry and not self.connected:
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
metadata=metadata))
self.retried = 0
def connect(self, func, exception_handler=None):
@wraps(func)
def inner(*args, **kwargs):
self.on_connect()
try:
return func(*args, **kwargs)
except Exception as e:
if exception_handler:
exception_handler(e)
else:
raise e
return inner
def __str__(self):
return '<Connection: {}:{}>'.format(self.name, id(self))
def __repr__(self):
return self.__str__()
class Duration:
def __init__(self):
self.start_ts = time.time()
self.end_ts = None
def stop(self):
if self.end_ts:
return False
self.end_ts = time.time()
return True
@property
def value(self):
if not self.end_ts:
return None
return self.end_ts - self.start_ts
class ProxyMixin:
def __getattr__(self, name):
target = self.__dict__.get(name, None)
if target or not self.connection:
return target
return getattr(self.connection, name)
class ScopedConnection(ProxyMixin):
def __init__(self, pool, connection):
self.pool = pool
self.connection = connection
self.duration = Duration()
def __del__(self):
self.release()
def __str__(self):
return self.connection.__str__()
def release(self):
if not self.pool or not self.connection:
return
self.pool.release(self.connection)
self.duration.stop()
self.pool.record_duration(self.connection, self.duration)
self.pool = None
self.connection = None
class ConnectionPool(topology.TopoObject):
def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs):
super().__init__(name)
self.capacity = capacity
self.pending_pool = set()
self.active_pool = set()
self.connection_ownership = {}
self.uri = uri
self.max_retry = max_retry
self.kwargs = kwargs
self.cv = threading.Condition()
self.durations = defaultdict(list)
def record_duration(self, conn, duration):
if len(self.durations[conn]) >= 10000:
self.durations[conn].pop(0)
self.durations[conn].append(duration)
def stats(self):
out = {'connections': {}}
connections = out['connections']
take_time = []
for conn, durations in self.durations.items():
total_time = sum(d.value for d in durations)
connections[id(conn)] = {
'total_time': total_time,
'called_times': len(durations)
}
take_time.append(total_time)
out['max-time'] = max(take_time)
out['num'] = len(self.durations)
logger.debug(json.dumps(out, indent=2))
return out
def __len__(self):
return len(self.pending_pool) + len(self.active_pool)
@property
def active_num(self):
return len(self.active_pool)
def _is_full(self):
if self.capacity < 0:
return False
return len(self) >= self.capacity
def fetch(self, timeout=1):
with self.cv:
timeout_times = 0
while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1):
self.cv.notifyAll()
self.cv.wait(timeout)
timeout_times += 1
connection = None
if timeout_times >= 1:
return connection
# logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
if len(self.pending_pool) == 0:
connection = self.create()
else:
connection = self.pending_pool.pop()
# logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name))
self.active_pool.add(connection)
scoped_connection = ScopedConnection(self, connection)
return scoped_connection
def release(self, connection):
with self.cv:
if connection not in self.active_pool:
raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name))
# logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name))
# logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
self.active_pool.remove(connection)
self.pending_pool.add(connection)
def create(self):
connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs)
return connection
# class Searchook(BaseSearchHook):
#
# def on_response(self, *args, **kwargs):
# return True
#
#
# class Connection:
# def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
# self.name = name
# self.uri = uri
# self.max_retry = max_retry
# self.retried = 0
# self.conn = Milvus()
# self.error_handlers = [] if not error_handlers else error_handlers
# self.on_retry_func = kwargs.get('on_retry_func', None)
#
# # define search hook
# self.conn.set_hook(search_in_file=Searchook())
# # self._connect()
#
# def __str__(self):
# return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri)
#
# def _connect(self, metadata=None):
# try:
# self.conn.connect(uri=self.uri)
# except Exception as e:
# if not self.error_handlers:
# raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata)
# for handler in self.error_handlers:
# handler(e, metadata=metadata)
#
# @property
# def can_retry(self):
# return self.retried < self.max_retry
#
# @property
# def connected(self):
# return self.conn.connected()
#
# def on_retry(self):
# if self.on_retry_func:
# self.on_retry_func(self)
# else:
# self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried))
#
# def on_connect(self, metadata=None):
# while not self.connected and self.can_retry:
# self.retried += 1
# self.on_retry()
# self._connect(metadata=metadata)
#
# if not self.can_retry and not self.connected:
# raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
# metadata=metadata))
#
# self.retried = 0
#
# def connect(self, func, exception_handler=None):
# @wraps(func)
# def inner(*args, **kwargs):
# self.on_connect()
# try:
# return func(*args, **kwargs)
# except Exception as e:
# if exception_handler:
# exception_handler(e)
# else:
# raise e
# return inner
#
# def __str__(self):
# return '<Connection: {}:{}>'.format(self.name, id(self))
#
# def __repr__(self):
# return self.__str__()
#
#
# class Duration:
# def __init__(self):
# self.start_ts = time.time()
# self.end_ts = None
#
# def stop(self):
# if self.end_ts:
# return False
#
# self.end_ts = time.time()
# return True
#
# @property
# def value(self):
# if not self.end_ts:
# return None
#
# return self.end_ts - self.start_ts
#
#
# class ProxyMixin:
# def __getattr__(self, name):
# target = self.__dict__.get(name, None)
# if target or not self.connection:
# return target
# return getattr(self.connection, name)
#
#
# class ScopedConnection(ProxyMixin):
# def __init__(self, pool, connection):
# self.pool = pool
# self.connection = connection
# self.duration = Duration()
#
# def __del__(self):
# self.release()
#
# def __str__(self):
# return self.connection.__str__()
#
# def release(self):
# if not self.pool or not self.connection:
# return
# self.pool.release(self.connection)
# self.duration.stop()
# self.pool.record_duration(self.connection, self.duration)
# self.pool = None
# self.connection = None
#
#
# class ConnectionPool(topology.TopoObject):
# def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs):
# super().__init__(name)
# self.capacity = capacity
# self.pending_pool = set()
# self.active_pool = set()
# self.connection_ownership = {}
# self.uri = uri
# self.max_retry = max_retry
# self.kwargs = kwargs
# self.cv = threading.Condition()
# self.durations = defaultdict(list)
#
# def record_duration(self, conn, duration):
# if len(self.durations[conn]) >= 10000:
# self.durations[conn].pop(0)
#
# self.durations[conn].append(duration)
#
# def stats(self):
# out = {'connections': {}}
# connections = out['connections']
# take_time = []
# for conn, durations in self.durations.items():
# total_time = sum(d.value for d in durations)
# connections[id(conn)] = {
# 'total_time': total_time,
# 'called_times': len(durations)
# }
# take_time.append(total_time)
#
# out['max-time'] = max(take_time)
# out['num'] = len(self.durations)
# logger.debug(json.dumps(out, indent=2))
# return out
#
# def __len__(self):
# return len(self.pending_pool) + len(self.active_pool)
#
# @property
# def active_num(self):
# return len(self.active_pool)
#
# def _is_full(self):
# if self.capacity < 0:
# return False
# return len(self) >= self.capacity
#
# def fetch(self, timeout=1):
# with self.cv:
# timeout_times = 0
# while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1):
# self.cv.notifyAll()
# self.cv.wait(timeout)
# timeout_times += 1
#
# connection = None
# if timeout_times >= 1:
# return connection
#
# # logger.error('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
# if len(self.pending_pool) == 0:
# connection = self.create()
# else:
# connection = self.pending_pool.pop()
# # logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name))
# self.active_pool.add(connection)
# scoped_connection = ScopedConnection(self, connection)
# return scoped_connection
#
# def release(self, connection):
# with self.cv:
# if connection not in self.active_pool:
# raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name))
# # logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name))
# # logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
# self.active_pool.remove(connection)
# self.pending_pool.add(connection)
#
# def create(self):
# connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs)
# return connection
class ConnectionGroup(topology.TopoGroup):
......@@ -237,9 +237,9 @@ class ConnectionGroup(topology.TopoGroup):
return out
def on_pre_add(self, topo_object):
conn = topo_object.fetch()
conn.on_connect(metadata=None)
status, version = conn.conn.server_version()
# conn = topo_object.fetch()
# conn.on_connect(metadata=None)
status, version = topo_object.server_version()
if not status.OK():
logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name))
return False
......@@ -254,7 +254,7 @@ class ConnectionGroup(topology.TopoGroup):
uri = kwargs.get('uri', None)
if not uri:
raise RuntimeError('\"uri\" is required to create connection pool')
pool = ConnectionPool(name=name, **kwargs)
pool = Milvus(name=name, **kwargs)
status = self.add(pool)
if status != topology.StatusType.OK:
pool = None
......
......@@ -116,9 +116,9 @@ class GrpcArgsParser(object):
@error_status
def parse_proto_VectorIdentity(cls, param):
_collection_name = param.collection_name
_id = param.id
_ids = list(param.id_array)
return _collection_name, _id
return _collection_name, _ids
@classmethod
@error_status
......
......@@ -10,11 +10,13 @@ class RouterMixin:
raise NotImplemented()
def connection(self, metadata=None):
conn = self.writable_topo.get_group('default').get('WOSERVER').fetch()
if conn:
conn.on_connect(metadata=metadata)
# conn = self.writable_topo.get_group('default').get('WOSERVER').fetch()
conn = self.writable_topo.get_group('default').get('WOSERVER')
# if conn:
# conn.on_connect(metadata=metadata)
# PXU TODO: should return conn
return conn.conn
return conn
# return conn.conn
def query_conn(self, name, metadata=None):
if not name:
......@@ -27,9 +29,15 @@ class RouterMixin:
raise exceptions.ConnectionNotFoundError(
message=f'Conn Group {name} is Empty. Please Check your configurations',
metadata=metadata)
conn = group.get(name).fetch()
if not conn:
raise exceptions.ConnectionNotFoundError(
message=f'Conn {name} Not Found', metadata=metadata)
conn.on_connect(metadata=metadata)
# conn = group.get(name).fetch()
# if not conn:
# raise exceptions.ConnectionNotFoundError(
# message=f'Conn {name} Not Found', metadata=metadata)
# conn.on_connect(metadata=metadata)
# conn = self.readonly_topo.get_group(name).get(name).fetch()
conn = self.readonly_topo.get_group(name).get(name)
# if not conn:
# raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
# conn.on_connect(metadata=metadata)
return conn
import logging
import re
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy import and_, or_
from mishards.models import Tables, TableFiles
......@@ -31,8 +32,8 @@ class Factory(RouterMixin):
else:
# TODO: collection default partition is '_default'
cond = and_(Tables.state != Tables.TO_DELETE,
Tables.owner_table == collection_name,
Tables.partition_tag.in_(partition_tags))
Tables.owner_table == collection_name)
# Tables.partition_tag.in_(partition_tags))
if '_default' in partition_tags:
default_par_cond = and_(Tables.table_id == collection_name, Tables.state != Tables.TO_DELETE)
cond = or_(cond, default_par_cond)
......@@ -45,7 +46,19 @@ class Factory(RouterMixin):
logger.error("Cannot find collection {} / {} in metadata".format(collection_name, partition_tags))
raise exceptions.CollectionNotFoundError('{}:{}'.format(collection_name, partition_tags), metadata=metadata)
collection_list = []
if not partition_tags:
collection_list = [str(collection.table_id) for collection in collections]
else:
for collection in collections:
if collection.table_id == collection_name:
collection_list.append(collection_name)
continue
for tag in partition_tags:
if re.match(tag, collection.partition_tag):
collection_list.append(collection.table_id)
break
file_type_cond = or_(
TableFiles.file_type == TableFiles.FILE_TYPE_RAW,
......
......@@ -122,14 +122,17 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
metadata = kwargs.get('metadata', None)
rs = []
all_topk_results = []
def search(addr, collection_id, file_ids, vectors, topk, params, **kwargs):
logger.info(
'Send Search Request: addr={};collection_id={};ids={};nq={};topk={};params={}'
.format(addr, collection_id, file_ids, len(vectors), topk, params))
with self.tracer.start_span('do_search', child_of=p_span) as span:
if len(routing) == 0:
logger.warning('SearchVector: partition_tags = {}'.format(partition_tags))
ft = self.router.connection().search(collection_id, topk, vectors, list(partition_tags), search_params, _async=True)
ret = ft.result(raw=True)
all_topk_results.append(ret)
else:
futures = []
for addr, file_ids in routing.items():
conn = self.router.query_conn(addr, metadata=metadata)
start = time.time()
span = kwargs.get('span', None)
......@@ -138,33 +141,18 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
with self.tracer.start_span('search_{}'.format(addr),
child_of=span):
ret = conn.conn.search_vectors_in_files(collection_name=collection_id,
logger.warning("Search file ids is {}".format(file_ids))
future = conn.search_vectors_in_files(collection_name=collection_id,
file_ids=file_ids,
query_records=vectors,
top_k=topk,
params=params)
if ret.status.error_code != 0:
logger.error("Search fail {}".format(ret.status))
params=search_params, _async=True)
futures.append(future)
end = time.time()
for f in futures:
ret = f.result(raw=True)
all_topk_results.append(ret)
with self.tracer.start_span('do_search', child_of=p_span) as span:
with ThreadPoolExecutor(max_workers=self.max_workers) as pool:
for addr, file_ids in routing.items():
res = pool.submit(search,
addr,
collection_id,
file_ids,
vectors,
topk,
search_params,
span=span)
rs.append(res)
for res in rs:
res.result()
reverse = collection_meta.metric_type == Types.MetricType.IP
with self.tracer.start_span('do_merge', child_of=p_span):
return self._do_merge(all_topk_results,
......@@ -231,6 +219,13 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
@mark_grpc_method
def HasPartition(self, request, context):
_collection_name, _tag = Parser.parse_proto_PartitionParam(request)
_status, _ok = self.router.connection().has_partition(_collection_name, _tag)
return milvus_pb2.BoolReply(status_pb2.Status(error_code=_status.code,
reason=_status.message), bool_reply=_ok)
@mark_grpc_method
def ShowPartitions(self, request, context):
_status, _collection_name = Parser.parse_proto_CollectionName(request)
......@@ -370,6 +365,72 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def SearchInFiles(self, request, context):
raise NotImplemented()
@mark_grpc_method
def SearchByID(self, request, context):
metadata = {'resp_class': milvus_pb2.TopKQueryResult}
collection_name = request.collection_name
topk = request.topk
if len(request.extra_params) == 0:
raise exceptions.SearchParamError(message="Search param loss", metadata=metadata)
params = ujson.loads(str(request.extra_params[0].value))
logger.info('Search {}: topk={} params={}'.format(
collection_name, topk, params))
if topk > self.MAX_TOPK or topk <= 0:
raise exceptions.InvalidTopKError(
message='Invalid topk: {}'.format(topk), metadata=metadata)
collection_meta = self.collection_meta.get(collection_name, None)
if not collection_meta:
status, info = self.router.connection(
metadata=metadata).describe_collection(collection_name)
if not status.OK():
raise exceptions.CollectionNotFoundError(collection_name,
metadata=metadata)
self.collection_meta[collection_name] = info
collection_meta = info
start = time.time()
query_record_array = []
if int(collection_meta.metric_type) >= MetricType.HAMMING.value:
for query_record in request.query_record_array:
query_record_array.append(bytes(query_record.binary_data))
else:
for query_record in request.query_record_array:
query_record_array.append(list(query_record.float_data))
partition_tags = getattr(request, "partition_tag_array", [])
ids = getattr(request, "id_array", [])
search_result = self.router.connection(metadata=metadata).search_by_ids(collection_name, ids, topk, partition_tags, params)
# status, id_results, dis_results = self._do_query(context,
# collection_name,
# collection_meta,
# query_record_array,
# topk,
# params,
# partition_tags=getattr(request, "partition_tag_array", []),
# metadata=metadata)
now = time.time()
logger.info('SearchVector takes: {}'.format(now - start))
return search_result
#
# topk_result_list = milvus_pb2.TopKQueryResult(
# status=status_pb2.Status(error_code=status.error_code,
# reason=status.reason),
# row_num=len(request.query_record_array) if len(id_results) else 0,
# ids=id_results,
# distances=dis_results)
# return topk_result_list
# raise NotImplemented()
def _describe_collection(self, collection_name, metadata=None):
return self.router.connection(metadata=metadata).describe_collection(collection_name)
......@@ -416,32 +477,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
metadata = {'resp_class': milvus_pb2.CollectionInfo}
logger.info('ShowCollectionInfo {}'.format(_collection_name))
_status, _info = self._collection_info(metadata=metadata, collection_name=_collection_name)
_info_str = ujson.dumps(_info)
if _status.OK():
_collection_info = milvus_pb2.CollectionInfo(
return milvus_pb2.CollectionInfo(
status=status_pb2.Status(error_code=_status.code,
reason=_status.message),
total_row_count=_info.count
json_info=_info_str
)
for par_stat in _info.partitions_stat:
_par = milvus_pb2.PartitionStat(
tag=par_stat.tag,
total_row_count=par_stat.count
)
for seg_stat in par_stat.segments_stat:
_par.segments_stat.add(
segment_name=seg_stat.segment_name,
row_count=seg_stat.count,
index_name=seg_stat.index_name,
data_size=seg_stat.data_size,
)
_collection_info.partitions_stat.append(_par)
return _collection_info
return milvus_pb2.CollectionInfo(
status=status_pb2.Status(error_code=_status.code,
reason=_status.message),
......@@ -564,35 +609,35 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
grpc_index.extra_params.add(key='params', value=ujson.dumps(_index_param._params))
return grpc_index
def _get_vector_by_id(self, collection_name, vec_id, metadata):
return self.router.connection(metadata=metadata).get_vector_by_id(collection_name, vec_id)
def _get_vectors_by_id(self, collection_name, ids, metadata):
return self.router.connection(metadata=metadata).get_vectors_by_ids(collection_name, ids)
@mark_grpc_method
def GetVectorByID(self, request, context):
def GetVectorsByID(self, request, context):
_status, unpacks = Parser.parse_proto_VectorIdentity(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
metadata = {'resp_class': milvus_pb2.VectorData}
metadata = {'resp_class': milvus_pb2.VectorsData}
_collection_name, _id = unpacks
_collection_name, _ids = unpacks
logger.info('GetVectorByID {}'.format(_collection_name))
_status, vector = self._get_vector_by_id(_collection_name, _id, metadata)
if not vector:
return milvus_pb2.VectorData(status=status_pb2.Status(
error_code=_status.code, reason=_status.message), )
if isinstance(vector, bytes):
records = milvus_pb2.RowRecord(binary_data=vector)
_status, vectors = self._get_vectors_by_id(_collection_name, _ids, metadata)
_rpc_status = status_pb2.Status(error_code=_status.code, reason=_status.message)
if not vectors:
return milvus_pb2.VectorsData(status=_rpc_status, )
if len(vectors) == 0:
return milvus_pb2.VectorsData(status=_rpc_status, vectors_data=[])
if isinstance(vectors[0], bytes):
records = [milvus_pb2.RowRecord(binary_data=v) for v in vectors]
else:
records = milvus_pb2.RowRecord(float_data=vector)
records = [milvus_pb2.RowRecord(float_data=v) for v in vectors]
return milvus_pb2.VectorData(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
vector_data=records
)
response = milvus_pb2.VectorsData(status=_rpc_status)
response.vectors_data.extend(records)
return response
def _get_vector_ids(self, collection_name, segment_name, metadata):
return self.router.connection(metadata=metadata).get_vector_ids(collection_name, segment_name)
......
......@@ -14,8 +14,8 @@ py==1.8.0
pyasn1==0.4.7
pyasn1-modules==0.2.6
pylint==2.5.0
pymilvus==0.2.10
#pymilvus-test==0.3.3
#pymilvus==0.2.10
pymilvus-test==0.3.10
pyparsing==2.4.0
pytest==4.6.3
pytest-level==0.1.1
......
......@@ -104,11 +104,11 @@ class TestConnect:
'''
uri_value = ""
if self.local_ip(args):
milvus = get_milvus(uri=uri_value, handler=args["handler"])
milvus = get_milvus(None, None, uri=uri_value, handler=args["handler"])
# assert milvus.connected()
else:
with pytest.raises(Exception) as e:
milvus = get_milvus(uri=uri_value, handler=args["handler"])
milvus = get_milvus(None, None, uri=uri_value, handler=args["handler"])
# assert not milvus.connected()
# disable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册