未验证 提交 626516bb 编写于 作者: N nico 提交者: GitHub

Add test cases of search iterator (#25039)

Signed-off-by: Nnico <cheng.yuan@zilliz.com>
上级 1357ef70
......@@ -176,6 +176,22 @@ class ApiCollectionWrapper:
timeout=timeout, **kwargs).run()
return res, check_result
@trace()
def search_iterator(self, data, anns_field, param, limit, expr=None,
partition_names=None, output_fields=None, timeout=None, round_decimal=-1,
check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
func_name = sys._getframe().f_code.co_name
res, check = api_request([self.collection.search_iterator, data, anns_field, param, limit,
expr, partition_names, output_fields, timeout, round_decimal], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
data=data, anns_field=anns_field, param=param, limit=limit,
expr=expr, partition_names=partition_names,
output_fields=output_fields,
timeout=timeout, **kwargs).run()
return res, check_result
@trace()
def query(self, expr, output_fields=None, partition_names=None, timeout=None, check_task=None, check_items=None,
**kwargs):
......@@ -190,6 +206,20 @@ class ApiCollectionWrapper:
timeout=timeout, **kwargs).run()
return res, check_result
@trace()
def query_iterator(self, expr, output_fields=None, partition_names=None, timeout=None, check_task=None,
check_items=None, **kwargs):
# time.sleep(5)
timeout = TIMEOUT if timeout is None else timeout
func_name = sys._getframe().f_code.co_name
res, check = api_request([self.collection.query_iterator, expr, output_fields, partition_names, timeout], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
expression=expr, partition_names=partition_names,
output_fields=output_fields,
timeout=timeout, **kwargs).run()
return res, check_result
@property
def partitions(self):
return self.collection.partitions
......
......@@ -56,10 +56,18 @@ class ResponseChecker:
# Search interface of collection and partition that response check
result = self.check_search_results(self.response, self.func_name, self.check_items)
elif self.check_task == CheckTasks.check_search_iterator:
# Search iterator interface of collection and partition that response check
result = self.check_search_iterator(self.response, self.func_name, self.check_items)
elif self.check_task == CheckTasks.check_query_results:
# Query interface of collection and partition that response check
result = self.check_query_results(self.response, self.func_name, self.check_items)
elif self.check_task == CheckTasks.check_query_iterator:
# query iterator interface of collection and partition that response check
result = self.check_query_iterator(self.response, self.func_name, self.check_items)
elif self.check_task == CheckTasks.check_query_empty:
result = self.check_query_empty(self.response, self.func_name)
......@@ -337,6 +345,46 @@ class ResponseChecker:
return True
@staticmethod
def check_search_iterator(search_res, func_name, check_items):
"""
target: check the search iterator results
method: 1. check the iterator number
2. check the limit(topK) and ids
3. check the distance
expected: check the search is ok
"""
log.info("search_iterator_results_check: checking the searching results")
if func_name != 'search_iterator':
log.warning("The function name is {} rather than {}".format(func_name, "search_iterator"))
search_iterator = search_res
pk_list = []
while True:
res = search_iterator.next()
if len(res[0]) == 0:
log.info("search iteration finished, close")
search_iterator.close()
break
if check_items.get("limit", None):
assert len(res[0].ids) <= check_items["limit"]
if check_items.get("radius", None):
for distance in res[0].distances:
if check_items["metric_type"] == "L2":
assert distance < check_items["radius"]
else:
assert distance > check_items["radius"]
if check_items.get("range_filter", None):
for distance in res[0].distances:
if check_items["metric_type"] == "L2":
assert distance >= check_items["range_filter"]
else:
assert distance <= check_items["range_filter"]
pk_list.extend(res[0].ids)
assert len(pk_list) == len(set(pk_list))
log.info("check: total %d results" % len(pk_list))
return True
@staticmethod
def check_query_results(query_res, func_name, check_items):
"""
......@@ -372,6 +420,37 @@ class ResponseChecker:
return False
log.warning(f'Expected query result is {exp_res}')
@staticmethod
def check_query_iterator(query_res, func_name, check_items):
"""
target: check the query results
method: 1. check the query number
2. check the limit(topK) and ids
3. check the distance
expected: check the search is ok
"""
log.info("query_iterator_results_check: checking the query results")
if func_name != 'query_iterator':
log.warning("The function name is {} rather than {}".format(func_name, "query_iterator"))
query_iterator = query_res
pk_list = []
while True:
res = query_iterator.next()
if len(res) == 0:
log.info("search iteration finished, close")
query_iterator.close()
break
for i in range(len(res)):
pk_list.append(res[i][ct.default_int64_field_name])
if check_items.get("limit", None):
assert len(res) <= check_items["limit"]
assert len(pk_list) == len(set(pk_list))
if check_items.get("count", None):
assert len(pk_list) == check_items["count"]
log.info("check: total %d results" % len(pk_list))
return True
@staticmethod
def check_query_empty(query_res, func_name):
"""
......
......@@ -243,7 +243,9 @@ class CheckTasks:
check_collection_property = "check_collection_property"
check_partition_property = "check_partition_property"
check_search_results = "check_search_results"
check_search_iterator = "check_search_iterator"
check_query_results = "check_query_results"
check_query_iterator = "check_query_iterator"
check_query_empty = "check_query_empty" # verify that query result is empty
check_query_not_empty = "check_query_not_empty"
check_distance = "check_distance"
......
......@@ -12,7 +12,7 @@ allure-pytest==2.7.0
pytest-print==0.2.1
pytest-level==0.1.1
pytest-xdist==2.5.0
pymilvus==2.4.0.dev75
pymilvus==2.4.0.dev81
pytest-rerunfailures==9.1.1
git+https://github.com/Projectplace/pytest-tags
ndg-httpsclient
......
......@@ -1618,6 +1618,68 @@ class TestQueryOperation(TestcaseBase):
assert res[ct.default_bool_field_name] is False
assert res[ct.default_string_field_name] == "abc"
@pytest.mark.tags(CaseLabel.L1)
def test_query_iterator_normal(self):
"""
target: test query iterator normal
method: 1. query iterator
2. check the result, expect pk
expected: query successfully
"""
# 1. initialize with data
limit = 100
collection_w = self.init_collection_general(prefix, True, is_index=False)[0]
collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"})
collection_w.load()
# 2. search iterator
expr = "int64 >= 0"
collection_w.query_iterator(expr, limit=limit,
check_task=CheckTasks.check_query_iterator,
check_items={"count": ct.default_nb,
"limit": limit})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("offset", [500, 1000, 1777])
def test_query_iterator_with_offset(self, offset):
"""
target: test query iterator normal
method: 1. query iterator
2. check the result, expect pk
expected: query successfully
"""
# 1. initialize with data
limit = 100
collection_w = self.init_collection_general(prefix, True, is_index=False)[0]
collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"})
collection_w.load()
# 2. search iterator
expr = "int64 >= 0"
collection_w.query_iterator(expr, limit=limit, offset=offset,
check_task=CheckTasks.check_query_iterator,
check_items={"count": ct.default_nb - offset,
"limit": limit})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("limit", [10, 100, 777, 1000])
def test_query_iterator_with_different_limit(self, limit):
"""
target: test query iterator normal
method: 1. query iterator
2. check the result, expect pk
expected: query successfully
"""
# 1. initialize with data
offset = 500
collection_w = self.init_collection_general(prefix, True, is_index=False)[0]
collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"})
collection_w.load()
# 2. search iterator
expr = "int64 >= 0"
collection_w.query_iterator(expr, limit=limit, offset=offset,
check_task=CheckTasks.check_query_iterator,
check_items={"count": ct.default_nb - offset,
"limit": limit})
class TestQueryString(TestcaseBase):
"""
......
......@@ -8358,3 +8358,192 @@ class TestCollectionSearchJSON(TestcaseBase):
"ids": insert_ids,
"limit": default_limit})
class TestSearchIterator(TestcaseBase):
""" Test case of search iterator """
@pytest.mark.tags(CaseLabel.L1)
def test_search_iterator_normal(self):
"""
target: test search iterator normal
method: 1. search iterator
2. check the result, expect pk
expected: search successfully
"""
# 1. initialize with data
limit = 100
dim = 128
collection_w = self.init_collection_general(prefix, True, dim=dim, is_index=False)[0]
collection_w.create_index(field_name, {"metric_type": "L2"})
collection_w.load()
# 2. search iterator
search_params = {"metric_type": "L2"}
collection_w.search_iterator(vectors[:1], field_name, search_params, limit,
check_task=CheckTasks.check_search_iterator,
check_items={"limit": limit})
@pytest.mark.tags(CaseLabel.L1)
def test_search_iterator_binary(self):
"""
target: test search iterator binary
method: 1. search iterator
2. check the result, expect pk
expected: search successfully
"""
# 1. initialize with data
limit = 200
collection_w = self.init_collection_general(prefix, True, is_binary=True)[0]
# 2. search iterator
_, binary_vectors = cf.gen_binary_vectors(2, ct.default_dim)
collection_w.search_iterator(binary_vectors[:1], binary_field_name,
ct.default_search_binary_params, limit,
check_task=CheckTasks.check_search_iterator,
check_items={"limit": limit})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("metrics", ct.float_metrics)
def test_search_iterator_with_expression(self, metrics):
"""
target: test search iterator normal
method: 1. search iterator
2. check the result, expect pk not repeat and meet the expr requirements
expected: search successfully
"""
# 1. initialize with data
limit = 100
dim = 128
collection_w = self.init_collection_general(prefix, True, dim=dim, is_index=False)[0]
collection_w.create_index(field_name, {"metric_type": metrics})
collection_w.load()
# 2. search iterator
search_params = {"metric_type": metrics}
expression = "1000.0 <= float < 2000.0"
collection_w.search_iterator(vectors[:1], field_name, search_params, limit,
expr=expression, check_task=CheckTasks.check_search_iterator,
check_items={})
@pytest.mark.tags(CaseLabel.L2)
def test_range_search_iterator_L2(self):
"""
target: test iterator range search
method: 1. search iterator
2. check the result, expect pk not repeat and meet the expr requirements
expected: search successfully
"""
# 1. initialize with data
limit = 100
collection_w = self.init_collection_general(prefix, True, is_index=False)[0]
collection_w.create_index(field_name, {"metric_type": "L2"})
collection_w.load()
# 2. search iterator
search_params = {"metric_type": "L2", "params": {"radius": 35.0, "range_filter": 34.0}}
collection_w.search_iterator(vectors[:1], field_name, search_params, limit,
check_task=CheckTasks.check_search_iterator,
check_items={"metric_type": "L2",
"radius": 35.0,
"range_filter": 34.0})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("metrics", ct.float_metrics[1:])
def test_range_search_iterator_IP_COSINE(self, metrics):
"""
target: test iterator range search
method: 1. search iterator
2. check the result, expect pk not repeat and meet the expr requirements
expected: search successfully
"""
# 1. initialize with data
limit = 100
collection_w = self.init_collection_general(prefix, True, is_index=False)[0]
collection_w.create_index(field_name, {"metric_type": metrics})
collection_w.load()
# 2. search iterator
search_params = {"metric_type": metrics, "params": {"radius": 0, "range_filter": 45}}
collection_w.search_iterator(vectors[:1], field_name, search_params, limit,
check_task=CheckTasks.check_search_iterator,
check_items={"metric_type": metrics,
"radius": 0,
"range_filter": 45})
@pytest.mark.tags(CaseLabel.L2)
def test_range_search_iterator_only_radius(self):
"""
target: test search iterator normal
method: 1. search iterator
2. check the result, expect pk not repeat and meet the expr requirements
expected: search successfully
"""
# 1. initialize with data
limit = 100
collection_w = self.init_collection_general(prefix, True, is_index=False)[0]
collection_w.create_index(field_name, {"metric_type": "L2"})
collection_w.load()
# 2. search iterator
search_params = {"metric_type": "L2", "params": {"radius": 35.0}}
collection_w.search_iterator(vectors[:1], field_name, search_params, limit,
check_task=CheckTasks.check_search_iterator,
check_items={"metric_type": "L2",
"radius": 35.0})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip("issue #25145")
@pytest.mark.parametrize("index, params",
zip(ct.all_index_types[:6],
ct.default_index_params[:6]))
@pytest.mark.parametrize("metrics", ct.float_metrics)
def test_search_iterator_after_different_index_metrics(self, index, params, metrics):
"""
target: test search iterator using different index
method: 1. search iterator
2. check the result, expect pk not repeat and meet the expr requirements
expected: search successfully
"""
# 1. initialize with data
limit = 100
collection_w = self.init_collection_general(prefix, True, is_index=False)[0]
default_index = {"index_type": index, "params": params, "metric_type": metrics}
collection_w.create_index(field_name, default_index)
collection_w.load()
# 2. search iterator
search_params = {"metric_type": metrics}
collection_w.search_iterator(vectors[:1], field_name, search_params, limit,
check_task=CheckTasks.check_search_iterator,
check_items={})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("limit", [10, 100, 777, 1000])
def test_search_iterator_with_different_limit(self, limit):
"""
target: test search iterator normal
method: 1. search iterator
2. check the result, expect pk not repeat and meet the expr requirements
expected: search successfully
"""
# 1. initialize with data
collection_w = self.init_collection_general(prefix, True)[0]
# 2. search iterator
search_params = {"metric_type": "COSINE"}
collection_w.search_iterator(vectors[:1], field_name, search_params, limit,
check_task=CheckTasks.check_search_iterator,
check_items={"limit": limit})
@pytest.mark.tags(CaseLabel.L2)
def test_search_iterator_invalid_nq(self):
"""
target: test search iterator normal
method: 1. search iterator
2. check the result, expect pk
expected: search successfully
"""
# 1. initialize with data
limit = 100
dim = 128
collection_w = self.init_collection_general(prefix, True, dim=dim, is_index=False)[0]
collection_w.create_index(field_name, {"metric_type": "L2"})
collection_w.load()
# 2. search iterator
search_params = {"metric_type": "L2"}
collection_w.search_iterator(vectors[:2], field_name, search_params, limit,
check_task=CheckTasks.err_res,
check_items={"err_code": 1,
"err_msg": "Not support multiple vector iterator at present"})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册