未验证 提交 cb74d0a3 编写于 作者: T ThreadDao 提交者: GitHub

Fix invalid ranges and delete logging (#3303)

* invalid ranges
Signed-off-by: Nzongyufen <zongyufen@foxmail.com>

* delete gen invalid ranges
Signed-off-by: Nzongyufen <zongyufen@foxmail.com>
上级 ebe06e3c
...@@ -225,7 +225,8 @@ class TestSearchBase: ...@@ -225,7 +225,8 @@ class TestSearchBase:
entities, ids = init_data(connect, collection) entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index) connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type) search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=search_metric_type, search_params=search_param) query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=search_metric_type,
search_params=search_param)
res = connect.search(collection, query) res = connect.search(collection, query)
assert len(res) == nq assert len(res) == nq
assert len(res[0]) == top_k assert len(res[0]) == top_k
...@@ -376,9 +377,9 @@ class TestSearchBase: ...@@ -376,9 +377,9 @@ class TestSearchBase:
assert res[0]._distances[0] > epsilon assert res[0]._distances[0] > epsilon
assert res[1]._distances[0] < epsilon assert res[1]._distances[0] < epsilon
# #
# test for ip metric # test for ip metric
# #
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
''' '''
...@@ -658,7 +659,7 @@ class TestSearchBase: ...@@ -658,7 +659,7 @@ class TestSearchBase:
def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection): def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
''' '''
target: search binary_collection, and check the result: distance target: search binary_collection, and check the result: distance
method: compare the return distance value with value computed with SUB method: compare the return distance value with value computed with SUB
expected: the return distance equals to the computed value expected: the return distance equals to the computed value
''' '''
# from scipy.spatial import distance # from scipy.spatial import distance
...@@ -1108,7 +1109,6 @@ class TestSearchDSL(object): ...@@ -1108,7 +1109,6 @@ class TestSearchDSL(object):
term["term"].update({"a": [0]}) term["term"].update({"a": [0]})
expr = {"must": [gen_default_vector_expr(default_query), term]} expr = {"must": [gen_default_vector_expr(default_query), term]}
query = update_query_expr(default_query, expr=expr) query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
res = connect.search(collection, query) res = connect.search(collection, query)
...@@ -1151,28 +1151,34 @@ class TestSearchDSL(object): ...@@ -1151,28 +1151,34 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
res = connect.search(collection, query) res = connect.search(collection, query)
@pytest.fixture(
scope="function",
params=gen_invalid_ranges()
)
def get_invalid_ranges(self, request):
return request.param
@pytest.mark.level(2) @pytest.mark.level(2)
def test_query_range_invalid_ranges(self, connect, collection, get_invalid_ranges): def test_query_range_string_ranges(self, connect, collection):
''' '''
method: build query with invalid ranges method: build query with invalid ranges
expected: raise Exception expected: raise Exception
''' '''
entities, ids = init_data(connect, collection) entities, ids = init_data(connect, collection)
ranges = get_invalid_ranges ranges = {"GT": "0", "LT": "1000"}
range = gen_default_range_expr(ranges=ranges) range = gen_default_range_expr(ranges=ranges)
expr = {"must": [gen_default_vector_expr(default_query), range]} expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr) query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
res = connect.search(collection, query) res = connect.search(collection, query)
@pytest.mark.level(2)
def test_query_range_invalid_ranges(self, connect, collection):
'''
method: build query with invalid ranges
expected: 0
'''
entities, ids = init_data(connect, collection)
ranges = {"GT": nb, "LT": 0}
range = gen_default_range_expr(ranges=ranges)
expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res[0]) == 0
@pytest.fixture( @pytest.fixture(
scope="function", scope="function",
params=gen_valid_ranges() params=gen_valid_ranges()
...@@ -1180,7 +1186,6 @@ class TestSearchDSL(object): ...@@ -1180,7 +1186,6 @@ class TestSearchDSL(object):
def get_valid_ranges(self, request): def get_valid_ranges(self, request):
return request.param return request.param
# TODO:
@pytest.mark.level(2) @pytest.mark.level(2)
def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges): def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
''' '''
...@@ -1192,7 +1197,6 @@ class TestSearchDSL(object): ...@@ -1192,7 +1197,6 @@ class TestSearchDSL(object):
range = gen_default_range_expr(ranges=ranges) range = gen_default_range_expr(ranges=ranges)
expr = {"must": [gen_default_vector_expr(default_query), range]} expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr) query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
res = connect.search(collection, query) res = connect.search(collection, query)
assert len(res) == nq assert len(res) == nq
assert len(res[0]) == top_k assert len(res[0]) == top_k
...@@ -1256,7 +1260,7 @@ class TestSearchDSL(object): ...@@ -1256,7 +1260,7 @@ class TestSearchDSL(object):
''' '''
entities, ids = init_data(connect, collection) entities, ids = init_data(connect, collection)
term_first = gen_default_term_expr() term_first = gen_default_term_expr()
term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(nb//2, nb)]) term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(nb // 2, nb)])
expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
query = update_query_expr(default_query, expr=expr) query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query) res = connect.search(collection, query)
...@@ -1271,12 +1275,11 @@ class TestSearchDSL(object): ...@@ -1271,12 +1275,11 @@ class TestSearchDSL(object):
expected: pass expected: pass
''' '''
entities, ids = init_data(connect, collection) entities, ids = init_data(connect, collection)
term_first = {"int64": {"values": [i for i in range(nb//2)]}} term_first = {"int64": {"values": [i for i in range(nb // 2)]}}
term_second = {"float": {"values": [float(i) for i in range(nb//2, nb)]}} term_second = {"float": {"values": [float(i) for i in range(nb // 2, nb)]}}
term = update_term_expr({"term": {}}, [term_first, term_second]) term = update_term_expr({"term": {}}, [term_first, term_second])
expr = {"must": [gen_default_vector_expr(default_query), term]} expr = {"must": [gen_default_vector_expr(default_query), term]}
query = update_query_expr(default_query, expr=expr) query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
res = connect.search(collection, query) res = connect.search(collection, query)
assert len(res) == nq assert len(res) == nq
assert len(res[0]) == 0 assert len(res[0]) == 0
...@@ -1325,7 +1328,6 @@ class TestSearchDSL(object): ...@@ -1325,7 +1328,6 @@ class TestSearchDSL(object):
range_second = gen_default_range_expr(field="float", ranges={"GT": nb // 2, "LT": nb}) range_second = gen_default_range_expr(field="float", ranges={"GT": nb // 2, "LT": nb})
expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]} expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]}
query = update_query_expr(default_query, expr=expr) query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
res = connect.search(collection, query) res = connect.search(collection, query)
assert len(res) == nq assert len(res) == nq
assert len(res[0]) == 0 assert len(res[0]) == 0
...@@ -1343,7 +1345,6 @@ class TestSearchDSL(object): ...@@ -1343,7 +1345,6 @@ class TestSearchDSL(object):
range = update_range_expr({"range": {}}, [range_first, range_second]) range = update_range_expr({"range": {}}, [range_first, range_second])
expr = {"must": [gen_default_vector_expr(default_query), range]} expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr) query = update_query_expr(default_query, expr=expr)
logging.getLogger().info(query)
res = connect.search(collection, query) res = connect.search(collection, query)
assert len(res) == nq assert len(res) == nq
assert len(res[0]) == 0 assert len(res[0]) == 0
...@@ -1362,7 +1363,7 @@ class TestSearchDSL(object): ...@@ -1362,7 +1363,7 @@ class TestSearchDSL(object):
expected: pass expected: pass
''' '''
term = gen_default_term_expr() term = gen_default_term_expr()
range = gen_default_range_expr(ranges={"GT": -1, "LT": nb//2}) range = gen_default_range_expr(ranges={"GT": -1, "LT": nb // 2})
expr = {"must": [gen_default_vector_expr(default_query), term, range]} expr = {"must": [gen_default_vector_expr(default_query), term, range]}
query = update_query_expr(default_query, expr=expr) query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query) res = connect.search(collection, query)
...@@ -1383,7 +1384,6 @@ class TestSearchDSL(object): ...@@ -1383,7 +1384,6 @@ class TestSearchDSL(object):
assert len(res) == nq assert len(res) == nq
assert len(res[0]) == 0 assert len(res[0]) == 0
""" """
****************************************************************** ******************************************************************
# The following cases are used to build multi vectors query expr # The following cases are used to build multi vectors query expr
......
...@@ -346,14 +346,6 @@ def gen_invalid_range(): ...@@ -346,14 +346,6 @@ def gen_invalid_range():
return range return range
def gen_invalid_ranges():
ranges = [
{"GT": nb, "LT": 0},
{"GT": "0", "LT": "1000"}
]
return ranges
def gen_valid_ranges(): def gen_valid_ranges():
ranges = [ ranges = [
{"GT": 0, "LT": nb//2}, {"GT": 0, "LT": nb//2},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册