未验证 提交 64d748ff 编写于 作者: D del-zhenwu 提交者: GitHub

[skip ci] remove index name && add metric type in index and search params (#3088)

Signed-off-by: Nzw <zw@milvus.io>
Co-authored-by: Nzw <zw@milvus.io>
Co-authored-by: NWang XiangYu <xy.wang@zilliz.com>
上级 5c37a9fa
......@@ -49,6 +49,7 @@ class TestCollectionCount:
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in cpu mode")
request.param.update({"metric_type": "L2"})
return request.param
def test_collection_count(self, connect, collection, insert_count):
......@@ -153,7 +154,7 @@ class TestCollectionCount:
entities = gen_entities(insert_count)
res = connect.insert(collection, entities)
connect.flush([collection])
connect.create_index(collection, field_name, index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
res = connect.count_entities(collection)
assert res == insert_count
......@@ -204,6 +205,7 @@ class TestCollectionCountIP:
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in cpu mode")
request.param.update({"metric_type": "IP"})
return request.param
def test_collection_count(self, connect, ip_collection, insert_count):
......@@ -308,7 +310,7 @@ class TestCollectionCountIP:
entities = gen_entities(insert_count)
res = connect.insert(ip_collection, entities)
connect.flush([ip_collection])
connect.create_index(ip_collection, field_name, index_name, get_simple_index)
connect.create_index(ip_collection, field_name, get_simple_index)
res = connect.count_entities(ip_collection)
assert res == insert_count
......@@ -484,7 +486,7 @@ class TestCollectionCountBinary:
raw_vectors, entities = gen_binary_entities(insert_count)
res = connect.insert(jac_collection, entities)
connect.flush([jac_collection])
connect.create_index(jac_collection, field_name, index_name, get_simple_index)
connect.create_index(jac_collection, field_name, get_simple_index)
res = connect.count_entities(jac_collection)
assert res == insert_count
......
......@@ -16,7 +16,6 @@ nb = 6000
nlist = 1024
collection_id = "collection_stats"
field_name = "float_vector"
default_index_name = "stats_index"
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
......@@ -55,6 +54,7 @@ class TestStatsBase:
def get_jaccard_index(self, request, connect):
logging.getLogger().info(request.param)
if request.param["index_type"] in binary_support():
request.param["metric_type"] = "JACCARD"
return request.param
else:
pytest.skip("Skip index Temporary")
......@@ -241,7 +241,7 @@ class TestStatsBase:
'''
ids = connect.insert(collection, entities)
connect.flush([collection])
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
......@@ -255,7 +255,8 @@ class TestStatsBase:
'''
ids = connect.insert(ip_collection, entities)
connect.flush([ip_collection])
connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
get_simple_index.update({"metric_type": "IP"})
connect.create_index(ip_collection, field_name, get_simple_index)
stats = connect.get_collection_stats(ip_collection)
logging.getLogger().info(stats)
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
......@@ -269,7 +270,7 @@ class TestStatsBase:
'''
ids = connect.insert(jac_collection, binary_entities)
connect.flush([jac_collection])
connect.create_index(jac_collection, "binary_vector", default_index_name, get_jaccard_index)
connect.create_index(jac_collection, "binary_vector", get_jaccard_index)
stats = connect.get_collection_stats(jac_collection)
logging.getLogger().info(stats)
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
......@@ -284,7 +285,7 @@ class TestStatsBase:
ids = connect.insert(collection, entities)
connect.flush([collection])
for index_type in ["IVF_FLAT", "IVF_SQ8"]:
connect.create_index(collection, field_name, default_index_name, {"index_type": index_type, "nlist": 1024})
connect.create_index(collection, field_name, {"index_type": index_type, "nlist": 1024, "metric_type": "L2"})
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
assert stats["partitions"][0]["segments"][0]["index_name"] == index_type
......@@ -326,9 +327,9 @@ class TestStatsBase:
res = connect.insert(collection_name, entities)
connect.flush(collection_list)
if i % 2:
connect.create_index(collection_name, field_name, default_index_name, {"index_type": "IVF_SQ8", "nlist": 1024})
connect.create_index(collection_name, field_name, {"index_type": "IVF_SQ8", "nlist": 1024, "metric_type": "L2"})
else:
connect.create_index(collection_name, field_name, default_index_name, {"index_type": "IVF_FLAT", "nlist": 1024})
connect.create_index(collection_name, field_name, {"index_type": "IVF_FLAT", "nlist": 1024, "metric_type": "L2"})
for i in range(collection_num):
stats = connect.get_collection_stats(collection_list[i])
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
......
......@@ -7,7 +7,6 @@ from multiprocessing import Process
from utils import *
collection_id = "load_collection"
index_name = "load_index_name"
nb = 6000
default_fields = gen_default_fields()
entities = gen_entities(nb)
......@@ -41,7 +40,7 @@ class TestLoadCollection:
connect.insert(collection, entities)
connect.flush([collection])
logging.getLogger().info(get_simple_index)
connect.create_index(collection, field_name, index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
connect.load_collection(collection)
def load_empty_collection(self, connect, collection):
......
......@@ -16,7 +16,6 @@ DELETE_TIMEOUT = 60
tag = "1970-01-01"
nb = 6000
field_name = "float_vector"
default_index_name = "insert_index"
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
......@@ -29,12 +28,6 @@ default_single_query = {
}
}
def query_with_index(index_name):
query = copy.deepcopy(default_single_query)
query["bool"]["must"][0]["vector"]["params"].update({"index_name": default_index_name})
return query
class TestDeleteBase:
"""
******************************************************************
......@@ -310,7 +303,7 @@ class TestDeleteBase:
connect.flush([collection])
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
# assert index info
# TODO
......@@ -336,7 +329,7 @@ class TestDeleteBase:
method: create index, insert entities, and delete
expected: entities deleted
'''
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, entities)
connect.flush([collection])
delete_ids = [ids[0], ids[-1]]
......@@ -356,7 +349,7 @@ class TestDeleteBase:
expected: entities deleted
'''
ids = [i for i in range(nb)]
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
for i in range(nb):
connect.insert(collection, entity, [ids[i]])
connect.flush([collection])
......@@ -431,7 +424,7 @@ class TestDeleteBase:
ids = connect.insert(collection, entities, partition_tag=tag)
ids_new = connect.insert(collection, entities, partition_tag=tag_new)
connect.flush([collection])
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
delete_ids = [ids[0], ids_new[0]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
......
......@@ -17,7 +17,6 @@ DELETE_TIMEOUT = 60
tag = "1970-01-01"
nb = 6000
field_name = "float_entity"
default_index_name = "insert_index"
entity = gen_entities(1)
binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
......@@ -286,7 +285,7 @@ class TestGetBase:
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.flush([collection])
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
......@@ -438,7 +437,7 @@ class TestGetBase:
'''
ids = connect.insert(collection, entities)
connect.flush([collection])
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
......@@ -454,7 +453,7 @@ class TestGetBase:
for i in range(nb):
ids.append(connect.insert(collection, entity)[0])
connect.flush([collection])
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
......
......@@ -16,7 +16,6 @@ tag = "1970-01-01"
insert_interval_time = 1.5
nb = 6000
field_name = "float_vector"
default_index_name = "insert_index"
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
......@@ -26,7 +25,7 @@ default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim),
"params": {"index_name": default_index_name, "nprobe": 10}}}}
"params": {"nprobe": 10}}}}
]
}
}
......@@ -127,7 +126,7 @@ class TestInsertBase:
ids = connect.insert(collection, entities)
assert len(ids) == nb
connect.flush([collection])
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_after_create_index(self, connect, collection, get_simple_index):
......@@ -136,7 +135,7 @@ class TestInsertBase:
method: insert vector and build index
expected: no error raised
'''
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, entities)
assert len(ids) == nb
......@@ -692,7 +691,7 @@ class TestInsertMultiCollections:
'''
collection_name = gen_unique_str(collection_id)
connect.create_collection(collection_name, default_fields)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, entity)
connect.drop_collection(collection_name)
......@@ -706,7 +705,7 @@ class TestInsertMultiCollections:
collection_name = gen_unique_str(collection_id)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection, entity)
connect.create_index(collection_name, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
count = connect.count_entities(collection_name)
assert count == 0
......@@ -721,7 +720,7 @@ class TestInsertMultiCollections:
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection, entity)
connect.flush([collection])
connect.create_index(collection_name, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
count = connect.count_entities(collection)
assert count == 1
......
......@@ -12,7 +12,6 @@ segment_row_count = 100000
nb = 6000
tag = "1970-01-01"
field_name = "float_vector"
default_index_name = "list_index"
collection_id = "list_id_in_segment"
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
......@@ -29,7 +28,7 @@ def get_segment_id(connect, collection, nb=1, vec_type='float', index_params=Non
ids = connect.insert(collection, entities)
connect.flush([collection])
if index_params:
connect.create_index(collection, field_name, default_index_name, index_params)
connect.create_index(collection, field_name, index_params)
stats = connect.get_collection_stats(collection)
return ids, stats["partitions"][0]["segments"][0]["id"]
......@@ -259,6 +258,7 @@ class TestListIdInSegmentIP:
method: call list_id_in_segment and check if the segment contains vectors
expected: status ok
'''
get_simple_index["metric_type"] = "IP"
ids, seg_id = get_segment_id(connect, ip_collection, nb=nb, index_params=get_simple_index)
vector_ids = connect.list_id_in_segment(ip_collection, seg_id)
# TODO:
......@@ -280,13 +280,14 @@ class TestListIdInSegmentIP:
# TODO
@pytest.mark.level(2)
def test_list_id_in_segment_after_delete_vectors(self, connect, ip_collection):
def test_list_id_in_segment_after_delete_vectors(self, connect, ip_collection, get_simple_index):
'''
target: get vector ids after vectors are deleted
method: add vectors and delete a few, call list_id_in_segment
expected: status ok, vector_ids decreased after vectors deleted
'''
nb = 2
get_simple_index["metric_type"] = "IP"
ids, seg_id = get_segment_id(connect, ip_collection, nb=nb)
delete_ids = [ids[0]]
status = connect.delete_entity_by_id(ip_collection, delete_ids)
......@@ -357,6 +358,7 @@ class TestListIdInSegmentJAC:
method: call list_id_in_segment and check if the segment contains vectors
expected: status ok
'''
get_jaccard_index["metric_type"] = "JACCARD"
ids, seg_id = get_segment_id(connect, jac_collection, nb=nb, index_params=get_jaccard_index, vec_type='binary')
vector_ids = connect.list_id_in_segment(jac_collection, seg_id)
# TODO:
......@@ -383,6 +385,7 @@ class TestListIdInSegmentJAC:
expected: status ok, vector_ids decreased after vectors deleted
'''
nb = 2
get_jaccard_index["metric_type"] = "JACCARD"
ids, seg_id = get_segment_id(connect, jac_collection, nb=nb, vec_type='binary', index_params=get_jaccard_index)
delete_ids = [ids[0]]
status = connect.delete_entity_by_id(jac_collection, delete_ids)
......
......@@ -21,7 +21,6 @@ top_k = 10
nprobe = 1
epsilon = 0.001
field_name = "float_vector"
default_index_name = "insert_index"
default_fields = gen_default_fields()
search_param = {"nprobe": 1}
entity = gen_entities(1, is_normal=True)
......@@ -34,7 +33,7 @@ query, query_vecs = gen_query_vectors_inside_entities(field_name, entities, top_
# "must": [
# {"term": {"A": {"values": [1, 2, 5]}}},
# {"range": {"B": {"ranges": {"GT": 1, "LT": 100}}}},
# {"vector": {"Vec": {"topk": 10, "query": vec[: 1], "params": {"index_name": "IVFFLAT", "nprobe": 10}}}}
# {"vector": {"Vec": {"topk": 10, "query": vec[: 1], "params": {"nprobe": 10}}}}
# ],
# },
# }
......@@ -206,7 +205,7 @@ class TestSearchBase:
if index_type == "IVF_PQ":
pytest.skip("Skip PQ")
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
......@@ -233,7 +232,7 @@ class TestSearchBase:
pytest.skip("Skip PQ")
connect.create_partition(collection, tag)
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
......@@ -262,7 +261,7 @@ class TestSearchBase:
pytest.skip("Skip PQ")
connect.create_partition(collection, tag)
entities, ids = init_data(connect, collection, partition_tags=tag)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
for tags in [[tag], [tag, "new_tag"]]:
......@@ -312,7 +311,7 @@ class TestSearchBase:
connect.create_partition(collection, new_tag)
entities, ids = init_data(connect, collection, partition_tags=tag)
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
......@@ -347,7 +346,7 @@ class TestSearchBase:
connect.create_partition(collection, new_tag)
entities, ids = init_data(connect, collection, partition_tags=tag)
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors_inside_entities(field_name, new_entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
......@@ -399,7 +398,8 @@ class TestSearchBase:
if index_type == "IVF_PQ":
pytest.skip("Skip PQ")
entities, ids = init_data(connect, ip_collection)
connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
get_simple_index["metric_type"] = "IP"
connect.create_index(ip_collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
......@@ -427,8 +427,10 @@ class TestSearchBase:
pytest.skip("Skip PQ")
connect.create_partition(ip_collection, tag)
entities, ids = init_data(connect, ip_collection)
connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
get_simple_index["metric_type"] = "IP"
connect.create_index(ip_collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
search_param["metric_type"] = "IP"
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
with pytest.raises(Exception) as e:
......@@ -459,8 +461,10 @@ class TestSearchBase:
connect.create_partition(ip_collection, new_tag)
entities, ids = init_data(connect, ip_collection, partition_tags=tag)
new_entities, new_ids = init_data(connect, ip_collection, nb=6001, partition_tags=new_tag)
connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
get_simple_index["metric_type"] = "IP"
connect.create_index(ip_collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
search_param["metric_type"] = "IP"
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
with pytest.raises(Exception) as e:
......@@ -522,7 +526,7 @@ class TestSearchBase:
index_type = get_simple_index["index_type"]
nq = 2
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
inside_vecs = entities[-1]["values"]
......@@ -560,8 +564,10 @@ class TestSearchBase:
index_type = get_simple_index["index_type"]
nq = 2
entities, ids = init_data(connect, ip_collection)
connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
get_simple_index["metric_type"] = "IP"
connect.create_index(ip_collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
search_param["metric_type"] = "IP"
query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
inside_vecs = entities[-1]["values"]
max_distance = 0
......@@ -614,9 +620,10 @@ class TestSearchBase:
int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
index_type = "FLAT"
index_param = {
"nlist": 16384
"nlist": 16384,
"metric_type": "SUBSTRUCTURE"
}
connect.create_index(substructure_collection, index_type, index_param)
connect.create_index(substructure_collection, binary_field_name, index_param)
logging.getLogger().info(connect.get_collection_info(substructure_collection))
logging.getLogger().info(connect.get_index_info(substructure_collection))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False)
......@@ -640,9 +647,10 @@ class TestSearchBase:
int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
index_type = "FLAT"
index_param = {
"nlist": 16384
"nlist": 16384,
"metric_type": "SUBSTRUCTURE"
}
connect.create_index(substructure_collection, index_type, index_param)
connect.create_index(substructure_collection, binary_field_name, index_param)
logging.getLogger().info(connect.get_collection_info(substructure_collection))
logging.getLogger().info(connect.get_index_info(substructure_collection))
query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
......@@ -668,9 +676,10 @@ class TestSearchBase:
int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
index_type = "FLAT"
index_param = {
"nlist": 16384
"nlist": 16384,
"metric_type": "SUBSTRUCTURE"
}
connect.create_index(superstructure_collection, index_type, index_param)
connect.create_index(superstructure_collection, binary_field_name, index_param)
logging.getLogger().info(connect.get_collection_info(superstructure_collection))
logging.getLogger().info(connect.get_index_info(superstructure_collection))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False)
......@@ -694,9 +703,10 @@ class TestSearchBase:
int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
index_type = "FLAT"
index_param = {
"nlist": 16384
"nlist": 16384,
"metric_type": "SUBSTRUCTURE"
}
connect.create_index(superstructure_collection, index_type, index_param)
connect.create_index(superstructure_collection, binary_field_name, index_param)
logging.getLogger().info(connect.get_collection_info(superstructure_collection))
logging.getLogger().info(connect.get_index_info(superstructure_collection))
query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
......@@ -722,9 +732,10 @@ class TestSearchBase:
int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_collection, nb=2)
index_type = "FLAT"
index_param = {
"nlist": 16384
"nlist": 16384,
"metric_type": "TANIMOTO"
}
connect.create_index(tanimoto_collection, index_type, index_param)
connect.create_index(tanimoto_collection, binary_field_name, index_param)
logging.getLogger().info(connect.get_collection_info(tanimoto_collection))
logging.getLogger().info(connect.get_index_info(tanimoto_collection))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, tanimoto_collection, nb=1, insert=False)
......@@ -929,7 +940,7 @@ class TestSearchInvalid(object):
search_params = get_search_params
index_type = get_simple_index["index_type"]
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
if search_params["index_type"] != index_type:
pytest.skip("Skip case")
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, 1, search_params=search_params["search_params"])
......@@ -948,7 +959,7 @@ class TestSearchInvalid(object):
if index_type == "FLAT":
pytest.skip("skip in FLAT index")
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, default_index_name, get_simple_index)
connect.create_index(collection, field_name, get_simple_index)
query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, 1, search_params={})
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
......
......@@ -20,12 +20,11 @@ raw_vector, binary_entity = gen_binary_entities(1)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
field_name = "float_vector"
default_index_name = "insert_index"
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim),
"params": {"index_name": default_index_name, "nprobe": 10}}}}
"params": {"nprobe": 10}}}}
]
}
}
......@@ -227,8 +226,7 @@ class TestCompactBase:
count = 10
ids = connect.insert(collection, entities)
connect.flush([collection])
status = connect.create_index(collection, field_name, default_index_name, get_simple_index)
assert status.OK()
connect.create_index(collection, field_name, get_simple_index)
connect.flush([collection])
# get collection info before compact
info = connect.get_collection_stats(collection)
......@@ -363,9 +361,7 @@ class TestCompactBase:
connect.flush([collection])
status = connect.compact(collection)
assert status.OK()
# index_param = get_simple_index["index_param"]
# index_type = get_simple_index["index_type"]
status = connect.create_index(collection, field_name, default_index_name, get_simple_index)
status = connect.create_index(collection, field_name, get_simple_index)
assert status.OK()
# status, result = connect.get_index_info(collection)
......
......@@ -17,7 +17,6 @@ top_k = 1
nb = 6000
tag = "partition_tag"
field_name = "float_vector"
default_index_name = "insert_index"
entity = gen_entities(1)
entities = gen_entities(nb)
raw_vector, binary_entity = gen_binary_entities(1)
......@@ -27,7 +26,7 @@ default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim),
"params": {"index_name": default_index_name, "nprobe": 10}}}}
"params": {"nprobe": 10}}}}
]
}
}
......
......@@ -17,7 +17,6 @@ TIMEOUT = 120
nb = 6000
tag = "partition_tag"
field_name = "float_vector"
default_index_name = "partition"
entity = gen_entities(1)
entities = gen_entities(nb)
raw_vector, binary_entity = gen_binary_entities(1)
......
......@@ -14,7 +14,6 @@ tag = "1970-01-01"
insert_interval_time = 1.5
nb = 6000
field_name = "float_vector"
default_index_name = "insert_index"
entity = gen_entities(1)
binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
......
......@@ -651,7 +651,7 @@ def gen_simple_index():
for i in range(len(all_index_types)):
if all_index_types[i] in binary_support():
continue
dic = {"index_type": all_index_types[i]}
dic = {"index_type": all_index_types[i], "metric_type": "L2"}
dic.update(default_index_params[i])
index_params.append(dic)
return index_params
......@@ -668,16 +668,19 @@ def gen_binary_index():
def get_search_param(index_type):
search_params = {"metric_type": "L2"}
if index_type in ivf() or index_type in binary_support():
return {"nprobe": 32}
search_params.update({"nprobe": 32})
elif index_type == "HNSW":
return {"ef": 64}
search_params.update({"ef": 64})
elif index_type == "NSG":
return {"search_length": 100}
search_params.update({"search_length": 100})
elif index_type == "ANNOY":
return {"search_k": 100}
search_params.update({"search_k": 100})
else:
logging.getLogger().info("Invalid index_type.")
logging.getLogger().error("Invalid index_type.")
raise Exception("Invalid index_type.")
return search_params
def assert_equal_vector(v1, v2):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册