From 8ee2419582035bc352c23b47bff1173e39897e87 Mon Sep 17 00:00:00 2001 From: nico <109071306+NicoYuan1986@users.noreply.github.com> Date: Tue, 25 Jul 2023 10:47:01 +0800 Subject: [PATCH] Add test cases of query count(*) filter (#25844) Signed-off-by: nico --- tests/python_client/testcases/test_query.py | 133 +++++++++++++++++++- 1 file changed, 131 insertions(+), 2 deletions(-) diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index 8c8db08bc..96521a9e3 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -2533,8 +2533,7 @@ class TestQueryCount(TestcaseBase): # count with expr collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, - check_items={exp_res: [{count: ct.default_nb}]} - ) + check_items={exp_res: [{count: ct.default_nb}]}) collection_w.query(expr=default_term_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, @@ -2747,3 +2746,133 @@ class TestQueryCount(TestcaseBase): collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], ignore_growing=True, check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: 0}]}) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("expression", cf.gen_normal_expressions()) + def test_count_expressions(self, expression): + """ + target: test count with expr + method: count with expr + expected: verify count + """ + # create -> insert -> index -> load + collection_w, _vectors, _, insert_ids = self.init_collection_general(insert_data=True)[0:4] + + # filter result with expression in collection + _vectors = _vectors[0] + expression = expression.replace("&&", "and").replace("||", "or") + filter_ids = [] + for i, _id in enumerate(insert_ids): + int64 = _vectors.int64[i] + float = _vectors.float[i] + if not expression or eval(expression): + filter_ids.append(_id) + res = len(filter_ids) + + # count with expr + collection_w.query(expr=expression, output_fields=[count], + check_task=CheckTasks.check_query_results, + check_items={exp_res: [{count: res}]}) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("bool_type", [True, False, "true", "false"]) + def test_count_bool_expressions(self, bool_type): + """ + target: test count with binary expr + method: count with binary expr + expected: verify count + """ + # create -> insert -> index -> load + collection_w, _vectors, _, insert_ids = \ + self.init_collection_general(insert_data=True, is_all_data_type=True)[0:4] + + # filter result with expression in collection + filter_ids = [] + bool_type_cmp = bool_type + if bool_type == "true": + bool_type_cmp = True + if bool_type == "false": + bool_type_cmp = False + for i, _id in enumerate(insert_ids): + if _vectors[0][f"{ct.default_bool_field_name}"][i] == bool_type_cmp: + filter_ids.append(_id) + res = len(filter_ids) + + # count with expr + expression = f"{ct.default_bool_field_name} == {bool_type}" + collection_w.query(expr=expression, output_fields=[count], + check_task=CheckTasks.check_query_results, + check_items={exp_res: [{count: res}]}) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("expression", cf.gen_normal_expressions_field(default_float_field_name)) + def test_count_expression_auto_field(self, expression): + """ + target: test count with expr + method: count with expr + expected: verify count + """ + # create -> insert -> index -> load + collection_w, _vectors, _, insert_ids = self.init_collection_general(insert_data=True)[0:4] + + # filter result with expression in collection + _vectors = _vectors[0] + expression = expression.replace("&&", "and").replace("||", "or") + filter_ids = [] + for i, _id in enumerate(insert_ids): + float = _vectors.float[i] + if not expression or eval(expression): + filter_ids.append(_id) + res = len(filter_ids) + + # count with expr + collection_w.query(expr=expression, output_fields=[count], + check_task=CheckTasks.check_query_results, + check_items={exp_res: [{count: res}]}) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.skip(reason="issue #25841") + def test_count_expression_all_datatype(self): + """ + target: test count with expr + method: count with expr + expected: verify count + """ + # create -> insert -> index -> load + collection_w = self.init_collection_general(insert_data=True, is_all_data_type=True)[0] + + # count with expr + expression = "int64 >= 0 && int32 >= 1999 && int16 >= 0 && int8 >= 0 && float <= 1999.0 && double >= 0" + # expression = "int64 == 1999" + collection_w.query(expr=expression, output_fields=[count], + check_task=CheckTasks.check_query_results, + check_items={exp_res: [{count: 1}]}) + + @pytest.mark.tags(CaseLabel.L1) + def test_count_expression_comparative(self): + """ + target: test count with expr + method: count with expr + expected: verify count + """ + # create -> insert -> index -> load + fields = [cf.gen_int64_field("int64_1"), cf.gen_int64_field("int64_2"), + cf.gen_float_vec_field()] + schema = cf.gen_collection_schema(fields=fields, primary_field="int64_1") + collection_w = self.init_collection_wrap(schema=schema) + + nb, res = 10, 0 + int_values = [random.randint(0, nb) for _ in range(nb)] + data = [[i for i in range(nb)], int_values, cf.gen_vectors(nb, ct.default_dim)] + collection_w.insert(data) + collection_w.create_index(ct.default_float_vec_field_name) + collection_w.load() + + for i in range(nb): + res = res + 1 if i >= int_values[i] else res + + # count with expr + expression = "int64_1 >= int64_2" + collection_w.query(expr=expression, output_fields=[count], + check_task=CheckTasks.check_query_results, + check_items={exp_res: [{count: res}]}) -- GitLab