未验证 提交 31122a68 编写于 作者: B binbin 提交者: GitHub

Update high level api test cases (#25118)

Signed-off-by: NBinbin Lv <binbin.lv@zilliz.com>
上级 fe242289
......@@ -10,6 +10,7 @@ from base.partition_wrapper import ApiPartitionWrapper
from base.index_wrapper import ApiIndexWrapper
from base.utility_wrapper import ApiUtilityWrapper
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
from base.high_level_api_wrapper import HighLevelApiWrapper
from utils.util_log import test_log as log
from common import common_func as cf
from common import common_type as ct
......@@ -28,6 +29,7 @@ class Base:
field_schema_wrap = None
collection_object_list = []
resource_group_list = []
high_level_api_wrap = None
def setup_class(self):
log.info("[setup_class] Start setup class...")
......@@ -45,6 +47,7 @@ class Base:
self.index_wrap = ApiIndexWrapper()
self.collection_schema_wrap = ApiCollectionSchemaWrapper()
self.field_schema_wrap = ApiFieldSchemaWrapper()
self.high_level_api_wrap = HighLevelApiWrapper()
def teardown_method(self, method):
log.info(("*" * 35) + " teardown " + ("*" * 35))
......@@ -118,18 +121,28 @@ class TestcaseBase(Base):
Public methods that can be used for test cases.
"""
def _connect(self):
def _connect(self, enable_high_level_api=False):
""" Add a connection and create the connect """
if cf.param_info.param_user and cf.param_info.param_password:
res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,
host=cf.param_info.param_host,
port=cf.param_info.param_port, user=cf.param_info.param_user,
password=cf.param_info.param_password,
secure=cf.param_info.param_secure)
if enable_high_level_api:
if cf.param_info.param_uri:
uri = cf.param_info.param_uri
else:
uri = "http://" + cf.param_info.param_host + ":" + str(cf.param_info.param_port)
res, is_succ = self.connection_wrap.MilvusClient(uri=uri,
token=cf.param_info.param_token)
else:
res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,
host=cf.param_info.param_host,
port=cf.param_info.param_port)
if cf.param_info.param_user and cf.param_info.param_password:
res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,
host=cf.param_info.param_host,
port=cf.param_info.param_port,
user=cf.param_info.param_user,
password=cf.param_info.param_password,
secure=cf.param_info.param_secure)
else:
res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,
host=cf.param_info.param_host,
port=cf.param_info.param_port)
return res
def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,
......
......@@ -330,6 +330,7 @@ class ApiCollectionWrapper:
check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run()
return res, check_result
@trace()
def get_compaction_state(self, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
func_name = sys._getframe().f_code.co_name
......@@ -337,6 +338,7 @@ class ApiCollectionWrapper:
check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run()
return res, check_result
@trace()
def get_compaction_plans(self, timeout=None, check_task=None, check_items={}, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
func_name = sys._getframe().f_code.co_name
......@@ -350,6 +352,7 @@ class ApiCollectionWrapper:
# log.debug(res)
return res
@trace()
def get_replicas(self, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
func_name = sys._getframe().f_code.co_name
......@@ -357,9 +360,12 @@ class ApiCollectionWrapper:
check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run()
return res, check_result
@trace()
def describe(self, timeout=None, check_task=None, check_items=None):
timeout = TIMEOUT if timeout is None else timeout
func_name = sys._getframe().f_code.co_name
res, check = api_request([self.collection.describe, timeout])
check_result = ResponseChecker(res, func_name, check_task, check_items, check).run()
return res, check_result
from pymilvus import Connections
from pymilvus import DefaultConfig
from pymilvus import MilvusClient
import sys
sys.path.append("..")
......@@ -58,3 +59,10 @@ class ApiConnectionsWrapper:
response, is_succ = api_request([self.connection.get_connection_addr, alias])
check_result = ResponseChecker(response, func_name, check_task, check_items, is_succ, alias=alias).run()
return response, check_result
# high level api
def MilvusClient(self, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
response, succ = api_request([MilvusClient], **kwargs)
check_result = ResponseChecker(response, func_name, check_task, check_items, succ, **kwargs).run()
return response, check_result
import sys
import time
import timeout_decorator
from numpy import NaN
from pymilvus import Collection
sys.path.append("..")
from check.func_check import ResponseChecker
from utils.api_request import api_request
from utils.wrapper import trace
from utils.util_log import test_log as log
from pymilvus.orm.types import CONSISTENCY_STRONG
from common.common_func import param_info
TIMEOUT = 120
INDEX_NAME = ""
# keep small timeout for stability tests
# TIMEOUT = 5
class HighLevelApiWrapper:
def __init__(self, active_trace=False):
self.active_trace = active_trace
@trace()
def create_collection(self, client, collection_name, dimension, timeout=None, check_task=None,
check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.create_collection, collection_name, dimension], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, dimension=dimension,
**kwargs).run()
return res, check_result
@trace()
def insert(self, client, collection_name, data, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.insert, collection_name, data], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, data=data,
**kwargs).run()
return res, check_result
@trace()
def search(self, client, collection_name, data, limit=10, filter=None, output_fields=None, search_params=None,
timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.search, collection_name, data, filter, limit,
output_fields, search_params], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, data=data, limit=limit, filter=filter,
output_fields=output_fields, search_params=search_params,
**kwargs).run()
return res, check_result
@trace()
def query(self, client, collection_name, filter=None, output_fields=None,
timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.query, collection_name, filter, output_fields], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, filter=filter,
output_fields=output_fields,
**kwargs).run()
return res, check_result
@trace()
def get(self, client, collection_name, ids, output_fields=None,
timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.get, collection_name, ids, output_fields], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, ids=ids,
output_fields=output_fields,
**kwargs).run()
return res, check_result
@trace()
def num_entities(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.num_entities, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name,
**kwargs).run()
return res, check_result
@trace()
def delete(self, client, collection_name, pks, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.delete, collection_name, pks], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, pks=pks,
**kwargs).run()
return res, check_result
@trace()
def flush(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.flush, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name,
**kwargs).run()
return res, check_result
@trace()
def describe_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.describe_collection, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name,
**kwargs).run()
return res, check_result
@trace()
def list_collections(self, client, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.list_collections], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
**kwargs).run()
return res, check_result
@trace()
def drop_collection(self, client, collection_name, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.drop_collection, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name,
**kwargs).run()
return res, check_result
......@@ -83,9 +83,14 @@ class ResponseChecker:
elif self.check_task == CheckTasks.check_permission_deny:
# Collection interface response check
result = self.check_permission_deny(self.response, self.succ)
elif self.check_task == CheckTasks.check_rg_property:
# describe resource group interface response check
result = self.check_rg_property(self.response, self.func_name, self.check_items)
elif self.check_task == CheckTasks.check_describe_collection_property:
# describe collection interface(high level api) response check
result = self.check_describe_collection_property(self.response, self.func_name, self.check_items)
# Add check_items here if something new need verify
......@@ -178,6 +183,48 @@ class ResponseChecker:
assert collection.primary_field.name == check_items.get("primary")
return True
@staticmethod
def check_describe_collection_property(res, func_name, check_items):
"""
According to the check_items to check collection properties of res, which return from func_name
:param res: actual response of init collection
:type res: Collection
:param func_name: init collection API
:type func_name: str
:param check_items: which items expected to be checked, including name, schema, num_entities, primary
:type check_items: dict, {check_key: expected_value}
"""
exp_func_name = "describe_collection"
if func_name != exp_func_name:
log.warning("The function name is {} rather than {}".format(func_name, exp_func_name))
if len(check_items) == 0:
raise Exception("No expect values found in the check task")
if check_items.get("collection_name", None) is not None:
assert res["collection_name"] == check_items.get("collection_name")
if check_items.get("auto_id", False):
assert res["auto_id"] == check_items.get("auto_id")
if check_items.get("num_shards", 1):
assert res["num_shards"] == check_items.get("num_shards", 1)
if check_items.get("consistency_level", 2):
assert res["consistency_level"] == check_items.get("consistency_level", 2)
if check_items.get("enable_dynamic_field", True):
assert res["enable_dynamic_field"] == check_items.get("enable_dynamic_field", True)
if check_items.get("num_partitions", 1):
assert res["num_partitions"] == check_items.get("num_partitions", 1)
if check_items.get("id_name", "id"):
assert res["fields"][0]["name"] == check_items.get("id_name", "id")
if check_items.get("vector_name", "vector"):
assert res["fields"][1]["name"] == check_items.get("vector_name", "vector")
if check_items.get("dim", None) is not None:
assert res["fields"][1]["params"]["dim"] == check_items.get("dim")
assert res["fields"][0]["is_primary"] is True
assert res["fields"][0]["field_id"] == 100 and res["fields"][0]["type"] == 5
assert res["fields"][1]["field_id"] == 101 and res["fields"][1]["type"] == 101
return True
@staticmethod
def check_partition_property(partition, func_name, check_items):
exp_func_name = "init_partition"
......@@ -248,18 +295,26 @@ class ResponseChecker:
assert len(search_res) == check_items["nq"]
else:
log.info("search_results_check: Numbers of query searched is correct")
enable_high_level_api = check_items.get("enable_high_level_api", False)
log.debug(search_res)
for hits in search_res:
searched_original_vectors = []
ids = []
if enable_high_level_api:
for hit in hits:
ids.append(hit['id'])
else:
ids = list(hits.ids)
if (len(hits) != check_items["limit"]) \
or (len(hits.ids) != check_items["limit"]):
or (len(ids) != check_items["limit"]):
log.error("search_results_check: limit(topK) searched (%d) "
"is not equal with expected (%d)"
% (len(hits), check_items["limit"]))
assert len(hits) == check_items["limit"]
assert len(hits.ids) == check_items["limit"]
assert len(ids) == check_items["limit"]
else:
if check_items.get("ids", None) is not None:
ids_match = pc.list_contain_check(list(hits.ids),
ids_match = pc.list_contain_check(ids,
list(check_items["ids"]))
if not ids_match:
log.error("search_results_check: ids searched not match")
......
......@@ -38,8 +38,10 @@ class ParamInfo:
self.param_password = ""
self.param_secure = False
self.param_replica_num = ct.default_replica_num
self.param_uri = ""
self.param_token = ""
def prepare_param_info(self, host, port, handler, replica_num, user, password, secure):
def prepare_param_info(self, host, port, handler, replica_num, user, password, secure, uri, token):
self.param_host = host
self.param_port = port
self.param_handler = handler
......@@ -47,6 +49,8 @@ class ParamInfo:
self.param_password = password
self.param_secure = secure
self.param_replica_num = replica_num
self.param_uri = uri
self.param_token = token
param_info = ParamInfo()
......
......@@ -253,6 +253,7 @@ class CheckTasks:
check_permission_deny = "check_permission_deny"
check_value_equal = "check_value_equal"
check_rg_property = "check_resource_group_property"
check_describe_collection_property = "check_describe_collection_property"
class BulkLoadStates:
......
......@@ -45,6 +45,8 @@ def pytest_addoption(parser):
parser.addoption('--field_name', action='store', default="field_name", help="field_name of index")
parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number")
parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip")
parser.addoption('--uri', action='store', default="", help="uri for high level api")
parser.addoption('--token', action='store', default="", help="token for high level api")
@pytest.fixture
......@@ -174,6 +176,16 @@ def minio_host(request):
return request.config.getoption("--minio_host")
@pytest.fixture
def uri(request):
return request.config.getoption("--uri")
@pytest.fixture
def token(request):
return request.config.getoption("--token")
""" fixture func """
......@@ -188,6 +200,8 @@ def initialize_env(request):
secure = request.config.getoption("--secure")
clean_log = request.config.getoption("--clean_log")
replica_num = request.config.getoption("--replica_num")
uri = request.config.getoption("--uri")
token = request.config.getoption("--token")
""" params check """
assert ip_check(host) and number_check(port)
......@@ -200,7 +214,7 @@ def initialize_env(request):
log.info("#" * 80)
log.info("[initialize_milvus] Log cleaned up, start testing...")
param_info.prepare_param_info(host, port, handler, replica_num, user, password, secure)
param_info.prepare_param_info(host, port, handler, replica_num, user, password, secure, uri, token)
@pytest.fixture(params=ct.get_invalid_strs)
......
import multiprocessing
import numbers
import random
import numpy
import threading
import pytest
import pandas as pd
import decimal
from decimal import Decimal, getcontext
from time import sleep
import heapq
from base.client_base import TestcaseBase
from utils.util_log import test_log as log
from common import common_func as cf
from common import common_type as ct
from common.common_type import CaseLabel, CheckTasks
from utils.util_pymilvus import *
from common.constants import *
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY
from base.high_level_api_wrapper import HighLevelApiWrapper
client_w = HighLevelApiWrapper()
prefix = "high_level_api"
epsilon = ct.epsilon
default_nb = ct.default_nb
default_nb_medium = ct.default_nb_medium
default_nq = ct.default_nq
default_dim = ct.default_dim
default_limit = ct.default_limit
default_search_exp = "id >= 0"
exp_res = "exp_res"
default_search_string_exp = "varchar >= \"0\""
default_search_mix_exp = "int64 >= 0 && varchar >= \"0\""
default_invaild_string_exp = "varchar >= 0"
default_json_search_exp = "json_field[\"number\"] >= 0"
perfix_expr = 'varchar like "0%"'
default_search_field = ct.default_float_vec_field_name
default_search_params = ct.default_search_params
default_primary_key_field_name = "id"
default_vector_field_name = "vector"
default_float_field_name = ct.default_float_field_name
default_bool_field_name = ct.default_bool_field_name
default_string_field_name = ct.default_string_field_name
class TestHighLevelApi(TestcaseBase):
""" Test case of search interface """
@pytest.fixture(scope="function", params=[False, True])
def auto_id(self, request):
yield request.param
@pytest.fixture(scope="function", params=["COSINE", "L2"])
def metric_type(self, request):
yield request.param
"""
******************************************************************
# The following are invalid base cases
******************************************************************
"""
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.xfail(reason="pymilvus issue 1554")
def test_high_level_collection_invalid_primary_field(self):
"""
target: test high level api: client.create_collection
method: create collection with invalid primary field
expected: Raise exception
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
error = {ct.err_code: 1, ct.err_msg: f"Param id_type must be int or string"}
client_w.create_collection(client, collection_name, default_dim, id_type="invalid",
check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L2)
def test_high_level_collection_string_auto_id(self):
"""
target: test high level api: client.create_collection
method: create collection with auto id on string primary key
expected: Raise exception
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
error = {ct.err_code: 1, ct.err_msg: f"The auto_id can only be specified on field with DataType.INT64"}
client_w.create_collection(client, collection_name, default_dim, id_type="string", auto_id=True,
check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1)
def test_high_level_create_same_collection_different_params(self):
"""
target: test high level api: client.create_collection
method: create
expected: 1. Successfully to create collection with same params
2. Report errors for creating collection with same name and different params
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
client_w.create_collection(client, collection_name, default_dim)
# 2. create collection with same params
client_w.create_collection(client, collection_name, default_dim)
# 3. create collection with same name and different params
error = {ct.err_code: 1, ct.err_msg: f"create duplicate collection with different parameters, "
f"collection: {collection_name}"}
client_w.create_collection(client, collection_name, default_dim+1,
check_task=CheckTasks.err_res, check_items=error)
client_w.drop_collection(client, collection_name)
@pytest.mark.tags(CaseLabel.L2)
def test_high_level_collection_invalid_metric_type(self):
"""
target: test high level api: client.create_collection
method: create collection with auto id on string primary key
expected: Raise exception
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
error = {ct.err_code: 1, ct.err_msg: f"metric type not found or not supported, supported: [L2 IP COSINE]"}
client_w.create_collection(client, collection_name, default_dim, metric_type="invalid",
check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L2)
def test_high_level_search_not_consistent_metric_type(self, metric_type):
"""
target: test search with inconsistent metric type (default is IP) with that of index
method: create connection, collection, insert and search with not consistent metric type
expected: Raise exception
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
client_w.create_collection(client, collection_name, default_dim)
# 2. search
rng = np.random.default_rng(seed=19530)
vectors_to_search = rng.random((1, 8))
search_params = {"metric_type": metric_type}
error = {ct.err_code: 1, ct.err_msg: f"metric type not match: expected=IP, actual={metric_type}"}
client_w.search(client, collection_name, vectors_to_search, limit=default_limit,
search_params=search_params,
check_task=CheckTasks.err_res, check_items=error)
client_w.drop_collection(client, collection_name)
"""
******************************************************************
# The following are valid base cases
******************************************************************
"""
@pytest.mark.tags(CaseLabel.L1)
def test_high_level_search_query_default(self):
"""
target: test search (high level api) normal case
method: create connection, collection, insert and search
expected: search/query successfully
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
client_w.create_collection(client, collection_name, default_dim)
collections = client_w.list_collections(client)[0]
assert collection_name in collections
client_w.describe_collection(client, collection_name,
check_task=CheckTasks.check_describe_collection_property,
check_items={"collection_name": collection_name,
"dim": default_dim})
# 2. insert
rng = np.random.default_rng(seed=19530)
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
client_w.insert(client, collection_name, rows)
client_w.flush(client, collection_name)
assert client_w.num_entities(client, collection_name)[0] == default_nb
# 3. search
vectors_to_search = rng.random((1, default_dim))
insert_ids = [i for i in range(default_nb)]
client_w.search(client, collection_name, vectors_to_search,
check_task=CheckTasks.check_search_results,
check_items={"enable_high_level_api": True,
"nq": len(vectors_to_search),
"ids": insert_ids,
"limit": default_limit})
# 4. query
client_w.query(client, collection_name, filter=default_search_exp,
check_task=CheckTasks.check_query_results,
check_items={exp_res: rows,
"with_vec": True,
"primary_field": default_primary_key_field_name})
client_w.drop_collection(client, collection_name)
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip(reason="issue 25110")
def test_high_level_search_query_string(self):
"""
target: test search (high level api) for string primary key
method: create connection, collection, insert and search
expected: search/query successfully
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
client_w.create_collection(client, collection_name, default_dim, id_type="string", max_length=ct.default_length)
client_w.describe_collection(client, collection_name,
check_task=CheckTasks.check_describe_collection_property,
check_items={"collection_name": collection_name,
"dim": default_dim,
"auto_id": auto_id})
# 2. insert
rng = np.random.default_rng(seed=19530)
rows = [{default_primary_key_field_name: str(i), default_vector_field_name: list(rng.random((1, default_dim))[0]),
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
client_w.insert(client, collection_name, rows)
client_w.flush(client, collection_name)
assert client_w.num_entities(client, collection_name)[0] == default_nb
# 3. search
vectors_to_search = rng.random((1, default_dim))
client_w.search(client, collection_name, vectors_to_search,
check_task=CheckTasks.check_search_results,
check_items={"enable_high_level_api": True,
"nq": len(vectors_to_search),
"limit": default_limit})
# 4. query
client_w.query(client, collection_name, filter=default_search_exp,
check_task=CheckTasks.check_query_results,
check_items={exp_res: rows,
"with_vec": True,
"primary_field": default_primary_key_field_name})
client_w.drop_collection(client, collection_name)
@pytest.mark.tags(CaseLabel.L2)
def test_high_level_search_different_metric_types(self, metric_type, auto_id):
"""
target: test search (high level api) normal case
method: create connection, collection, insert and search
expected: search successfully with limit(topK)
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
client_w.create_collection(client, collection_name, default_dim, metric_type=metric_type, auto_id=auto_id)
# 2. insert
rng = np.random.default_rng(seed=19530)
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
if auto_id:
for row in rows:
row.pop(default_primary_key_field_name)
client_w.insert(client, collection_name, rows)
client_w.flush(client, collection_name)
assert client_w.num_entities(client, collection_name)[0] == default_nb
# 3. search
vectors_to_search = rng.random((1, default_dim))
search_params = {"metric_type": metric_type}
client_w.search(client, collection_name, vectors_to_search, limit=default_limit,
search_params=search_params,
output_fields=[default_primary_key_field_name],
check_task=CheckTasks.check_search_results,
check_items={"enable_high_level_api": True,
"nq": len(vectors_to_search),
"limit": default_limit})
client_w.drop_collection(client, collection_name)
@pytest.mark.tags(CaseLabel.L1)
def test_high_level_delete(self):
"""
target: test delete (high level api)
method: create connection, collection, insert delete, and search
expected: search/query successfully without deleted data
"""
client = self._connect(enable_high_level_api=True)
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
client_w.create_collection(client, collection_name, default_dim, consistency_level="Strong")
# 2. insert
default_nb = 1000
rng = np.random.default_rng(seed=19530)
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
pks = client_w.insert(client, collection_name, rows)[0]
client_w.flush(client, collection_name)
assert client_w.num_entities(client, collection_name)[0] == default_nb
# 3. get first primary key
first_pk_data = client_w.get(client, collection_name, pks[0:1])
# 4. delete
delete_num = 3
client_w.delete(client, collection_name, pks[0:delete_num])
# 5. search
vectors_to_search = rng.random((1, default_dim))
insert_ids = [i for i in range(default_nb)]
for insert_id in pks[0:delete_num]:
if insert_id in insert_ids:
insert_ids.remove(insert_id)
limit = default_nb - delete_num
client_w.search(client, collection_name, vectors_to_search, limit=default_nb,
check_task=CheckTasks.check_search_results,
check_items={"enable_high_level_api": True,
"nq": len(vectors_to_search),
"ids": insert_ids,
"limit": limit})
# 6. query
client_w.query(client, collection_name, filter=default_search_exp,
check_task=CheckTasks.check_query_results,
check_items={exp_res: rows[delete_num:],
"with_vec": True,
"primary_field": default_primary_key_field_name})
client_w.drop_collection(client, collection_name)
......@@ -3771,8 +3771,7 @@ class TestCollectionSearch(TestcaseBase):
collection_w.search(vectors[:nq], default_search_field,
default_search_params, limit,
default_search_exp, _async=_async,
**kwargs
)
**kwargs)
@pytest.mark.tags(CaseLabel.L1)
def test_search_with_consistency_session(self, nq, dim, auto_id, _async, enable_dynamic_field):
......@@ -5624,7 +5623,7 @@ class TestSearchDiskann(TestcaseBase):
collection_w.create_index(ct.default_float_vec_field_name, default_index)
collection_w.load()
search_list = 20
default_search_params ={"metric_type": "L2", "params": {"search_list": search_list}}
default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]
collection_w.search(vectors[:default_nq], default_search_field,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册