未验证 提交 5543dba8 编写于 作者: T ThreadDao 提交者: GitHub

Update query cases and query result check (#6146)

* update query cases

check query result
Signed-off-by: NThreadDao <yufen.zong@zilliz.com>

* [skip ci] skip ci
Signed-off-by: NThreadDao <yufen.zong@zilliz.com>
上级 6d7edc60
......@@ -186,25 +186,7 @@ class ResponseChecker:
raise Exception("No expect values found in the check task")
exp_res = check_items.get("exp_res", None)
if exp_res and isinstance(query_res, list):
# assert exp_res == query_res
assert len(exp_res) == len(query_res)
for i in range(len(exp_res)):
assert_entity_equal(exp=exp_res[i], actual=query_res[i])
def assert_entity_equal(exp, actual):
"""
compare two entities
{"int64": 0, "float": 0.0, "float_vec": [0.09111554112502457, ..., 0.08652634258062468]}
:param exp: exp entity
:param actual: actual entity
:return: bool
"""
assert actual.keys() == exp.keys()
for field, value in exp.items():
if isinstance(value, list):
assert len(actual[field]) == len(exp[field])
for i in range(len(exp[field])):
assert abs(actual[field][i] - exp[field][i]) < ct.epsilon
else:
assert actual[field] == exp[field]
assert pc.equal_entities_list(exp=exp_res, actual=query_res)
# assert len(exp_res) == len(query_res)
# for i in range(len(exp_res)):
# assert_entity_equal(exp=exp_res[i], actual=query_res[i])
import pytest
import sys
import operator
from common import common_type as ct
sys.path.append("..")
from utils.util_log import test_log as log
......@@ -63,7 +63,8 @@ def list_de_duplication(_list):
# Keep the order of the elements unchanged
result.sort(key=_list.index)
log.debug("[LIST_DE_DUPLICATION] %s after removing the duplicate elements, the list becomes %s" % (str(_list), str(result)))
log.debug("[LIST_DE_DUPLICATION] %s after removing the duplicate elements, the list becomes %s" % (
str(_list), str(result)))
return result
......@@ -116,3 +117,86 @@ def get_connect_object_name(_list):
log.debug("[GET_CONNECT_OBJECT_NAME] list:%s is reset to list:%s" % (str(_list), str(new_list)))
return new_list
def equal_entity(exp, actual):
"""
compare two entities containing vector field
{"int64": 0, "float": 0.0, "float_vec": [0.09111554112502457, ..., 0.08652634258062468]}
:param exp: exp entity
:param actual: actual entity
:return: bool
"""
assert actual.keys() == exp.keys()
for field, value in exp.items():
if isinstance(value, list):
assert len(actual[field]) == len(exp[field])
for i in range(len(exp[field])):
assert abs(actual[field][i] - exp[field][i]) < ct.epsilon
else:
assert actual[field] == exp[field]
def entity_in(entity, entities, primary_field=ct.default_int64_field_name):
"""
according to the primary key to judge entity in the entities list
:param entity: dict
{"int": 0, "vec": [0.999999, 0.111111]}
:param entities: list of dict
[{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
:param primary_field: collection primary field
:return: True or False
"""
primary_key = entity.get(primary_field, None)
primary_keys = []
for e in entities:
primary_keys.append(e[primary_field])
if primary_key not in primary_keys:
return False
index = primary_key.index(primary_key)
return equal_entity(entities[index], entity)
def remove_entity(entity, entities, primary_field=ct.default_int64_field_name):
"""
according to the primary key to remove an entity from an entities list
:param entity: dict
{"int": 0, "vec": [0.999999, 0.111111]}
:param entities: list of dict
[{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
:param primary_field: collection primary field
:return: entities of removed entity
"""
primary_key = entity.get(primary_field, None)
primary_keys = []
for e in entities:
primary_keys.append(e[primary_field])
index = primary_keys.index(primary_key)
entities.pop(index)
return entities
def equal_entities_list(exp, actual):
"""
compare two entities lists in inconsistent order
:param exp: exp entities list, list of dict
:param actual: actual entities list, list of dict
:return: True or False
example:
exp = [{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
actual = [{"int": 1, "vec": [0.888888, 0.222222]}, {"int": 0, "vec": [0.999999, 0.111111]}]
exp = actual
"""
if len(exp) != len(actual):
return False
for a in actual:
# if vec field returned in query res
# if entity_in_entities(a, exp):
if a in exp:
try:
exp.remove(a)
# if vec field returned in query res
# remove_entity(a, exp)
except Exception as ex:
print(ex)
return True if len(exp) == 0 else False
......@@ -31,11 +31,10 @@ class TestQueryBase(TestcaseBase):
int_values = vectors[0][ct.default_int64_field_name].values.tolist()
pos = 5
term_expr = f'{ct.default_int64_field_name} in {int_values[:pos]}'
res = vectors[0].iloc[0:pos].to_dict('records')
res = vectors[0].iloc[0:pos, :2].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="#6028")
def test_query_empty_collection(self):
"""
target: test query empty collection
......@@ -85,7 +84,6 @@ class TestQueryBase(TestcaseBase):
assert len(res) == 0
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="#6033")
def test_query_expr_none(self):
"""
target: test query with none expr
......@@ -93,11 +91,10 @@ class TestQueryBase(TestcaseBase):
expected: raise exception
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: "invalid expr"}
error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"}
collection_w.query(None, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="#6044")
@pytest.mark.parametrize("expr", [1, 2., [], {}, ()])
def test_query_expr_non_string(self, expr):
"""
......@@ -106,7 +103,7 @@ class TestQueryBase(TestcaseBase):
expected: raise exception
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: "expr must string type"}
error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"}
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1)
......@@ -130,7 +127,7 @@ class TestQueryBase(TestcaseBase):
expected: query result is correct
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res = vectors[0].iloc[0:2].to_dict('records')
res = vectors[0].iloc[:2, :2].to_dict('records')
collection_w.query(default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1)
......@@ -214,7 +211,7 @@ class TestQueryBase(TestcaseBase):
"""
target: test query with empty array term expr
method: query with empty term expr
expected: empty rsult
expected: empty result
"""
term_expr = f'{ct.default_int64_field_name} in []'
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
......@@ -270,8 +267,8 @@ class TestQueryBase(TestcaseBase):
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res, _ = collection_w.query(default_term_expr, output_fields=None)
fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_float_vec_field_name]
assert list(res[0].keys()) == fields
fields = [ct.default_int64_field_name, ct.default_float_field_name]
assert set(res[0].keys()) == set(fields)
@pytest.mark.tags(CaseLabel.L0)
def test_query_output_one_field(self):
......@@ -282,7 +279,7 @@ class TestQueryBase(TestcaseBase):
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
assert list(res[0].keys()) == [ct.default_int64_field_name]
assert set(res[0].keys()) == set([ct.default_int64_field_name])
@pytest.mark.tags(CaseLabel.L1)
def test_query_output_all_fields(self):
......@@ -292,36 +289,51 @@ class TestQueryBase(TestcaseBase):
expected: return all fields
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_float_vec_field_name]
log.debug(collection_w.num_entities)
fields = [ct.default_int64_field_name, ct.default_float_field_name]
res, _ = collection_w.query(default_term_expr, output_fields=fields)
log.debug(res)
assert list(res[0].keys()) == fields
assert set(res[0].keys()) == set(fields)
res_1, _ = collection_w.query(default_term_expr, output_fields=[ct.default_float_field_name])
assert set(res_1[0].keys()) == set(fields)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6049")
def test_query_output_not_existed_field(self):
@pytest.mark.xfail(reason="issue #6143")
@pytest.mark.parametrize("output_fields", [[ct.default_float_vec_field_name],
[ct.default_int64_field_name, ct.default_float_vec_field_name]])
def test_query_output_vec_field(self, output_fields):
"""
target: test query output not existed field
method: query with not existed output field
target: test query with vec output field
method: specify vec field as output field
expected: raise exception
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: 'cannot find field'}
collection_w.query(default_term_expr, output_fields=["int"], check_items=CheckTasks.err_res, check_task=error)
error = {ct.err_code: 1, ct.err_msg: "unsupported leaf node"}
collection_w.query(default_term_expr, output_fields=output_fields,
check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6049")
def test_query_output_part_not_existed_field(self):
def test_query_output_primary_field(self):
"""
target: test query output part not existed field
method: query with part not existed field
target: test query with output field only primary field
method: specify int64 primary field as output field
expected: return int64 field
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
assert list(res[0].keys()) == [ct.default_int64_field_name]
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6074")
@pytest.mark.parametrize("output_fields", [["int"],
[ct.default_int64_field_name, "int"]])
def test_query_output_not_existed_field(self, output_fields):
"""
target: test query output not existed field
method: query with not existed output field
expected: raise exception
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: 'cannot find field'}
fields = [ct.default_int64_field_name, "int"]
collection_w.query(default_term_expr, output_fields=fields, check_items=CheckTasks.err_res, check_task=error)
collection_w.query(default_term_expr, output_fields=output_fields, check_items=CheckTasks.err_res, check_task=error)
@pytest.mark.tags(CaseLabel.L1)
def test_query_empty_output_fields(self):
......@@ -332,10 +344,9 @@ class TestQueryBase(TestcaseBase):
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
query_res, _ = collection_w.query(default_term_expr, output_fields=[])
fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_float_vec_field_name]
fields = [ct.default_int64_field_name, ct.default_float_field_name]
assert list(query_res[0].keys()) == fields
@pytest.mark.xfail(reason="issue #6056")
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("fields", ct.get_invalid_strs)
def test_query_invalid_output_fields(self, fields):
......@@ -345,7 +356,7 @@ class TestQueryBase(TestcaseBase):
expected: raise exception
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
error = {ct.err_code: 1, ct.err_msg: 'invalid output fields'}
error = {ct.err_code: 1, ct.err_msg: 'Invalid query format. must be a list'}
collection_w.query(default_term_expr, output_fields=fields, check_items=CheckTasks.err_res, check_task=error)
@pytest.mark.tags(CaseLabel.L0)
......@@ -361,12 +372,11 @@ class TestQueryBase(TestcaseBase):
partition_w.insert(df)
assert collection_w.num_entities == ct.default_nb
partition_w.load()
res = df.iloc[0:2].to_dict('records')
res = df.iloc[:2, :2].to_dict('records')
collection_w.query(default_term_expr, partition_names=[partition_w.name],
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6059")
def test_query_partition_without_loading(self):
"""
target: test query on partition without loading
......@@ -378,7 +388,7 @@ class TestQueryBase(TestcaseBase):
df = cf.gen_default_dataframe_data(ct.default_nb)
partition_w.insert(df)
assert partition_w.num_entities == ct.default_nb
error = {ct.err_code: 1, ct.err_msg: 'cannot find collection'}
error = {ct.err_code: 1, ct.err_msg: f'collection {collection_w.name} was not loaded into memory'}
collection_w.query(default_term_expr, partition_names=[partition_w.name],
check_items=CheckTasks.err_res, check_task=error)
......@@ -390,12 +400,11 @@ class TestQueryBase(TestcaseBase):
expected: verify query result
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
res = vectors[0].iloc[0:2].to_dict('records')
res = vectors[0].iloc[:2, :2].to_dict('records')
collection_w.query(default_term_expr, partition_names=[ct.default_partition_name],
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6028")
def test_query_empty_partition(self):
"""
target: test query on empty partition
......@@ -410,7 +419,6 @@ class TestQueryBase(TestcaseBase):
assert len(res) == 0
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6061")
def test_query_not_existed_partition(self):
"""
target: test query on a not existed partition
......@@ -420,7 +428,7 @@ class TestQueryBase(TestcaseBase):
collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix))
collection_w.load()
partition_names = cf.gen_unique_str()
error = {ct.err_code: 1, ct.err_msg: 'cannot find partition'}
error = {ct.err_code: 1, ct.err_msg: f'PartitonName: {partition_names} not found'}
collection_w.query(default_term_expr, partition_names=[partition_names],
check_items=CheckTasks.err_res, check_task=error)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册