未验证 提交 cb74d0a3 编写于 作者: T ThreadDao 提交者: GitHub

Fix invalid ranges and delete logging (#3303)

* invalid ranges
Signed-off-by: Nzongyufen <zongyufen@foxmail.com>

* delete gen invalid ranges
Signed-off-by: Nzongyufen <zongyufen@foxmail.com>
上级 ebe06e3c
......@@ -225,7 +225,8 @@ class TestSearchBase:
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=search_metric_type, search_params=search_param)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=search_metric_type,
search_params=search_param)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
......@@ -376,9 +377,9 @@ class TestSearchBase:
assert res[0]._distances[0] > epsilon
assert res[1]._distances[0] < epsilon
#
#
# test for ip metric
#
#
@pytest.mark.level(2)
def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
'''
......@@ -658,7 +659,7 @@ class TestSearchBase:
def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
'''
target: search binary_collection, and check the result: distance
method: compare the return distance value with value computed with SUB
method: compare the return distance value with value computed with SUB
expected: the return distance equals to the computed value
'''
# from scipy.spatial import distance
......@@ -1108,7 +1109,6 @@ class TestSearchDSL(object):
term["term"].update({"a": [0]})
expr = {"must": [gen_default_vector_expr(default_query), term]}
query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
......@@ -1151,28 +1151,34 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.fixture(
scope="function",
params=gen_invalid_ranges()
)
def get_invalid_ranges(self, request):
return request.param
@pytest.mark.level(2)
def test_query_range_invalid_ranges(self, connect, collection, get_invalid_ranges):
def test_query_range_string_ranges(self, connect, collection):
'''
method: build query with invalid ranges
expected: raise Exception
'''
entities, ids = init_data(connect, collection)
ranges = get_invalid_ranges
ranges = {"GT": "0", "LT": "1000"}
range = gen_default_range_expr(ranges=ranges)
expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@pytest.mark.level(2)
def test_query_range_invalid_ranges(self, connect, collection):
'''
method: build query with invalid ranges
expected: 0
'''
entities, ids = init_data(connect, collection)
ranges = {"GT": nb, "LT": 0}
range = gen_default_range_expr(ranges=ranges)
expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res[0]) == 0
@pytest.fixture(
scope="function",
params=gen_valid_ranges()
......@@ -1180,7 +1186,6 @@ class TestSearchDSL(object):
def get_valid_ranges(self, request):
return request.param
# TODO:
@pytest.mark.level(2)
def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
'''
......@@ -1192,7 +1197,6 @@ class TestSearchDSL(object):
range = gen_default_range_expr(ranges=ranges)
expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
......@@ -1256,7 +1260,7 @@ class TestSearchDSL(object):
'''
entities, ids = init_data(connect, collection)
term_first = gen_default_term_expr()
term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(nb//2, nb)])
term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(nb // 2, nb)])
expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
......@@ -1271,12 +1275,11 @@ class TestSearchDSL(object):
expected: pass
'''
entities, ids = init_data(connect, collection)
term_first = {"int64": {"values": [i for i in range(nb//2)]}}
term_second = {"float": {"values": [float(i) for i in range(nb//2, nb)]}}
term_first = {"int64": {"values": [i for i in range(nb // 2)]}}
term_second = {"float": {"values": [float(i) for i in range(nb // 2, nb)]}}
term = update_term_expr({"term": {}}, [term_first, term_second])
expr = {"must": [gen_default_vector_expr(default_query), term]}
query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == 0
......@@ -1325,7 +1328,6 @@ class TestSearchDSL(object):
range_second = gen_default_range_expr(field="float", ranges={"GT": nb // 2, "LT": nb})
expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]}
query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == 0
......@@ -1343,7 +1345,6 @@ class TestSearchDSL(object):
range = update_range_expr({"range": {}}, [range_first, range_second])
expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == 0
......@@ -1362,7 +1363,7 @@ class TestSearchDSL(object):
expected: pass
'''
term = gen_default_term_expr()
range = gen_default_range_expr(ranges={"GT": -1, "LT": nb//2})
range = gen_default_range_expr(ranges={"GT": -1, "LT": nb // 2})
expr = {"must": [gen_default_vector_expr(default_query), term, range]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
......@@ -1383,7 +1384,6 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
"""
******************************************************************
# The following cases are used to build multi vectors query expr
......
......@@ -346,14 +346,6 @@ def gen_invalid_range():
return range
def gen_invalid_ranges():
ranges = [
{"GT": nb, "LT": 0},
{"GT": "0", "LT": "1000"}
]
return ranges
def gen_valid_ranges():
ranges = [
{"GT": 0, "LT": nb//2},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册