未验证 提交 ec22185f 编写于 作者: D del-zhenwu 提交者: GitHub

Support output vectors in query (#6636)

* support output vectors in query
Signed-off-by: Nzhenwu <zhenxiang.li@zilliz.com>

* add query case
Signed-off-by: Ndel-zhenwu <zhenxiang.li@zilliz.com>

* add query case
Signed-off-by: Ndel-zhenwu <zhenxiang.li@zilliz.com>
上级 ac50c5dd
......@@ -78,6 +78,7 @@ class TestQueryBase:
def get_simple_index(self, request, connect):
return request.param
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_invalid(self, connect, collection):
"""
target: test query
......@@ -91,6 +92,7 @@ class TestQueryBase:
with pytest.raises(Exception):
res = connect.query(collection, term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_valid(self, connect, collection):
"""
target: test query
......@@ -101,35 +103,39 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name])
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_collection_not_existed(self, connect):
"""
target: test query not existed collection
method: query not existed collection
expected: raise exception
"""
ex_msg = 'find collection'
collection = "not_exist"
with pytest.raises(Exception, match=ex_msg):
with pytest.raises(Exception):
connect.query(collection, default_term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_without_connect(self, dis_connect, collection):
"""
target: test query without connection
method: close connect and query
expected: raise exception
"""
ex_msg = 'NoneType'
with pytest.raises(Exception, match=ex_msg):
with pytest.raises(Exception):
dis_connect.query(collection, default_term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_invalid_collection_name(self, connect, get_collection_name):
"""
target: test query with invalid collection name
......@@ -140,6 +146,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection_name, default_term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_after_index(self, connect, collection, get_simple_index):
"""
target: test query after creating index
......@@ -151,13 +158,13 @@ class TestQueryBase:
connect.create_index(collection, ut.default_float_vec_field_name, get_simple_index)
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
logging.getLogger().info(res)
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# # ut.assert_equal_vector(res[i][ut.default_float_vec_field_name], entities[-1]["values"][i])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[-1]["values"][index])
def test_query_after_search(self, connect, collection):
"""
......@@ -175,15 +182,15 @@ class TestQueryBase:
assert len(search_res) == nq
assert len(search_res[0]) == top_k
term_expr = f'{default_int_field_name} in {ids[:default_pos]}'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
logging.getLogger().info(res)
assert len(res) == default_pos
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_empty_collection(self, connect, collection):
"""
target: test query empty collection
......@@ -195,6 +202,7 @@ class TestQueryBase:
logging.getLogger().info(res)
assert len(res) == 0
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_without_loading(self, connect, collection):
"""
target: test query without loading
......@@ -206,6 +214,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, default_term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_collection_not_primary_key(self, connect, collection):
"""
target: test query on collection that not on the primary field
......@@ -219,6 +228,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, term_expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_none(self, connect, collection):
"""
target: test query with none expr
......@@ -231,6 +241,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, None)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
@pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()])
def test_query_expr_invalid_string(self, connect, collection, expr):
"""
......@@ -244,6 +255,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_not_existed_field(self, connect, collection):
"""
target: test query with not existed field
......@@ -260,6 +272,7 @@ class TestQueryBase:
@pytest.mark.parametrize("expr", [f'{default_int_field_name} inn [1, 2]',
f'{default_int_field_name} not in [1, 2]',
f'{default_int_field_name} in not [1, 2]'])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_wrong_term_keyword(self, connect, collection, expr):
"""
target: test query with wrong term expr keyword
......@@ -273,6 +286,7 @@ class TestQueryBase:
@pytest.mark.parametrize("expr", [f'{default_int_field_name} in 1',
f'{default_int_field_name} in "in"',
f'{default_int_field_name} in (mn)'])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_non_array_term(self, connect, collection, expr):
"""
target: test query with non-array term expr
......@@ -283,6 +297,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_empty_term_array(self, connect, collection):
"""
target: test query with empty array term expr
......@@ -296,6 +311,7 @@ class TestQueryBase:
res = connect.query(collection, term_expr)
assert len(res) == 0
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_single_term_array(self, connect, collection):
"""
target: test query with single array term expr
......@@ -306,14 +322,14 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in [0]'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
assert len(res) == 1
assert res[0][default_int_field_name] == entities[0]["values"][0]
assert res[0][default_float_field_name] == entities[1]["values"][0]
# not support
# ut.assert_equal_vector(res[0][ut.default_float_vec_field_name], entities[2]["values"][0])
ut.assert_equal_vector(res[0][ut.default_float_vec_field_name], entities[2]["values"][0])
@pytest.mark.xfail(reason="#6072")
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_binary_expr_single_term_array(self, connect, binary_collection):
"""
target: test query with single array term expr
......@@ -324,14 +340,14 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(binary_collection)
term_expr = f'{default_int_field_name} in [0]'
res = connect.query(binary_collection, term_expr)
res = connect.query(binary_collection, term_expr, output_fields=["*", "%"])
assert len(res) == 1
assert res[0][default_int_field_name] == binary_entities[0]["values"][0]
assert res[1][default_float_field_name] == binary_entities[1]["values"][0]
# not support
# assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0]
assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0]
@pytest.mark.level(2)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_all_term_array(self, connect, collection):
"""
target: test query with all array term expr
......@@ -342,14 +358,14 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(collection)
term_expr = f'{default_int_field_name} in {ids}'
res = connect.query(collection, term_expr)
res = connect.query(collection, term_expr, output_fields=["*", "%"])
assert len(res) == ut.default_nb
for _id, index in enumerate(ids):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_repeated_term_array(self, connect, collection):
"""
target: test query with repeated term array on primary field with unique value
......@@ -364,6 +380,7 @@ class TestQueryBase:
res = connect.query(collection, term_expr)
assert len(res) == 2
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_inconstant_term_array(self, connect, collection):
"""
target: test query with term expr that field and array are inconsistent
......@@ -377,6 +394,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_mix_term_array(self, connect, collection):
"""
target: test query with mix type value expr
......@@ -391,6 +409,7 @@ class TestQueryBase:
connect.query(collection, expr)
@pytest.mark.parametrize("constant", [[1], (), {}])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_expr_non_constant_array_term(self, connect, collection, constant):
"""
target: test query with non-constant array term expr
......@@ -404,6 +423,7 @@ class TestQueryBase:
with pytest.raises(Exception):
connect.query(collection, expr)
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_field_empty(self, connect, collection):
"""
target: test query with none output field
......@@ -414,11 +434,11 @@ class TestQueryBase:
assert len(ids) == ut.default_nb
connect.load_collection(collection)
res = connect.query(collection, default_term_expr, output_fields=[])
# not support float_vector
fields = [default_int_field_name, default_float_field_name]
for field in fields:
assert field in res[0].keys()
assert default_int_field_name in res[0].keys()
assert default_float_field_name not in res[0].keys()
assert ut.default_float_vec_field_name not in res[0].keys()
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_one_field(self, connect, collection):
"""
target: test query with output one field
......@@ -432,6 +452,7 @@ class TestQueryBase:
assert default_int_field_name in res[0].keys()
assert len(res[0].keys()) == 1
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_all_fields(self, connect, collection):
"""
target: test query with none output field
......@@ -447,6 +468,7 @@ class TestQueryBase:
for field in fields:
assert field in res[0].keys()
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_not_existed_field(self, connect, collection):
"""
target: test query output not existed field
......@@ -459,6 +481,7 @@ class TestQueryBase:
connect.query(collection, default_term_expr, output_fields=["int"])
# @pytest.mark.xfail(reason="#6074")
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_output_part_not_existed_field(self, connect, collection):
"""
target: test query output part not existed field
......@@ -471,6 +494,7 @@ class TestQueryBase:
connect.query(collection, default_term_expr, output_fields=[default_int_field_name, "int"])
@pytest.mark.parametrize("fields", ut.gen_invalid_strs())
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_invalid_output_fields(self, connect, collection, fields):
"""
target: test query with invalid output fields
......@@ -488,6 +512,7 @@ class TestQueryPartition:
test Query interface
query(collection_name, expr, output_fields=None, partition_names=None, timeout=None)
"""
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_partition(self, connect, collection):
"""
target: test query on partition
......@@ -498,13 +523,13 @@ class TestQueryPartition:
entities, ids = init_data(connect, collection, partition_names=ut.default_tag)
assert len(ids) == ut.default_nb
connect.load_partitions(collection, [ut.default_tag])
res = connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
res = connect.query(collection, default_term_expr, partition_names=[ut.default_tag], output_fields=["*", "%"])
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_partition_without_loading(self, connect, collection):
"""
target: test query on partition without loading
......@@ -517,6 +542,7 @@ class TestQueryPartition:
with pytest.raises(Exception):
connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_default_partition(self, connect, collection):
"""
target: test query on default partition
......@@ -526,12 +552,11 @@ class TestQueryPartition:
entities, ids = init_data(connect, collection)
assert len(ids) == ut.default_nb
connect.load_collection(collection)
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name])
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name], output_fields=["*", "%"])
for _id, index in enumerate(ids[:default_pos]):
if res[index][default_int_field_name] == entities[0]["values"][index]:
assert res[index][default_float_field_name] == entities[1]["values"][index]
# not support
# ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index])
@pytest.mark.xfail(reason="#6075")
def test_query_empty_partition(self, connect, collection):
......@@ -545,6 +570,7 @@ class TestQueryPartition:
res = connect.query(collection, default_term_expr, partition_names=[ut.default_partition_name])
assert len(res) == 0
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_not_existed_partition(self, connect, collection):
"""
target: test query on a not existed partition
......@@ -556,6 +582,7 @@ class TestQueryPartition:
with pytest.raises(Exception):
connect.query(collection, default_term_expr, partition_names=[tag])
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_partition_repeatedly(self, connect, collection):
"""
target: test query repeatedly on partition
......@@ -570,6 +597,7 @@ class TestQueryPartition:
res_two = connect.query(collection, default_term_expr, partition_names=[ut.default_tag])
assert res_one == res_two
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_another_partition(self, connect, collection):
"""
target: test query another partition
......@@ -583,6 +611,7 @@ class TestQueryPartition:
res = connect.query(collection, term_expr, partition_names=[ut.default_tag])
assert len(res) == 0
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_multi_partitions_multi_results(self, connect, collection):
"""
target: test query on multi partitions and get multi results
......@@ -601,6 +630,7 @@ class TestQueryPartition:
assert len(res) == 1
assert res[0][default_int_field_name] == entities_2[0]["values"][0]
@pytest.mark.tags(ut.CaseLabel.tags_smoke)
def test_query_multi_partitions_single_result(self, connect, collection):
"""
target: test query on multi partitions and get single result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册