未验证 提交 7f744c03 编写于 作者: D del-zhenwu 提交者: GitHub

[skip ci] Filter id:-1 (#2274)

* add async
Signed-off-by: Nzw <zw@zilliz.com>

* Update case
Signed-off-by: Nzw <zw@zilliz.com>

* Update some cases
Signed-off-by: Nzw <zw@zilliz.com>

* add has_partition case
Signed-off-by: Nzw <zw@zilliz.com>

* add search_by_id case: ids duplicate
Signed-off-by: Nzw <zw@zilliz.com>

* add async add
Signed-off-by: Nzw <zw@zilliz.com>

* add flush case
Signed-off-by: Nzw <zw@zilliz.com>

* update case
Signed-off-by: Nzw <zw@zilliz.com>

* fix search by id case
Signed-off-by: Nzw <zw@zilliz.com>

* update case
Signed-off-by: Nzw <zw@zilliz.com>

* filter id:-1
Signed-off-by: Nzw <zw@zilliz.com>

* pq distance
Signed-off-by: Nzw <zw@zilliz.com>

* [skip ci] skip ci
Signed-off-by: Nzw <zw@zilliz.com>
上级 d49b609a
import time
import pdb
import threading
import logging
import threading
......@@ -676,6 +677,121 @@ class TestAddBase:
status, ids = connect.add_vectors(collection_name=collection_list[i], records=vectors)
assert status.OK()
class TestAddAsync:
@pytest.fixture(
scope="function",
params=[
1,
1000
],
)
def insert_count(self, request):
yield request.param
def check_status(self, status, result):
logging.getLogger().info("In callback check status")
assert status.OK()
def check_status_not_ok(self, status, result):
logging.getLogger().info("In callback check status")
assert not status.OK()
def test_insert_async(self, connect, collection, insert_count):
'''
target: test add vectors with different length of vectors
method: set different vectors as add method params
expected: length of ids is equal to the length of vectors
'''
nb = insert_count
insert_vec_list = gen_vectors(nb, dim)
future = connect.add_vectors(collection, insert_vec_list, _async=True)
status, ids = future.result()
connect.flush([collection])
assert len(ids) == nb
assert status.OK()
@pytest.mark.level(2)
def test_insert_async_false(self, connect, collection, insert_count):
'''
target: test add vectors with different length of vectors
method: set different vectors as add method params
expected: length of ids is equal to the length of vectors
'''
nb = insert_count
insert_vec_list = gen_vectors(nb, dim)
status, ids = connect.add_vectors(collection, insert_vec_list, _async=False)
connect.flush([collection])
assert len(ids) == nb
assert status.OK()
def test_insert_async_callback(self, connect, collection, insert_count):
'''
target: test add vectors with different length of vectors
method: set different vectors as add method params
expected: length of ids is equal to the length of vectors
'''
nb = insert_count
insert_vec_list = gen_vectors(nb, dim)
future = connect.add_vectors(collection, insert_vec_list, _async=True, _callback=self.check_status)
future.done()
@pytest.mark.level(2)
def test_insert_async_long(self, connect, collection):
'''
target: test add vectors with different length of vectors
method: set different vectors as add method params
expected: length of ids is equal to the length of vectors
'''
nb = 50000
insert_vec_list = gen_vectors(nb, dim)
future = connect.add_vectors(collection, insert_vec_list, _async=True, _callback=self.check_status)
status, result = future.result()
assert status.OK()
assert len(result) == nb
connect.flush([collection])
status, count = connect.count_collection(collection)
assert status.OK()
logging.getLogger().info(status)
logging.getLogger().info(count)
assert count == nb
def test_insert_async_callback_timeout(self, connect, collection):
'''
target: test add vectors with different length of vectors
method: set different vectors as add method params
expected: length of ids is equal to the length of vectors
'''
nb = 100000
insert_vec_list = gen_vectors(nb, dim)
future = connect.add_vectors(collection, insert_vec_list, _async=True, _callback=self.check_status, timeout=1)
future.done()
def test_insert_async_invalid_params(self, connect, collection):
'''
target: test add vectors with different length of vectors
method: set different vectors as add method params
expected: length of ids is equal to the length of vectors
'''
insert_vec_list = gen_vectors(nb, dim)
collection_new = gen_unique_str()
future = connect.add_vectors(collection_new, insert_vec_list, _async=True)
status, result = future.result()
assert not status.OK()
# TODO: add assertion
def test_insert_async_invalid_params_raise_exception(self, connect, collection):
'''
target: test add vectors with different length of vectors
method: set different vectors as add method params
expected: length of ids is equal to the length of vectors
'''
insert_vec_list = []
collection_new = gen_unique_str()
with pytest.raises(Exception) as e:
future = connect.add_vectors(collection_new, insert_vec_list, _async=True)
class TestAddIP:
"""
******************************************************************
......
......@@ -233,6 +233,44 @@ class TestFlushBase:
assert res == 0
class TestFlushAsync:
"""
******************************************************************
The following cases are used to test `flush` function
******************************************************************
"""
def check_status(self, status, result):
logging.getLogger().info("In callback check status")
assert status.OK()
def test_flush_empty_collection(self, connect, collection):
'''
method: flush collection with no vectors
expected: status ok
'''
future = connect.flush([collection], _async=True)
status = future.result()
assert status.OK()
def test_flush_async(self, connect, collection):
vectors = gen_vectors(nb, dim)
status, ids = connect.add_vectors(collection, vectors)
future = connect.flush([collection], _async=True)
status = future.result()
assert status.OK()
def test_flush_async(self, connect, collection):
nb = 100000
vectors = gen_vectors(nb, dim)
connect.add_vectors(collection, vectors)
logging.getLogger().info("before")
future = connect.flush([collection], _async=True, _callback=self.check_status)
logging.getLogger().info("after")
future.done()
status = future.result()
assert status.OK()
class TestCollectionNameInvalid(object):
"""
Test adding vectors with invalid collection names
......
......@@ -1806,3 +1806,75 @@ class TestCreateIndexParamsInvalid(object):
logging.getLogger().info(result)
assert result._collection_name == collection
assert result._index_type == IndexType.FLAT
class TestIndexAsync:
"""
******************************************************************
The following cases are used to test `create_index` function
******************************************************************
"""
@pytest.fixture(
scope="function",
params=gen_index()
)
def get_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
pytest.skip("ivfpq not support in GPU mode")
return request.param
@pytest.fixture(
scope="function",
params=gen_simple_index()
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
# if request.param["index_type"] == IndexType.IVF_PQ:
if request.param["index_type"] not in [IndexType.IVF_FLAT]:
# pytest.skip("ivfpq not support in GPU mode")
pytest.skip("debug ivf_flat in GPU mode")
return request.param
def check_status(self, status):
logging.getLogger().info("In callback check status")
assert status.OK()
"""
******************************************************************
The following cases are used to test `create_index` function
******************************************************************
"""
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index(self, connect, collection, get_simple_index):
'''
target: test create index interface
method: create collection and add vectors in it, create index
expected: return code equals to 0, and search success
'''
index_param = get_simple_index["index_param"]
index_type = get_simple_index["index_type"]
logging.getLogger().info(get_simple_index)
vectors = gen_vectors(nb, dim)
status, ids = connect.add_vectors(collection, vectors)
logging.getLogger().info("start index")
# future = connect.create_index(collection, index_type, index_param, _async=True, _callback=self.check_status)
future = connect.create_index(collection, index_type, index_param, _async=True)
logging.getLogger().info("before result")
status = future.result()
assert status.OK()
def test_create_index_with_invalid_collectionname(self, connect):
collection_name = " "
nlist = NLIST
index_param = {"nlist": nlist}
future = connect.create_index(collection_name, IndexType.IVF_SQ8, index_param, _async=True)
status = future.result()
assert not status.OK()
......@@ -228,6 +228,78 @@ class TestShowBase:
assert status.OK()
class TestHasBase:
"""
******************************************************************
The following cases are used to test `has_partition` function
******************************************************************
"""
@pytest.fixture(
scope="function",
params=gen_invalid_collection_names()
)
def get_tag_name(self, request):
yield request.param
def test_has_partition(self, connect, collection):
'''
target: test has_partition, check status and result
method: create partition first, then call function: has_partition
expected: status ok, result true
'''
status = connect.create_partition(collection, tag)
status, res = connect.has_partition(collection, tag)
assert status.OK()
logging.getLogger().info(res)
assert res
def test_has_partition_multi_partitions(self, connect, collection):
'''
target: test has_partition, check status and result
method: create partition first, then call function: has_partition
expected: status ok, result true
'''
for tag_name in [tag, "tag_new", "tag_new_new"]:
status = connect.create_partition(collection, tag_name)
for tag_name in [tag, "tag_new", "tag_new_new"]:
status, res = connect.has_partition(collection, tag_name)
assert status.OK()
assert res
def test_has_partition_tag_not_existed(self, connect, collection):
'''
target: test has_partition, check status and result
method: then call function: has_partition, with tag not existed
expected: status ok, result empty
'''
status, res = connect.has_partition(collection, tag)
assert status.OK()
logging.getLogger().info(res)
assert not res
def test_has_partition_collection_not_existed(self, connect, collection):
'''
target: test has_partition, check status and result
method: then call function: has_partition, with collection not existed
expected: status not ok
'''
status, res = connect.has_partition("not_existed_collection", tag)
assert not status.OK()
@pytest.mark.level(2)
def test_has_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
'''
target: test has partition, with invalid tag name, check status returned
method: call function: has_partition
expected: status ok
'''
tag_name = get_tag_name
status = connect.create_partition(collection, tag)
status, res = connect.has_partition(collection, tag_name)
assert status.OK()
class TestDropBase:
"""
......
......@@ -29,12 +29,14 @@ raw_vectors, binary_vectors = gen_binary_vectors(6000, dim)
class TestSearchBase:
@pytest.fixture(scope="function", autouse=True)
def skip_check(self, connect):
if str(connect._cmd("mode")[1]) == "CPU" or str(connect._cmd("mode")[1]) == "GPU":
reason = "GPU mode not support"
logging.getLogger().info(reason)
pytest.skip(reason)
# @pytest.fixture(scope="function", autouse=True)
# def skip_check(self, connect):
# if str(connect._cmd("mode")[1]) == "CPU":
# if request.param["index_type"] == IndexType.IVF_SQ8H:
# pytest.skip("sq8h not support in CPU mode")
# if str(connect._cmd("mode")[1]) == "GPU":
# if request.param["index_type"] == IndexType.IVF_PQ:
# pytest.skip("ivfpq not support in GPU mode")
def init_data(self, connect, collection, nb=6000):
'''
......@@ -82,16 +84,6 @@ class TestSearchBase:
connect.flush([collection])
return add_vectors, ids
def check_no_result(self, results):
if len(results) == 0:
return True
flag = True
for r in results:
flag = flag and (r.id == -1)
if not flag:
return False
return flag
def init_data_partition(self, connect, collection, partition_tag, nb=6000):
'''
Generate vectors and add it in collection, before search vectors
......@@ -104,6 +96,7 @@ class TestSearchBase:
add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2')
add_vectors = add_vectors.tolist()
status, ids = connect.add_vectors(collection, add_vectors, partition_tag=partition_tag)
assert status.OK()
connect.flush([collection])
return add_vectors, ids
......@@ -178,6 +171,22 @@ class TestSearchBase:
assert result[0][0].distance <= epsilon
assert check_result(result[0], ids[0])
def test_search_flat_same_ids(self, connect, collection):
'''
target: test basic search fuction, all the search params is corrent, change top-k value
method: search with the given vector id, check the result
expected: search status ok, and the length of the result is top_k
'''
vectors, ids = self.init_data(connect, collection)
query_ids = [ids[0], ids[0]]
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert result[0][0].distance <= epsilon
assert result[1][0].distance <= epsilon
assert check_result(result[0], ids[0])
assert check_result(result[1], ids[0])
def test_search_flat_max_topk(self, connect, collection):
'''
target: test basic search fuction, all the search params is corrent, change top-k value
......@@ -186,7 +195,7 @@ class TestSearchBase:
'''
top_k = 2049
vectors, ids = self.init_data(connect, collection)
query_ids = ids[0]
query_ids = [ids[0]]
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
assert not status.OK()
......@@ -200,7 +209,7 @@ class TestSearchBase:
query_ids = non_exist_id
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert len(result[0]) == 0
def test_search_collection_empty(self, connect, collection):
'''
......@@ -209,9 +218,11 @@ class TestSearchBase:
expected: search status ok, and the length of the result is top_k
'''
query_ids = non_exist_id
logging.getLogger().info(query_ids)
logging.getLogger().info(collection)
logging.getLogger().info(connect.describe_collection(collection))
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
assert status.OK()
assert len(result) == 0
assert not status.OK()
def test_search_index_l2(self, connect, collection, get_simple_index):
'''
......@@ -221,6 +232,8 @@ class TestSearchBase:
'''
index_param = get_simple_index["index_param"]
index_type = get_simple_index["index_type"]
if index_type == IndexType.IVF_PQ:
pytest.skip("skip pq")
vectors, ids = self.init_data(connect, collection)
status = connect.create_index(collection, index_type, index_param)
query_ids = [ids[0]]
......@@ -239,6 +252,8 @@ class TestSearchBase:
'''
index_param = get_simple_index["index_param"]
index_type = get_simple_index["index_type"]
if index_type == IndexType.IVF_PQ:
pytest.skip("skip pq")
vectors, ids = self.init_data(connect, collection)
status = connect.create_index(collection, index_type, index_param)
query_ids = ids[0:nq]
......@@ -246,7 +261,7 @@ class TestSearchBase:
status, result = connect.search_by_ids(collection, query_ids, top_k, params=search_param)
assert status.OK()
assert len(result) == nq
for i in nq:
for i in range(nq):
assert len(result[i]) == min(len(vectors), top_k)
assert result[i][0].distance <= epsilon
assert check_result(result[i], ids[i])
......@@ -259,17 +274,19 @@ class TestSearchBase:
'''
index_param = get_simple_index["index_param"]
index_type = get_simple_index["index_type"]
if index_type == IndexType.IVF_PQ:
pytest.skip("skip pq")
vectors, ids = self.init_data(connect, collection)
status = connect.create_index(collection, index_type, index_param)
query_ids = ids[0:nq]
query_ids[0] = non_exist_id
query_ids[0] = 1
search_param = get_search_param(index_type)
status, result = connect.search_by_ids(collection, [query_ids], top_k, params=search_param)
status, result = connect.search_by_ids(collection, query_ids, top_k, params=search_param)
assert status.OK()
assert len(result) == nq
for i in nq:
for i in range(nq):
if i == 0:
assert result[i].id == -1
assert len(result[i]) == 0
else:
assert len(result[i]) == min(len(vectors), top_k)
assert result[i][0].distance <= epsilon
......@@ -277,15 +294,16 @@ class TestSearchBase:
def test_search_index_delete(self, connect, collection):
vectors, ids = self.init_data(connect, collection)
query_ids = ids[0]
status = connect.delete_by_id(collection, [query_ids])
query_ids = ids[0:nq]
status = connect.delete_by_id(collection, [query_ids[0]])
assert status.OK()
status = connect.flush(collection)
status, result = connect.search_by_ids(collection, [query_ids], top_k, params={})
status = connect.flush([collection])
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
assert status.OK()
assert len(result) == 1
assert result[0][0].distance <= epsilon
assert result[0][0].id != ids[0]
assert len(result) == nq
assert len(result[0]) == 0
assert len(result[1]) == top_k
assert result[1][0].distance <= epsilon
def test_search_l2_partition_tag_not_existed(self, connect, collection):
'''
......@@ -295,28 +313,31 @@ class TestSearchBase:
'''
status = connect.create_partition(collection, tag)
vectors, ids = self.init_data(connect, collection)
query_ids = ids[0]
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag], params=search_param)
assert status.OK()
query_ids = [ids[0]]
new_tag = gen_unique_str()
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[new_tag], params={})
assert not status.OK()
logging.getLogger().info(status)
assert len(result) == 0
def test_search_l2_partition_other(self, connect, collection):
tag = gen_unique_str()
def test_search_l2_partition_empty(self, connect, collection):
status = connect.create_partition(collection, tag)
vectors, ids = self.init_data(connect, collection)
query_ids = ids[0]
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag], params=search_param)
assert status.OK()
query_ids = [ids[0]]
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag], params={})
assert not status.OK()
logging.getLogger().info(status)
assert len(result) == 0
def test_search_l2_partition(self, connect, collection):
status = connect.create_partition(collection, tag)
vectors, ids = self.init_data_partition(connect, collection, tag)
query_ids = ids[-1]
query_ids = ids[-1:]
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag])
assert status.OK()
assert len(result) == 1
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], query_ids)
assert check_result(result[0], query_ids[-1])
def test_search_l2_partition_B(self, connect, collection):
status = connect.create_partition(collection, tag)
......@@ -325,7 +346,7 @@ class TestSearchBase:
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag])
assert status.OK()
assert len(result) == nq
for i in nq:
for i in range(nq):
assert len(result[i]) == min(len(vectors), top_k)
assert result[i][0].distance <= epsilon
assert check_result(result[i], ids[i])
......@@ -338,14 +359,17 @@ class TestSearchBase:
vectors, new_ids = self.init_data_partition(connect, collection, new_tag, nb=nb+1)
tmp = 2
query_ids = ids[0:tmp]
query_ids.extend(new_ids[0:nq-tmp])
query_ids.extend(new_ids[tmp:nq])
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag, new_tag], params={})
assert status.OK()
assert len(result) == nq
for i in nq:
for i in range(nq):
assert len(result[i]) == min(len(vectors), top_k)
assert result[i][0].distance <= epsilon
assert check_result(result[i], ids[i])
if i < tmp:
assert result[i][0].id == ids[i]
else:
assert result[i][0].id == new_ids[i]
def test_search_l2_index_partitions_match_one_tag(self, connect, collection):
new_tag = "new_tag"
......@@ -355,18 +379,19 @@ class TestSearchBase:
vectors, new_ids = self.init_data_partition(connect, collection, new_tag, nb=nb+1)
tmp = 2
query_ids = ids[0:tmp]
query_ids.extend(new_ids[0:nq-tmp])
query_ids.extend(new_ids[tmp:nq])
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[new_tag], params={})
assert status.OK()
assert len(result) == nq
for i in nq:
for i in range(nq):
if i < tmp:
assert result[i][0].distance > epsilon
assert result[i][0].id != ids[i]
else:
assert len(result[i]) == min(len(vectors), top_k)
assert result[i][0].distance <= epsilon
assert check_result(result[i], ids[i])
assert result[i][0].id == new_ids[i]
assert result[i][1].distance > epsilon
# def test_search_by_ids_without_connect(self, dis_connect, collection):
# '''
......@@ -411,7 +436,7 @@ class TestSearchBase:
status, result = connect.search_by_ids(jac_collection, query_ids, top_k, params=search_param)
assert status.OK()
assert len(result) == nq
for i in nq:
for i in range(nq):
assert len(result[i]) == min(len(vectors), top_k)
assert result[i][0].distance <= epsilon
assert check_result(result[i], ids[i])
......@@ -499,7 +524,7 @@ class TestSearchParamsInvalid(object):
def check_result(result, id):
if len(result) >= 5:
return id in [x.id for x in result[:5]]
if len(result) >= top_k:
return id in [x.id for x in result[:top_k]]
else:
return id in (i.id for i in result)
......@@ -666,7 +666,7 @@ class TestSearchBase:
status, result = connect.search_vectors(substructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert result[0][0].id == -1
assert len(result[0]) == 0
def test_search_distance_substructure_flat_index_B(self, connect, substructure_collection):
'''
......@@ -690,12 +690,12 @@ class TestSearchBase:
status, result = connect.search_vectors(substructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert len(result[0]) == 1
assert len(result[1]) == 1
assert result[0][0].distance <= epsilon
assert result[0][0].id == ids[0]
assert result[1][0].distance <= epsilon
assert result[1][0].id == ids[1]
assert result[0][1].id == -1
assert result[1][1].id == -1
def test_search_distance_superstructure_flat_index(self, connect, superstructure_collection):
'''
......@@ -720,7 +720,7 @@ class TestSearchBase:
status, result = connect.search_vectors(superstructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert result[0][0].id == -1
assert len(result[0]) == 0
def test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection):
'''
......@@ -744,12 +744,12 @@ class TestSearchBase:
status, result = connect.search_vectors(superstructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert len(result[0]) == 2
assert len(result[1]) == 2
assert result[0][0].id in ids
assert result[0][0].distance <= epsilon
assert result[1][0].id in ids
assert result[1][0].distance <= epsilon
assert result[0][2].id == -1
assert result[1][2].id == -1
def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection):
'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册