未验证 提交 082fb744 编写于 作者: T ThreadDao 提交者: GitHub

[skip ci] update search case for the change of query param in http-api (#4320)

* [skip ci] update http search case for field
Signed-off-by: NThreadDao <zongyufen@foxmail.com>

* [skip ci] update http search case for field
Signed-off-by: NThreadDao <zongyufen@foxmail.com>
上级 fbf5972f
......@@ -18,6 +18,7 @@ class Request(object):
def _check_status(self, result):
# logging.getLogger().info(result.text)
if result.status_code not in [200, 201, 204]:
logging.getLogger().error(result.text)
return False
if not result.text or "code" not in json.loads(result.text):
return True
......@@ -283,17 +284,16 @@ class MilvusClient(object):
if field["field_name"] == field_name:
return field["index_params"]
def search(self, collection_name, query_expr, fields=None, partition_tags=None):
def search(self, collection_name, query_expr):
url = self._url+url_collections+'/'+str(collection_name)+'/entities'
r = Request(url)
search_params = {
"query": query_expr,
"fields": fields,
"partition_tags": partition_tags
"query": query_expr
}
# logging.getLogger().info(search_params)
try:
status, data = r.get_with_body(search_params)
logging.getLogger().info(status)
if status:
return data
else:
......
import logging
import pytest
import requests
from utils import *
from constants import *
......@@ -135,14 +134,35 @@ class TestSearchBase:
expected: return field value
"""
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
data = client.search(collection, query, fields=[default_int_field_name])
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[default_int_field_name])
client.search(collection, query)
data = client.search(collection, query)
res = data['result']
assert data['nq'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == default_top_k
assert default_int_field_name in res[0][0]['entity'].keys()
def test_search_with_not_exist_field(self, client, collection):
"""
target: test search with not existed field
method: call search with exist field and not exist field
expected: not ok
"""
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[default_int_field_name, "default_int_field_name"])
assert not client.search(collection, query)
def test_search_with_none_field(self, client, collection):
"""
target: test search with not existed field
method: call search with exist field and not exist field
expected: not ok
"""
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[None])
assert not client.search(collection, query)
# TODO
def test_search_invalid_n_probe(self, client, collection, ):
"""
......@@ -238,12 +258,12 @@ class TestSearchBase:
"""
entities, ids = init_data(client, collection)
client.create_partition(collection, default_tag)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
data = client.search(collection, query, partition_tags=default_tag)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, partition_tags=[default_tag])
data = client.search(collection, query)
res = data['result']
assert data['nq'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == 0
assert res[0] is None
def test_search_binary_flat(self, client, binary_collection):
"""
......@@ -252,7 +272,8 @@ class TestSearchBase:
expected:
"""
raw_vectors, binary_entities, ids = init_binary_data(client, binary_collection)
query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,default_nq, metric_type='JACCARD')
query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,
default_nq, metric_type='JACCARD')
data = client.search(binary_collection, query)
res = data['result']
assert data['nq'] == default_nq
......
......@@ -311,7 +311,7 @@ def assert_equal_entity(a, b):
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
metric_type="L2", replace_vecs=None):
metric_type="L2", fields=None, partition_tags=None, replace_vecs=None):
if rand_vector is True:
dimension = len(entities[0][field_name][0])
query_vectors = gen_vectors(nq, dimension)
......@@ -326,6 +326,10 @@ def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe":
"must": [must_param]
}
}
if fields:
query.update({"fields": fields})
if partition_tags:
query.update({"partition_tags": partition_tags})
# logging.getLogger().info(len(query_vectors[0]))
return query, query_vectors
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册