diff --git a/tests/milvus_python_test/entity/test_search.py b/tests/milvus_python_test/entity/test_search.py index 1890b054286a2868043bc2203f91ad6087cc93bb..4a91ee861b1b1b597a58b9e77c96f5d6edff4d8e 100644 --- a/tests/milvus_python_test/entity/test_search.py +++ b/tests/milvus_python_test/entity/test_search.py @@ -225,7 +225,8 @@ class TestSearchBase: entities, ids = init_data(connect, collection) connect.create_index(collection, field_name, get_simple_index) search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=search_metric_type, search_params=search_param) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=search_metric_type, + search_params=search_param) res = connect.search(collection, query) assert len(res) == nq assert len(res[0]) == top_k @@ -376,9 +377,9 @@ class TestSearchBase: assert res[0]._distances[0] > epsilon assert res[1]._distances[0] < epsilon - # + # # test for ip metric - # + # @pytest.mark.level(2) def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -658,7 +659,7 @@ class TestSearchBase: def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection): ''' target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with SUB + method: compare the return distance value with value computed with SUB expected: the return distance equals to the computed value ''' # from scipy.spatial import distance @@ -1108,7 +1109,6 @@ class TestSearchDSL(object): term["term"].update({"a": [0]}) expr = {"must": [gen_default_vector_expr(default_query), term]} query = update_query_expr(default_query, expr=expr) - logging.getLogger().info(query) with pytest.raises(Exception) as e: res = connect.search(collection, query) @@ -1151,28 +1151,34 @@ class TestSearchDSL(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) - @pytest.fixture( - scope="function", - params=gen_invalid_ranges() - ) - def get_invalid_ranges(self, request): - return request.param - @pytest.mark.level(2) - def test_query_range_invalid_ranges(self, connect, collection, get_invalid_ranges): + def test_query_range_string_ranges(self, connect, collection): ''' method: build query with invalid ranges expected: raise Exception ''' entities, ids = init_data(connect, collection) - ranges = get_invalid_ranges + ranges = {"GT": "0", "LT": "1000"} range = gen_default_range_expr(ranges=ranges) expr = {"must": [gen_default_vector_expr(default_query), range]} query = update_query_expr(default_query, expr=expr) - logging.getLogger().info(query) with pytest.raises(Exception) as e: res = connect.search(collection, query) + @pytest.mark.level(2) + def test_query_range_invalid_ranges(self, connect, collection): + ''' + method: build query with invalid ranges + expected: 0 + ''' + entities, ids = init_data(connect, collection) + ranges = {"GT": nb, "LT": 0} + range = gen_default_range_expr(ranges=ranges) + expr = {"must": [gen_default_vector_expr(default_query), range]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res[0]) == 0 + @pytest.fixture( scope="function", params=gen_valid_ranges() @@ -1180,7 +1186,6 @@ class TestSearchDSL(object): def get_valid_ranges(self, request): return request.param - # TODO: @pytest.mark.level(2) def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges): ''' @@ -1192,7 +1197,6 @@ class TestSearchDSL(object): range = gen_default_range_expr(ranges=ranges) expr = {"must": [gen_default_vector_expr(default_query), range]} query = update_query_expr(default_query, expr=expr) - logging.getLogger().info(query) res = connect.search(collection, query) assert len(res) == nq assert len(res[0]) == top_k @@ -1256,7 +1260,7 @@ class TestSearchDSL(object): ''' entities, ids = init_data(connect, collection) term_first = gen_default_term_expr() - term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(nb//2, nb)]) + term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(nb // 2, nb)]) expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} query = update_query_expr(default_query, expr=expr) res = connect.search(collection, query) @@ -1271,12 +1275,11 @@ class TestSearchDSL(object): expected: pass ''' entities, ids = init_data(connect, collection) - term_first = {"int64": {"values": [i for i in range(nb//2)]}} - term_second = {"float": {"values": [float(i) for i in range(nb//2, nb)]}} + term_first = {"int64": {"values": [i for i in range(nb // 2)]}} + term_second = {"float": {"values": [float(i) for i in range(nb // 2, nb)]}} term = update_term_expr({"term": {}}, [term_first, term_second]) expr = {"must": [gen_default_vector_expr(default_query), term]} query = update_query_expr(default_query, expr=expr) - logging.getLogger().info(query) res = connect.search(collection, query) assert len(res) == nq assert len(res[0]) == 0 @@ -1325,7 +1328,6 @@ class TestSearchDSL(object): range_second = gen_default_range_expr(field="float", ranges={"GT": nb // 2, "LT": nb}) expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]} query = update_query_expr(default_query, expr=expr) - logging.getLogger().info(query) res = connect.search(collection, query) assert len(res) == nq assert len(res[0]) == 0 @@ -1343,7 +1345,6 @@ class TestSearchDSL(object): range = update_range_expr({"range": {}}, [range_first, range_second]) expr = {"must": [gen_default_vector_expr(default_query), range]} query = update_query_expr(default_query, expr=expr) - logging.getLogger().info(query) res = connect.search(collection, query) assert len(res) == nq assert len(res[0]) == 0 @@ -1362,7 +1363,7 @@ class TestSearchDSL(object): expected: pass ''' term = gen_default_term_expr() - range = gen_default_range_expr(ranges={"GT": -1, "LT": nb//2}) + range = gen_default_range_expr(ranges={"GT": -1, "LT": nb // 2}) expr = {"must": [gen_default_vector_expr(default_query), term, range]} query = update_query_expr(default_query, expr=expr) res = connect.search(collection, query) @@ -1383,7 +1384,6 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == 0 - """ ****************************************************************** # The following cases are used to build multi vectors query expr diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py index 1701ba7cd3755473e1b0725ab6273f3ae534c989..0b26a72b8e4674734d8229b8f4f64b7d0e3f01ac 100644 --- a/tests/milvus_python_test/utils.py +++ b/tests/milvus_python_test/utils.py @@ -346,14 +346,6 @@ def gen_invalid_range(): return range -def gen_invalid_ranges(): - ranges = [ - {"GT": nb, "LT": 0}, - {"GT": "0", "LT": "1000"} - ] - return ranges - - def gen_valid_ranges(): ranges = [ {"GT": 0, "LT": nb//2},