未验证 提交 d34be8bc 编写于 作者: D del-zhenwu 提交者: GitHub

Skip flat search params (#3381)

* assert top ids
Signed-off-by: Nzw <zw@milvus.io>

* update milvus-helm to 0.11.0
Signed-off-by: Nzw <zw@milvus.io>
Co-authored-by: Nzw <zw@milvus.io>
上级 75868b20
......@@ -28,7 +28,7 @@ pipeline {
LOWER_BUILD_TYPE = params.BUILD_TYPE.toLowerCase()
SEMVER = "${BRANCH_NAME.contains('/') ? BRANCH_NAME.substring(BRANCH_NAME.lastIndexOf('/') + 1) : BRANCH_NAME}"
PIPELINE_NAME = "milvus-ci"
HELM_BRANCH = "0.10.1"
HELM_BRANCH = "0.11.0"
}
stages {
stage ('Milvus Build and Unittest') {
......
......@@ -33,7 +33,7 @@ default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_
default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
def init_data(connect, collection, nb=6000, partition_tags=None):
def init_data(connect, collection, nb=6000, partition_tags=None, auto_id=True):
'''
Generate entities and add it in collection
'''
......@@ -43,9 +43,15 @@ def init_data(connect, collection, nb=6000, partition_tags=None):
else:
insert_entities = gen_entities(nb, is_normal=True)
if partition_tags is None:
ids = connect.insert(collection, insert_entities)
if auto_id:
ids = connect.insert(collection, insert_entities)
else:
ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
else:
ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
if auto_id:
ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
else:
ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
connect.flush([collection])
return insert_entities, ids
......@@ -532,7 +538,7 @@ class TestSearchBase:
res = connect.search(collection, query)
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
def test_search_distance_l2_after_index(self, connect, collection, get_simple_index):
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
'''
target: search collection, and check the result: distance
method: compare the return distance value with value computed with Inner product
......@@ -540,22 +546,25 @@ class TestSearchBase:
'''
index_type = get_simple_index["index_type"]
nq = 2
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
entities, ids = init_data(connect, id_collection, auto_id=False)
connect.create_index(id_collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
inside_vecs = entities[-1]["values"]
min_distance = 1.0
min_id = None
for i in range(nb):
tmp_dis = l2(vecs[0], inside_vecs[i])
if min_distance > tmp_dis:
min_distance = tmp_dis
res = connect.search(collection, query)
min_id = ids[i]
res = connect.search(id_collection, query)
tmp_epsilon = epsilon
check_id_result(res[0], min_id)
# if index_type in ["ANNOY", "IVF_PQ"]:
# tmp_epsilon = 0.1
# TODO:
if index_type in ["ANNOY", "IVF_PQ"]:
tmp_epsilon = 0.1
assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon
# assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon
@pytest.mark.level(2)
def test_search_distance_ip(self, connect, collection):
......@@ -576,7 +585,7 @@ class TestSearchBase:
res = connect.search(collection, query)
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
def test_search_distance_ip_after_index(self, connect, collection, get_simple_index):
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
'''
target: search collection, and check the result: distance
method: compare the return distance value with value computed with Inner product
......@@ -585,24 +594,27 @@ class TestSearchBase:
index_type = get_simple_index["index_type"]
nq = 2
metirc_type = "IP"
entities, ids = init_data(connect, collection)
entities, ids = init_data(connect, id_collection, auto_id=False)
get_simple_index["metric_type"] = metirc_type
connect.create_index(collection, field_name, get_simple_index)
connect.create_index(id_collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type,
search_params=search_param)
inside_vecs = entities[-1]["values"]
max_distance = 0
max_id = None
for i in range(nb):
tmp_dis = ip(vecs[0], inside_vecs[i])
if max_distance < tmp_dis:
max_distance = tmp_dis
res = connect.search(collection, query)
max_id = ids[i]
res = connect.search(id_collection, query)
tmp_epsilon = epsilon
check_id_result(res[0], max_id)
# if index_type in ["ANNOY", "IVF_PQ"]:
# tmp_epsilon = 0.1
# TODO:
if index_type in ["ANNOY", "IVF_PQ"]:
tmp_epsilon = 0.1
assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
# assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
'''
......@@ -1559,6 +1571,8 @@ class TestSearchInvalid(object):
'''
search_params = get_search_params
index_type = get_simple_index["index_type"]
if index_type in ["FLAT"]:
pytest.skip("skip in FLAT index")
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"])
......
......@@ -594,17 +594,9 @@ def gen_invalid_params():
-1,
# None,
[1, 2, 3],
(1, 2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"pip+",
"=c",
"中文"
]
return params
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册