提交 ee9f67c2 编写于 作者: C Cai Yudong 提交者: yefu.chen

Add mix base

Signed-off-by: NCai Yudong <yudong.cai@zilliz.com>
上级 aee7c747
......@@ -19,19 +19,18 @@ class TestConnect:
else:
return False
# TODO: remove
def _test_disconnect(self, connect):
@pytest.mark.tags("0331")
def test_close(self, connect):
'''
target: test disconnect
method: disconnect a connected client
expected: connect failed after disconnected
'''
res = connect.close()
connect.close()
with pytest.raises(Exception) as e:
res = connect.list_collections()
connect.list_collections()
# TODO: remove
def _test_disconnect_repeatedly(self, dis_connect, args):
def test_close_repeatedly(self, dis_connect, args):
'''
target: test disconnect repeatedly
method: disconnect a connected client, disconnect again
......@@ -48,7 +47,6 @@ class TestConnect:
expected: connected is True
'''
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
# assert milvus.connected()
# TODO: Currently we test with remote IP, localhost testing need to add
def _test_connect_ip_localhost(self, args):
......@@ -62,6 +60,7 @@ class TestConnect:
# assert milvus.connected()
@pytest.mark.timeout(CONNECT_TIMEOUT)
@pytest.mark.tags("0331")
def test_connect_wrong_ip_null(self, args):
'''
target: test connect with wrong ip value
......@@ -70,9 +69,9 @@ class TestConnect:
'''
ip = ""
with pytest.raises(Exception) as e:
milvus = get_milvus(ip, args["port"], args["handler"])
# assert not milvus.connected()
get_milvus(ip, args["port"], args["handler"])
@pytest.mark.tags("0331")
def test_connect_uri(self, args):
'''
target: test connect with correct uri
......@@ -81,8 +80,8 @@ class TestConnect:
'''
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus = get_milvus(args["ip"], args["port"], uri=uri_value, handler=args["handler"])
# assert milvus.connected()
@pytest.mark.tags("0331")
def test_connect_uri_null(self, args):
'''
target: test connect with null uri
......@@ -92,28 +91,28 @@ class TestConnect:
uri_value = ""
if self.local_ip(args):
milvus = get_milvus(None, None, uri=uri_value, handler=args["handler"])
# assert milvus.connected()
else:
with pytest.raises(Exception) as e:
milvus = get_milvus(None, None, uri=uri_value, handler=args["handler"])
# assert not milvus.connected()
@pytest.mark.tags("0331")
def test_connect_with_multiprocess(self, args):
'''
target: test uri connect with multiprocess
method: set correct uri, test with multiprocessing connecting
expected: all connection is connected
'''
processes = []
def connect():
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
assert milvus
assert milvus
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
future_results = {executor.submit(
connect): i for i in range(100)}
for future in concurrent.futures.as_completed(future_results):
future.result()
@pytest.mark.tags("0331")
def test_connect_repeatedly(self, args):
'''
target: test connect repeatedly
......@@ -122,10 +121,7 @@ class TestConnect:
'''
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus = Milvus(uri=uri_value, handler=args["handler"])
# milvus.connect(uri=uri_value, timeout=5)
# milvus.connect(uri=uri_value, timeout=5)
milvus = Milvus(uri=uri_value, handler=args["handler"])
# assert milvus.connected()
def _test_add_vector_and_disconnect_concurrently(self):
'''
......@@ -153,19 +149,20 @@ class TestConnect:
pass
def _test_thread_safe_with_one_connection_shared_in_multi_threads(self):
'''
'''
Target: test 1 connection thread safe
Method: 1 connection shared in multi-threads, all adding vectors, or other things
Expected: Functional as one thread
'''
pass
pass
class TestConnectIPInvalid(object):
"""
Test connect server with invalid ip
"""
@pytest.fixture(
scope="function",
params=gen_invalid_ips()
......@@ -175,11 +172,11 @@ class TestConnectIPInvalid(object):
@pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT)
@pytest.mark.tags("0331")
def test_connect_with_invalid_ip(self, args, get_invalid_ip):
ip = get_invalid_ip
with pytest.raises(Exception) as e:
milvus = get_milvus(ip, args["port"], args["handler"])
# assert not milvus.connected()
class TestConnectPortInvalid(object):
......@@ -196,6 +193,7 @@ class TestConnectPortInvalid(object):
@pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT)
@pytest.mark.tags("0331")
def test_connect_with_invalid_port(self, args, get_invalid_port):
'''
target: test ip:port connect with invalid port value
......@@ -205,13 +203,13 @@ class TestConnectPortInvalid(object):
port = get_invalid_port
with pytest.raises(Exception) as e:
milvus = get_milvus(args["ip"], port, args["handler"])
# assert not milvus.connected()
class TestConnectURIInvalid(object):
"""
Test connect server with invalid uri
"""
@pytest.fixture(
scope="function",
params=gen_invalid_uris()
......@@ -221,6 +219,7 @@ class TestConnectURIInvalid(object):
@pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT)
@pytest.mark.tags("0331")
def test_connect_with_invalid_uri(self, get_invalid_uri, args):
'''
target: test uri connect with invalid uri value
......@@ -230,4 +229,3 @@ class TestConnectURIInvalid(object):
uri_value = get_invalid_uri
with pytest.raises(Exception) as e:
milvus = get_milvus(uri=uri_value, handler=args["handler"])
# assert not milvus.connected()
......@@ -18,27 +18,50 @@ nprobe = 1
epsilon = 0.001
nlist = 128
# index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 16384}, "metric_type": "L2"}
class TestMixBase:
# TODO
def _test_mix_base(self, connect, collection):
nb = 200000
nq = 5
entities = gen_entities(nb=nb)
ids = connect.insert(collection, entities)
assert len(ids) == nb
connect.flush([collection])
connect.create_index(collection, default_float_vec_field_name, default_index)
index = connect.describe_index(collection, default_float_vec_field_name)
assert index == default_index
query, vecs = gen_query_vectors(default_float_vec_field_name, entities, default_top_k, nq)
connect.load_collection(collection)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == default_top_k
assert res[0]._distances[0] <= epsilon
assert check_id_result(res[0], ids[0])
# disable
def _test_search_during_createIndex(self, args):
loops = 10000
collection = gen_unique_str()
query_vecs = [vectors[0], vectors[1]]
uri = "tcp://%s:%s" % (args["ip"], args["port"])
id_0 = 0; id_1 = 0
id_0 = 0;
id_1 = 0
milvus_instance = get_milvus(args["handler"])
# milvus_instance.connect(uri=uri)
milvus_instance.create_collection({'collection_name': collection,
'dimension': default_dim,
'index_file_size': index_file_size,
'metric_type': "L2"})
'dimension': default_dim,
'index_file_size': index_file_size,
'metric_type': "L2"})
for i in range(10):
status, ids = milvus_instance.bulk_insert(collection, vectors)
# logging.getLogger().info(ids)
if i == 0:
id_0 = ids[0]; id_1 = ids[1]
id_0 = ids[0];
id_1 = ids[1]
# def create_index(milvus_instance):
# logging.getLogger().info("In create index")
# status = milvus_instance.create_index(collection, index_params)
......@@ -49,6 +72,7 @@ class TestMixBase:
logging.getLogger().info("In add vectors")
status, ids = milvus_instance.bulk_insert(collection, vectors)
logging.getLogger().info(status)
def search(milvus_instance):
logging.getLogger().info("In search vectors")
for i in range(loops):
......@@ -56,13 +80,14 @@ class TestMixBase:
logging.getLogger().info(status)
assert result[0][0].id == id_0
assert result[1][0].id == id_1
milvus_instance = get_milvus(args["handler"])
# milvus_instance.connect(uri=uri)
p_search = Process(target=search, args=(milvus_instance, ))
p_search = Process(target=search, args=(milvus_instance,))
p_search.start()
milvus_instance = get_milvus(args["handler"])
# milvus_instance.connect(uri=uri)
p_create = Process(target=insert, args=(milvus_instance, ))
p_create = Process(target=insert, args=(milvus_instance,))
p_create.start()
p_create.join()
......@@ -79,7 +104,7 @@ class TestMixBase:
idx = []
index_param = {'nlist': nlist}
#create collection and add vectors
# create collection and add vectors
for i in range(30):
collection_name = gen_unique_str('test_mix_multi_collections')
collection_list.append(collection_name)
......@@ -123,7 +148,7 @@ class TestMixBase:
status = connect.create_index(collection_list[50 + i], IndexType.IVF_SQ8, index_param)
assert status.OK()
#describe index
# describe index
for i in range(10):
status, result = connect.get_index_info(collection_list[i])
assert result._index_type == IndexType.FLAT
......@@ -138,7 +163,7 @@ class TestMixBase:
status, result = connect.get_index_info(collection_list[50 + i])
assert result._index_type == IndexType.IVF_SQ8
#search
# search
query_vecs = [vectors[0], vectors[10], vectors[20]]
for i in range(60):
collection = collection_list[i]
......@@ -154,8 +179,18 @@ class TestMixBase:
logging.getLogger().info(idx[3 * i + j])
assert check_result(result[j], idx[3 * i + j])
def check_result(result, id):
if len(result) >= 5:
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
else:
return id in (i.id for i in result)
def check_id_result(result, id):
limit_in = 5
ids = [entity.id for entity in result]
if len(result) >= limit_in:
return id in ids[:limit_in]
else:
return id in ids
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册