未验证 提交 8cdbc006 编写于 作者: D del-zhenwu 提交者: GitHub

add annoy case (#1868)

* add annoy case
Signed-off-by: Nzw <zw@zilliz.com>

* update case
Signed-off-by: Nzw <zw@zilliz.com>
上级 bf6d22e2
...@@ -31,11 +31,12 @@ class TestDeleteBase: ...@@ -31,11 +31,12 @@ class TestDeleteBase:
params=gen_simple_index() params=gen_simple_index()
) )
def get_simple_index(self, request, connect): def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU": if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.HNSW]: if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVF_SQ8H]:
pytest.skip("Only support index_type: flat/ivf_flat/ivf_sq8/hnsw/ivf_pq") pytest.skip("Only support index_type: idmap/ivf")
else: elif str(connect._cmd("mode")[1]) == "CPU":
pytest.skip("Only support CPU mode") if request.param["index_type"] in [IndexType.IVF_SQ8H]:
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param return request.param
def test_delete_vector_search(self, connect, collection, get_simple_index): def test_delete_vector_search(self, connect, collection, get_simple_index):
...@@ -170,8 +171,6 @@ class TestDeleteBase: ...@@ -170,8 +171,6 @@ class TestDeleteBase:
assert status.OK() assert status.OK()
status = connect.flush([collection]) status = connect.flush([collection])
assert status.OK() assert status.OK()
status = connect.flush([collection])
assert status.OK()
delete_ids = [ids[0], ids[-1]] delete_ids = [ids[0], ids[-1]]
query_vecs = [vectors[0], vectors[1], vectors[-1]] query_vecs = [vectors[0], vectors[1], vectors[-1]]
status = connect.delete_by_id(collection, delete_ids) status = connect.delete_by_id(collection, delete_ids)
...@@ -298,11 +297,12 @@ class TestDeleteIndexedVectors: ...@@ -298,11 +297,12 @@ class TestDeleteIndexedVectors:
params=gen_simple_index() params=gen_simple_index()
) )
def get_simple_index(self, request, connect): def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU": if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.HNSW]: if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVF_SQ8H]:
pytest.skip("Only support index_type: flat/ivf_flat/ivf_sq8") pytest.skip("Only support index_type: idmap/ivf")
else: elif str(connect._cmd("mode")[1]) == "CPU":
pytest.skip("Only support CPU mode") if request.param["index_type"] in [IndexType.IVF_SQ8H]:
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param return request.param
def test_delete_vectors_after_index_created_search(self, connect, collection, get_simple_index): def test_delete_vectors_after_index_created_search(self, connect, collection, get_simple_index):
......
...@@ -31,11 +31,9 @@ class TestFlushBase: ...@@ -31,11 +31,9 @@ class TestFlushBase:
params=gen_simple_index() params=gen_simple_index()
) )
def get_simple_index(self, request, connect): def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU": if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.HNSW]: if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVF_SQ8H]:
pytest.skip("Only support index_type: flat/ivf_flat/ivf_sq8") pytest.skip("Only support index_type: idmap/flat")
else:
pytest.skip("Only support CPU mode")
return request.param return request.param
def test_flush_collection_not_existed(self, connect, collection): def test_flush_collection_not_existed(self, connect, collection):
......
...@@ -268,7 +268,7 @@ class TestIndexBase: ...@@ -268,7 +268,7 @@ class TestIndexBase:
p.join() p.join()
# TODO: enable # TODO: enable
@pytest.mark.timeout(BUILD_TIMEOUT) @pytest.mark.timeout(BUILD_TIMEOUT)
@pytest.mark.level(2) @pytest.mark.level(2)
def _test_create_index_multiprocessing(self, connect, collection, args): def _test_create_index_multiprocessing(self, connect, collection, args):
......
...@@ -19,6 +19,7 @@ add_interval_time = 2 ...@@ -19,6 +19,7 @@ add_interval_time = 2
vectors = gen_vectors(6000, dim) vectors = gen_vectors(6000, dim)
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2') vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
vectors = vectors.tolist() vectors = vectors.tolist()
top_k = 1
nprobe = 1 nprobe = 1
epsilon = 0.001 epsilon = 0.001
tag = "1970-01-01" tag = "1970-01-01"
...@@ -1198,7 +1199,7 @@ class TestSearchParamsInvalid(object): ...@@ -1198,7 +1199,7 @@ class TestSearchParamsInvalid(object):
scope="function", scope="function",
params=gen_invaild_search_params() params=gen_invaild_search_params()
) )
def get_invalid_searh_param(self, request, connect): def get_invalid_search_param(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU": if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H: if request.param["index_type"] == IndexType.IVF_SQ8H:
pytest.skip("sq8h not support in CPU mode") pytest.skip("sq8h not support in CPU mode")
...@@ -1207,25 +1208,17 @@ class TestSearchParamsInvalid(object): ...@@ -1207,25 +1208,17 @@ class TestSearchParamsInvalid(object):
pytest.skip("ivfpq not support in GPU mode") pytest.skip("ivfpq not support in GPU mode")
return request.param return request.param
def test_search_with_invalid_params(self, connect, collection, get_invalid_searh_param): def test_search_with_invalid_params(self, connect, collection, get_invalid_search_param):
''' '''
target: test search fuction, with invalid search params target: test search fuction, with invalid search params
method: search with params method: search with params
expected: search status not ok, and the connection is normal expected: search status not ok, and the connection is normal
''' '''
index_type = get_invalid_searh_param["index_type"] index_type = get_invalid_search_param["index_type"]
search_param = get_invalid_searh_param["search_param"] search_param = get_invalid_search_param["search_param"]
for index in gen_simple_index():
if index_type in [IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]: if index_type == index["index_type"]:
connect.create_index(collection, index_type, {"nlist": 16384}) connect.create_index(collection, index_type, index["index_param"])
if (index_type == IndexType.IVF_PQ):
connect.create_index(collection, index_type, {"nlist": 16384, "m": 16})
if(index_type == IndexType.HNSW):
connect.create_index(collection, index_type, {"M": 16, "efConstruction": 500})
if (index_type == IndexType.RNSG):
connect.create_index(collection, index_type, {"search_length": 60, "out_degree": 50, "candidate_pool_size": 300, "knng": 100})
top_k = 1
query_vecs = gen_vectors(1, dim) query_vecs = gen_vectors(1, dim)
status, result = connect.search_vectors(collection, top_k, query_vecs, params=search_param) status, result = connect.search_vectors(collection, top_k, query_vecs, params=search_param)
assert not status.OK() assert not status.OK()
......
...@@ -12,6 +12,17 @@ from milvus import Milvus, IndexType, MetricType ...@@ -12,6 +12,17 @@ from milvus import Milvus, IndexType, MetricType
port = 19530 port = 19530
epsilon = 0.000001 epsilon = 0.000001
all_index_types = [
IndexType.FLAT,
IndexType.IVFLAT,
IndexType.IVF_SQ8,
IndexType.IVF_SQ8H,
IndexType.IVF_PQ,
IndexType.HNSW,
IndexType.RNSG,
IndexType.ANNOY
]
def get_milvus(handler=None): def get_milvus(handler=None):
if handler is None: if handler is None:
...@@ -460,34 +471,31 @@ def gen_invalid_engine_config(): ...@@ -460,34 +471,31 @@ def gen_invalid_engine_config():
def gen_invaild_search_params(): def gen_invaild_search_params():
index_types = [ invalid_search_key = 100
IndexType.FLAT,
IndexType.IVFLAT,
IndexType.IVF_SQ8,
IndexType.IVF_SQ8H,
IndexType.IVF_PQ,
IndexType.HNSW,
IndexType.RNSG
]
search_params = [] search_params = []
for index_type in index_types: for index_type in all_index_types:
if index_type == IndexType.FLAT:
continue
search_params.append({"index_type": index_type, "search_param": {"invalid_key": invalid_search_key}})
if index_type in [IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ]: if index_type in [IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ]:
for nprobe in gen_invalid_params(): for nprobe in gen_invalid_params():
ivf_search_params = {"index_type": index_type, "search_param": {"nprobe": nprobe}} ivf_search_params = {"index_type": index_type, "search_param": {"nprobe": nprobe}}
search_params.append(ivf_search_params) search_params.append(ivf_search_params)
search_params.append({"index_type": index_type, "search_param": {"invalid_key": 100}})
elif index_type == IndexType.HNSW: elif index_type == IndexType.HNSW:
for ef in gen_invalid_params(): for ef in gen_invalid_params():
hnsw_search_param = {"index_type": index_type, "search_param": {"ef": ef}} hnsw_search_param = {"index_type": index_type, "search_param": {"ef": ef}}
search_params.append(hnsw_search_param) search_params.append(hnsw_search_param)
search_params.append({"index_type": index_type, "search_param": {"invalid_key": 100}})
elif index_type == IndexType.RNSG: elif index_type == IndexType.RNSG:
for search_length in gen_invalid_params(): for search_length in gen_invalid_params():
nsg_search_param = {"index_type": index_type, "search_param": {"search_length": search_length}} nsg_search_param = {"index_type": index_type, "search_param": {"search_length": search_length}}
search_params.append(nsg_search_param) search_params.append(nsg_search_param)
search_params.append({"index_type": index_type, "search_param": {"invalid_key": 100}}) search_params.append({"index_type": index_type, "search_param": {"invalid_key": 100}})
elif index_type == IndexType.ANNOY:
for search_k in gen_invalid_params():
if isinstance(search_k, int):
continue
annoy_search_param = {"index_type": index_type, "search_param": {"search_k": search_k}}
search_params.append(annoy_search_param)
return search_params return search_params
...@@ -525,20 +533,13 @@ def gen_invalid_index(): ...@@ -525,20 +533,13 @@ def gen_invalid_index():
index_params.append({"index_type": IndexType.RNSG, index_params.append({"index_type": IndexType.RNSG,
"index_param": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300, "index_param": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
"knng": 100}}) "knng": 100}})
for invalid_n_trees in gen_invalid_params():
index_params.append({"index_type": IndexType.ANNOY, "index_param": {"n_trees": invalid_n_trees}})
return index_params return index_params
def gen_index(): def gen_index():
index_types = [
IndexType.FLAT,
IndexType.IVFLAT,
IndexType.IVF_SQ8,
IndexType.IVF_SQ8H,
IndexType.IVF_PQ,
IndexType.HNSW,
IndexType.RNSG
]
nlists = [1, 1024, 16384] nlists = [1, 1024, 16384]
pq_ms = [128, 64, 32, 16, 8, 4] pq_ms = [128, 64, 32, 16, 8, 4]
Ms = [5, 24, 48] Ms = [5, 24, 48]
...@@ -549,7 +550,7 @@ def gen_index(): ...@@ -549,7 +550,7 @@ def gen_index():
knngs = [5, 100, 300] knngs = [5, 100, 300]
index_params = [] index_params = []
for index_type in index_types: for index_type in all_index_types:
if index_type == IndexType.FLAT: if index_type == IndexType.FLAT:
index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}}) index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}})
elif index_type in [IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]: elif index_type in [IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]:
...@@ -580,15 +581,6 @@ def gen_index(): ...@@ -580,15 +581,6 @@ def gen_index():
def gen_simple_index(): def gen_simple_index():
index_types = [
IndexType.FLAT,
IndexType.IVFLAT,
IndexType.IVF_SQ8,
IndexType.IVF_SQ8H,
IndexType.IVF_PQ,
IndexType.HNSW,
IndexType.RNSG
]
params = [ params = [
{"nlist": 1024}, {"nlist": 1024},
{"nlist": 1024}, {"nlist": 1024},
...@@ -596,12 +588,12 @@ def gen_simple_index(): ...@@ -596,12 +588,12 @@ def gen_simple_index():
{"nlist": 1024}, {"nlist": 1024},
{"nlist": 1024, "m": 16}, {"nlist": 1024, "m": 16},
{"M": 16, "efConstruction": 500}, {"M": 16, "efConstruction": 500},
{"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50} {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
{"n_trees": 4}
] ]
index_params = [] index_params = []
for i in range(len(index_types)): for i in range(len(all_index_types)):
index_params.append({"index_type": index_types[i], "index_param": params[i]}) index_params.append({"index_type": all_index_types[i], "index_param": params[i]})
return index_params return index_params
...@@ -612,6 +604,9 @@ def get_search_param(index_type): ...@@ -612,6 +604,9 @@ def get_search_param(index_type):
return {"ef": 64} return {"ef": 64}
elif index_type == IndexType.RNSG: elif index_type == IndexType.RNSG:
return {"search_length": 50} return {"search_length": 50}
elif index_type == IndexType.ANNOY:
return {"search_k": 100}
else: else:
logging.getLogger().info("Invalid index_type.") logging.getLogger().info("Invalid index_type.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册