未验证 提交 5543dba8 编写于 作者: T ThreadDao 提交者: GitHub

Update query cases and query result check (#6146)

* update query cases

check query result
Signed-off-by: NThreadDao <yufen.zong@zilliz.com>

* [skip ci] skip ci
Signed-off-by: NThreadDao <yufen.zong@zilliz.com>
上级 6d7edc60
...@@ -186,25 +186,7 @@ class ResponseChecker: ...@@ -186,25 +186,7 @@ class ResponseChecker:
raise Exception("No expect values found in the check task") raise Exception("No expect values found in the check task")
exp_res = check_items.get("exp_res", None) exp_res = check_items.get("exp_res", None)
if exp_res and isinstance(query_res, list): if exp_res and isinstance(query_res, list):
# assert exp_res == query_res assert pc.equal_entities_list(exp=exp_res, actual=query_res)
assert len(exp_res) == len(query_res) # assert len(exp_res) == len(query_res)
for i in range(len(exp_res)): # for i in range(len(exp_res)):
assert_entity_equal(exp=exp_res[i], actual=query_res[i]) # assert_entity_equal(exp=exp_res[i], actual=query_res[i])
def assert_entity_equal(exp, actual):
"""
compare two entities
{"int64": 0, "float": 0.0, "float_vec": [0.09111554112502457, ..., 0.08652634258062468]}
:param exp: exp entity
:param actual: actual entity
:return: bool
"""
assert actual.keys() == exp.keys()
for field, value in exp.items():
if isinstance(value, list):
assert len(actual[field]) == len(exp[field])
for i in range(len(exp[field])):
assert abs(actual[field][i] - exp[field][i]) < ct.epsilon
else:
assert actual[field] == exp[field]
import pytest
import sys import sys
import operator import operator
from common import common_type as ct
sys.path.append("..") sys.path.append("..")
from utils.util_log import test_log as log from utils.util_log import test_log as log
...@@ -63,7 +63,8 @@ def list_de_duplication(_list): ...@@ -63,7 +63,8 @@ def list_de_duplication(_list):
# Keep the order of the elements unchanged # Keep the order of the elements unchanged
result.sort(key=_list.index) result.sort(key=_list.index)
log.debug("[LIST_DE_DUPLICATION] %s after removing the duplicate elements, the list becomes %s" % (str(_list), str(result))) log.debug("[LIST_DE_DUPLICATION] %s after removing the duplicate elements, the list becomes %s" % (
str(_list), str(result)))
return result return result
...@@ -116,3 +117,86 @@ def get_connect_object_name(_list): ...@@ -116,3 +117,86 @@ def get_connect_object_name(_list):
log.debug("[GET_CONNECT_OBJECT_NAME] list:%s is reset to list:%s" % (str(_list), str(new_list))) log.debug("[GET_CONNECT_OBJECT_NAME] list:%s is reset to list:%s" % (str(_list), str(new_list)))
return new_list return new_list
def equal_entity(exp, actual):
"""
compare two entities containing vector field
{"int64": 0, "float": 0.0, "float_vec": [0.09111554112502457, ..., 0.08652634258062468]}
:param exp: exp entity
:param actual: actual entity
:return: bool
"""
assert actual.keys() == exp.keys()
for field, value in exp.items():
if isinstance(value, list):
assert len(actual[field]) == len(exp[field])
for i in range(len(exp[field])):
assert abs(actual[field][i] - exp[field][i]) < ct.epsilon
else:
assert actual[field] == exp[field]
def entity_in(entity, entities, primary_field=ct.default_int64_field_name):
"""
according to the primary key to judge entity in the entities list
:param entity: dict
{"int": 0, "vec": [0.999999, 0.111111]}
:param entities: list of dict
[{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
:param primary_field: collection primary field
:return: True or False
"""
primary_key = entity.get(primary_field, None)
primary_keys = []
for e in entities:
primary_keys.append(e[primary_field])
if primary_key not in primary_keys:
return False
index = primary_key.index(primary_key)
return equal_entity(entities[index], entity)
def remove_entity(entity, entities, primary_field=ct.default_int64_field_name):
"""
according to the primary key to remove an entity from an entities list
:param entity: dict
{"int": 0, "vec": [0.999999, 0.111111]}
:param entities: list of dict
[{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
:param primary_field: collection primary field
:return: entities of removed entity
"""
primary_key = entity.get(primary_field, None)
primary_keys = []
for e in entities:
primary_keys.append(e[primary_field])
index = primary_keys.index(primary_key)
entities.pop(index)
return entities
def equal_entities_list(exp, actual):
"""
compare two entities lists in inconsistent order
:param exp: exp entities list, list of dict
:param actual: actual entities list, list of dict
:return: True or False
example:
exp = [{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
actual = [{"int": 1, "vec": [0.888888, 0.222222]}, {"int": 0, "vec": [0.999999, 0.111111]}]
exp = actual
"""
if len(exp) != len(actual):
return False
for a in actual:
# if vec field returned in query res
# if entity_in_entities(a, exp):
if a in exp:
try:
exp.remove(a)
# if vec field returned in query res
# remove_entity(a, exp)
except Exception as ex:
print(ex)
return True if len(exp) == 0 else False
...@@ -31,11 +31,10 @@ class TestQueryBase(TestcaseBase): ...@@ -31,11 +31,10 @@ class TestQueryBase(TestcaseBase):
int_values = vectors[0][ct.default_int64_field_name].values.tolist() int_values = vectors[0][ct.default_int64_field_name].values.tolist()
pos = 5 pos = 5
term_expr = f'{ct.default_int64_field_name} in {int_values[:pos]}' term_expr = f'{ct.default_int64_field_name} in {int_values[:pos]}'
res = vectors[0].iloc[0:pos].to_dict('records') res = vectors[0].iloc[0:pos, :2].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="#6028")
def test_query_empty_collection(self): def test_query_empty_collection(self):
""" """
target: test query empty collection target: test query empty collection
...@@ -85,7 +84,6 @@ class TestQueryBase(TestcaseBase): ...@@ -85,7 +84,6 @@ class TestQueryBase(TestcaseBase):
assert len(res) == 0 assert len(res) == 0
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="#6033")
def test_query_expr_none(self): def test_query_expr_none(self):
""" """
target: test query with none expr target: test query with none expr
...@@ -93,11 +91,10 @@ class TestQueryBase(TestcaseBase): ...@@ -93,11 +91,10 @@ class TestQueryBase(TestcaseBase):
expected: raise exception expected: raise exception
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: "invalid expr"} error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"}
collection_w.query(None, check_task=CheckTasks.err_res, check_items=error) collection_w.query(None, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="#6044")
@pytest.mark.parametrize("expr", [1, 2., [], {}, ()]) @pytest.mark.parametrize("expr", [1, 2., [], {}, ()])
def test_query_expr_non_string(self, expr): def test_query_expr_non_string(self, expr):
""" """
...@@ -106,7 +103,7 @@ class TestQueryBase(TestcaseBase): ...@@ -106,7 +103,7 @@ class TestQueryBase(TestcaseBase):
expected: raise exception expected: raise exception
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: "expr must string type"} error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"}
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
...@@ -130,7 +127,7 @@ class TestQueryBase(TestcaseBase): ...@@ -130,7 +127,7 @@ class TestQueryBase(TestcaseBase):
expected: query result is correct expected: query result is correct
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res = vectors[0].iloc[0:2].to_dict('records') res = vectors[0].iloc[:2, :2].to_dict('records')
collection_w.query(default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) collection_w.query(default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
...@@ -214,7 +211,7 @@ class TestQueryBase(TestcaseBase): ...@@ -214,7 +211,7 @@ class TestQueryBase(TestcaseBase):
""" """
target: test query with empty array term expr target: test query with empty array term expr
method: query with empty term expr method: query with empty term expr
expected: empty rsult expected: empty result
""" """
term_expr = f'{ct.default_int64_field_name} in []' term_expr = f'{ct.default_int64_field_name} in []'
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
...@@ -270,8 +267,8 @@ class TestQueryBase(TestcaseBase): ...@@ -270,8 +267,8 @@ class TestQueryBase(TestcaseBase):
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res, _ = collection_w.query(default_term_expr, output_fields=None) res, _ = collection_w.query(default_term_expr, output_fields=None)
fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_float_vec_field_name] fields = [ct.default_int64_field_name, ct.default_float_field_name]
assert list(res[0].keys()) == fields assert set(res[0].keys()) == set(fields)
@pytest.mark.tags(CaseLabel.L0) @pytest.mark.tags(CaseLabel.L0)
def test_query_output_one_field(self): def test_query_output_one_field(self):
...@@ -282,7 +279,7 @@ class TestQueryBase(TestcaseBase): ...@@ -282,7 +279,7 @@ class TestQueryBase(TestcaseBase):
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name]) res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
assert list(res[0].keys()) == [ct.default_int64_field_name] assert set(res[0].keys()) == set([ct.default_int64_field_name])
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_output_all_fields(self): def test_query_output_all_fields(self):
...@@ -292,36 +289,51 @@ class TestQueryBase(TestcaseBase): ...@@ -292,36 +289,51 @@ class TestQueryBase(TestcaseBase):
expected: return all fields expected: return all fields
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_float_vec_field_name] fields = [ct.default_int64_field_name, ct.default_float_field_name]
log.debug(collection_w.num_entities)
res, _ = collection_w.query(default_term_expr, output_fields=fields) res, _ = collection_w.query(default_term_expr, output_fields=fields)
log.debug(res) assert set(res[0].keys()) == set(fields)
assert list(res[0].keys()) == fields res_1, _ = collection_w.query(default_term_expr, output_fields=[ct.default_float_field_name])
assert set(res_1[0].keys()) == set(fields)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6049") @pytest.mark.xfail(reason="issue #6143")
def test_query_output_not_existed_field(self): @pytest.mark.parametrize("output_fields", [[ct.default_float_vec_field_name],
[ct.default_int64_field_name, ct.default_float_vec_field_name]])
def test_query_output_vec_field(self, output_fields):
""" """
target: test query output not existed field target: test query with vec output field
method: query with not existed output field method: specify vec field as output field
expected: raise exception expected: raise exception
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: 'cannot find field'} error = {ct.err_code: 1, ct.err_msg: "unsupported leaf node"}
collection_w.query(default_term_expr, output_fields=["int"], check_items=CheckTasks.err_res, check_task=error) collection_w.query(default_term_expr, output_fields=output_fields,
check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6049") def test_query_output_primary_field(self):
def test_query_output_part_not_existed_field(self):
""" """
target: test query output part not existed field target: test query with output field only primary field
method: query with part not existed field method: specify int64 primary field as output field
expected: return int64 field
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
assert list(res[0].keys()) == [ct.default_int64_field_name]
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6074")
@pytest.mark.parametrize("output_fields", [["int"],
[ct.default_int64_field_name, "int"]])
def test_query_output_not_existed_field(self, output_fields):
"""
target: test query output not existed field
method: query with not existed output field
expected: raise exception expected: raise exception
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: 'cannot find field'} error = {ct.err_code: 1, ct.err_msg: 'cannot find field'}
fields = [ct.default_int64_field_name, "int"] collection_w.query(default_term_expr, output_fields=output_fields, check_items=CheckTasks.err_res, check_task=error)
collection_w.query(default_term_expr, output_fields=fields, check_items=CheckTasks.err_res, check_task=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_empty_output_fields(self): def test_query_empty_output_fields(self):
...@@ -332,10 +344,9 @@ class TestQueryBase(TestcaseBase): ...@@ -332,10 +344,9 @@ class TestQueryBase(TestcaseBase):
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
query_res, _ = collection_w.query(default_term_expr, output_fields=[]) query_res, _ = collection_w.query(default_term_expr, output_fields=[])
fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_float_vec_field_name] fields = [ct.default_int64_field_name, ct.default_float_field_name]
assert list(query_res[0].keys()) == fields assert list(query_res[0].keys()) == fields
@pytest.mark.xfail(reason="issue #6056")
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("fields", ct.get_invalid_strs) @pytest.mark.parametrize("fields", ct.get_invalid_strs)
def test_query_invalid_output_fields(self, fields): def test_query_invalid_output_fields(self, fields):
...@@ -345,7 +356,7 @@ class TestQueryBase(TestcaseBase): ...@@ -345,7 +356,7 @@ class TestQueryBase(TestcaseBase):
expected: raise exception expected: raise exception
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: 'invalid output fields'} error = {ct.err_code: 1, ct.err_msg: 'Invalid query format. must be a list'}
collection_w.query(default_term_expr, output_fields=fields, check_items=CheckTasks.err_res, check_task=error) collection_w.query(default_term_expr, output_fields=fields, check_items=CheckTasks.err_res, check_task=error)
@pytest.mark.tags(CaseLabel.L0) @pytest.mark.tags(CaseLabel.L0)
...@@ -361,12 +372,11 @@ class TestQueryBase(TestcaseBase): ...@@ -361,12 +372,11 @@ class TestQueryBase(TestcaseBase):
partition_w.insert(df) partition_w.insert(df)
assert collection_w.num_entities == ct.default_nb assert collection_w.num_entities == ct.default_nb
partition_w.load() partition_w.load()
res = df.iloc[0:2].to_dict('records') res = df.iloc[:2, :2].to_dict('records')
collection_w.query(default_term_expr, partition_names=[partition_w.name], collection_w.query(default_term_expr, partition_names=[partition_w.name],
check_task=CheckTasks.check_query_results, check_items={exp_res: res}) check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6059")
def test_query_partition_without_loading(self): def test_query_partition_without_loading(self):
""" """
target: test query on partition without loading target: test query on partition without loading
...@@ -378,7 +388,7 @@ class TestQueryBase(TestcaseBase): ...@@ -378,7 +388,7 @@ class TestQueryBase(TestcaseBase):
df = cf.gen_default_dataframe_data(ct.default_nb) df = cf.gen_default_dataframe_data(ct.default_nb)
partition_w.insert(df) partition_w.insert(df)
assert partition_w.num_entities == ct.default_nb assert partition_w.num_entities == ct.default_nb
error = {ct.err_code: 1, ct.err_msg: 'cannot find collection'} error = {ct.err_code: 1, ct.err_msg: f'collection {collection_w.name} was not loaded into memory'}
collection_w.query(default_term_expr, partition_names=[partition_w.name], collection_w.query(default_term_expr, partition_names=[partition_w.name],
check_items=CheckTasks.err_res, check_task=error) check_items=CheckTasks.err_res, check_task=error)
...@@ -390,12 +400,11 @@ class TestQueryBase(TestcaseBase): ...@@ -390,12 +400,11 @@ class TestQueryBase(TestcaseBase):
expected: verify query result expected: verify query result
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res = vectors[0].iloc[0:2].to_dict('records') res = vectors[0].iloc[:2, :2].to_dict('records')
collection_w.query(default_term_expr, partition_names=[ct.default_partition_name], collection_w.query(default_term_expr, partition_names=[ct.default_partition_name],
check_task=CheckTasks.check_query_results, check_items={exp_res: res}) check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6028")
def test_query_empty_partition(self): def test_query_empty_partition(self):
""" """
target: test query on empty partition target: test query on empty partition
...@@ -410,7 +419,6 @@ class TestQueryBase(TestcaseBase): ...@@ -410,7 +419,6 @@ class TestQueryBase(TestcaseBase):
assert len(res) == 0 assert len(res) == 0
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6061")
def test_query_not_existed_partition(self): def test_query_not_existed_partition(self):
""" """
target: test query on a not existed partition target: test query on a not existed partition
...@@ -420,7 +428,7 @@ class TestQueryBase(TestcaseBase): ...@@ -420,7 +428,7 @@ class TestQueryBase(TestcaseBase):
collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix))
collection_w.load() collection_w.load()
partition_names = cf.gen_unique_str() partition_names = cf.gen_unique_str()
error = {ct.err_code: 1, ct.err_msg: 'cannot find partition'} error = {ct.err_code: 1, ct.err_msg: f'PartitonName: {partition_names} not found'}
collection_w.query(default_term_expr, partition_names=[partition_names], collection_w.query(default_term_expr, partition_names=[partition_names],
check_items=CheckTasks.err_res, check_task=error) check_items=CheckTasks.err_res, check_task=error)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册