未验证 提交 84319ed3 编写于 作者: 紫晴 提交者: GitHub

Update testceses of query (#6154)

* [skip ci] Update query test cases
Signed-off-by: Nwangting0128 <ting.wang@zilliz.com>

* [skip ci] Conflict resolution
Signed-off-by: Nwangting0128 <ting.wang@zilliz.com>

* [skip ci] Update teardown
Signed-off-by: Nwangting0128 <ting.wang@zilliz.com>
上级 b87baa10
...@@ -41,6 +41,7 @@ class Base: ...@@ -41,6 +41,7 @@ class Base:
utility_wrap = None utility_wrap = None
collection_schema_wrap = None collection_schema_wrap = None
field_schema_wrap = None field_schema_wrap = None
collection_object_list = []
def setup_class(self): def setup_class(self):
log.info("[setup_class] Start setup class...") log.info("[setup_class] Start setup class...")
...@@ -52,6 +53,7 @@ class Base: ...@@ -52,6 +53,7 @@ class Base:
log.info(("*" * 35) + " setup " + ("*" * 35)) log.info(("*" * 35) + " setup " + ("*" * 35))
self.connection_wrap = ApiConnectionsWrapper() self.connection_wrap = ApiConnectionsWrapper()
self.collection_wrap = ApiCollectionWrapper() self.collection_wrap = ApiCollectionWrapper()
self.collection_object_list.append(self.collection_wrap)
self.partition_wrap = ApiPartitionWrapper() self.partition_wrap = ApiPartitionWrapper()
self.index_wrap = ApiIndexWrapper() self.index_wrap = ApiIndexWrapper()
self.utility_wrap = ApiUtilityWrapper() self.utility_wrap = ApiUtilityWrapper()
...@@ -63,14 +65,20 @@ class Base: ...@@ -63,14 +65,20 @@ class Base:
try: try:
""" Drop collection before disconnect """ """ Drop collection before disconnect """
# if self.collection_wrap is not None and self.collection_wrap.collection is not None: if self.connection_wrap.get_connection(alias=DefaultConfig.DEFAULT_USING)[0] is None:
# self.collection_wrap.drop() self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, host=param_info.param_host,
if self.collection_wrap is not None: port=param_info.param_port)
collection_list = self.utility_wrap.list_collections()[0]
for i in collection_list: for collection_object in self.collection_object_list:
collection_wrap = ApiCollectionWrapper() if collection_object is not None and collection_object.collection is not None:
collection_wrap.init_collection(name=i) collection_object.drop()
collection_wrap.drop()
# if self.collection_wrap is not None:
# collection_list = self.utility_wrap.list_collections()[0]
# for i in collection_list:
# collection_wrap = ApiCollectionWrapper()
# collection_wrap.init_collection(name=i)
# collection_wrap.drop()
except Exception as e: except Exception as e:
pass pass
...@@ -111,13 +119,14 @@ class TestcaseBase(Base): ...@@ -111,13 +119,14 @@ class TestcaseBase(Base):
Public methods that can be used to add cases. Public methods that can be used to add cases.
""" """
@pytest.fixture(scope="module", params=ct.get_invalid_strs) # move to conftest.py
def get_invalid_string(self, request): # @pytest.fixture(scope="module", params=ct.get_invalid_strs)
yield request.param # def get_invalid_string(self, request):
# yield request.param
@pytest.fixture(scope="module", params=cf.gen_simple_index()) #
def get_index_param(self, request): # @pytest.fixture(scope="module", params=cf.gen_simple_index())
yield request.param # def get_index_param(self, request):
# yield request.param
def _connect(self): def _connect(self):
""" Add an connection and create the connect """ """ Add an connection and create the connect """
...@@ -131,6 +140,7 @@ class TestcaseBase(Base): ...@@ -131,6 +140,7 @@ class TestcaseBase(Base):
if self.connection_wrap.get_connection(alias=DefaultConfig.DEFAULT_USING)[0] is None: if self.connection_wrap.get_connection(alias=DefaultConfig.DEFAULT_USING)[0] is None:
self._connect() self._connect()
collection_w = ApiCollectionWrapper() collection_w = ApiCollectionWrapper()
self.collection_object_list.append(collection_w)
collection_w.init_collection(name=name, schema=schema, check_task=check_task, check_items=check_items, **kwargs) collection_w.init_collection(name=name, schema=schema, check_task=check_task, check_items=check_items, **kwargs)
return collection_w return collection_w
...@@ -176,7 +186,7 @@ class TestcaseBase(Base): ...@@ -176,7 +186,7 @@ class TestcaseBase(Base):
if insert_data: if insert_data:
collection_w, vectors, binary_raw_vectors = \ collection_w, vectors, binary_raw_vectors = \
cf.insert_data(collection_w, nb, is_binary, is_all_data_type) cf.insert_data(collection_w, nb, is_binary, is_all_data_type)
assert collection_w.is_empty == False assert collection_w.is_empty is False
assert collection_w.num_entities == nb assert collection_w.num_entities == nb
collection_w.load() collection_w.load()
......
...@@ -61,7 +61,7 @@ class ResponseChecker: ...@@ -61,7 +61,7 @@ class ResponseChecker:
assert len(error_dict) > 0 assert len(error_dict) > 0
if isinstance(res, Error): if isinstance(res, Error):
error_code = error_dict[ct.err_code] error_code = error_dict[ct.err_code]
assert res.code == error_code or error_dict[ct.err_msg] in res.message assert res.code == error_code and error_dict[ct.err_msg] in res.message
else: else:
log.error("[CheckFunc] Response of API is not an error: %s" % str(res)) log.error("[CheckFunc] Response of API is not an error: %s" % str(res))
assert False assert False
...@@ -190,3 +190,4 @@ class ResponseChecker: ...@@ -190,3 +190,4 @@ class ResponseChecker:
# assert len(exp_res) == len(query_res) # assert len(exp_res) == len(query_res)
# for i in range(len(exp_res)): # for i in range(len(exp_res)):
# assert_entity_equal(exp=exp_res[i], actual=query_res[i]) # assert_entity_equal(exp=exp_res[i], actual=query_res[i])
...@@ -11,6 +11,16 @@ ErrorMessage = {ErrorCode.ErrorOk: "", ...@@ -11,6 +11,16 @@ ErrorMessage = {ErrorCode.ErrorOk: "",
ErrorCode.Error: "is illegal"} ErrorCode.Error: "is illegal"}
class ErrorMap:
def __init__(self, err_code, err_msg):
self.err_code = err_code
self.err_msg = err_msg
class ConnectionErrorMessage(ExceptionsMessage): class ConnectionErrorMessage(ExceptionsMessage):
FailConnect = "Fail connecting to server on %s:%s. Timeout" FailConnect = "Fail connecting to server on %s:%s. Timeout"
ConnectExist = "The connection named %s already creating, but passed parameters don't match the configured parameters" ConnectExist = "The connection named %s already creating, but passed parameters don't match the configured parameters"
class CollectionErrorMessage(ExceptionsMessage):
CollNotLoaded = "collection %s was not loaded into memory"
import pytest import pytest
import common.common_type as ct
import common.common_func as cf
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption("--ip", action="store", default="localhost", help="service's ip") parser.addoption("--ip", action="store", default="localhost", help="service's ip")
...@@ -20,6 +23,8 @@ def pytest_addoption(parser): ...@@ -20,6 +23,8 @@ def pytest_addoption(parser):
parser.addoption('--clean_log', action='store_true', default=False, help="clean log before testing") parser.addoption('--clean_log', action='store_true', default=False, help="clean log before testing")
parser.addoption('--schema', action='store', default="schema", help="schema of test interface") parser.addoption('--schema', action='store', default="schema", help="schema of test interface")
parser.addoption('--err_msg', action='store', default="err_msg", help="error message of test") parser.addoption('--err_msg', action='store', default="err_msg", help="error message of test")
parser.addoption('--term_expr', action='store', default="term_expr", help="expr of query quest")
parser.addoption('--check_content', action='store', default="check_content", help="content of check")
@pytest.fixture @pytest.fixture
...@@ -110,3 +115,26 @@ def schema(request): ...@@ -110,3 +115,26 @@ def schema(request):
@pytest.fixture @pytest.fixture
def err_msg(request): def err_msg(request):
return request.config.getoption("--err_msg") return request.config.getoption("--err_msg")
@pytest.fixture
def term_expr(request):
return request.config.getoption("--term_expr")
@pytest.fixture
def check_content(request):
return request.config.getoption("--check_content")
""" fixture func """
@pytest.fixture(params=ct.get_invalid_strs)
def get_invalid_string(request):
yield request.param
@pytest.fixture(params=cf.gen_simple_index())
def get_index_param(request):
yield request.param
[pytest] [pytest]
addopts = --host 192.168.1.239 --html=/tmp/ci_logs/report.html --self-contained-html addopts = --host 10.98.0.11 --html=/tmp/ci_logs/report.html --self-contained-html
# -;addopts = --host 172.28.255.155 --html=/tmp/report.html # -;addopts = --host 172.28.255.155 --html=/tmp/report.html
# python3 -W ignore -m pytest # python3 -W ignore -m pytest
\ No newline at end of file
import pytest import pytest
import random
from pymilvus_orm.default_config import DefaultConfig from pymilvus_orm.default_config import DefaultConfig
from base.client_base import TestcaseBase from base.client_base import TestcaseBase
from common.code_mapping import ConnectionErrorMessage as cem from common.code_mapping import ConnectionErrorMessage as cem
from common.code_mapping import CollectionErrorMessage as clem
from common import common_func as cf from common import common_func as cf
from common import common_type as ct from common import common_type as ct
from common.common_type import CaseLabel, CheckTasks from common.common_type import CaseLabel, CheckTasks
...@@ -433,7 +435,6 @@ class TestQueryBase(TestcaseBase): ...@@ -433,7 +435,6 @@ class TestQueryBase(TestcaseBase):
check_items=CheckTasks.err_res, check_task=error) check_items=CheckTasks.err_res, check_task=error)
# @pytest.mark.skip(reason="waiting for debug")
class TestQueryOperation(TestcaseBase): class TestQueryOperation(TestcaseBase):
""" """
****************************************************************** ******************************************************************
...@@ -463,106 +464,146 @@ class TestQueryOperation(TestcaseBase): ...@@ -463,106 +464,146 @@ class TestQueryOperation(TestcaseBase):
collection_w.query(default_term_expr, check_task=CheckTasks.err_res, collection_w.query(default_term_expr, check_task=CheckTasks.err_res,
check_items={ct.err_code: 0, ct.err_msg: cem.ConnectFirst}) check_items={ct.err_code: 0, ct.err_msg: cem.ConnectFirst})
def test_query_without_loading(self): @pytest.mark.tags(ct.CaseLabel.L3)
@pytest.mark.parametrize("collection_name, data",
[(cf.gen_unique_str(prefix), cf.gen_default_list_data(ct.default_nb))])
def test_query_without_loading(self, collection_name, data):
""" """
target: test query without loading target: test query without loading
method: no loading before query method: no loading before query
expected: raise exception expected: raise exception
""" """
c_name = cf.gen_unique_str(prefix)
collection_w = self.init_collection_wrap(name=c_name) # init a collection with default connection
data = cf.gen_default_list_data(ct.default_nb) collection_w = self.init_collection_wrap(name=collection_name)
# insert data to collection
collection_w.insert(data=data) collection_w.insert(data=data)
conn, _ = self.connection_wrap.get_connection()
conn.flush([c_name]) # check number of entities and that method calls the flush interface
assert collection_w.num_entities == ct.default_nb assert collection_w.num_entities == ct.default_nb
error = {ct.err_code: 1, ct.err_msg: "can not find collection"}
collection_w.query(default_term_expr, check_task=CheckTasks.err_res, check_items=error)
def test_query_expr_single_term_array(self): # query without load
collection_w.query(default_term_expr, check_task=CheckTasks.err_res,
check_items={ct.err_code: 1, ct.err_msg: clem.CollNotLoaded % collection_name})
@pytest.mark.tags(ct.CaseLabel.L3)
@pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]'])
def test_query_expr_single_term_array(self, term_expr):
""" """
target: test query with single array term expr target: test query with single array term expr
method: query with single array value method: query with single array value
expected: query result is one entity expected: query result is one entity
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
term_expr = f'{ct.default_int64_field_name} in [0]'
res, _ = collection_w.query(term_expr)
assert len(res) == 1
df = vectors[0]
assert res[0][ct.default_int64_field_name] == df[ct.default_int64_field_name].values.tolist()[0]
assert res[1][ct.default_float_field_name] == df[ct.default_float_field_name].values.tolist()[0]
assert res[2][ct.default_float_vec_field_name] == df[ct.default_float_vec_field_name].values.tolist()[0]
def test_query_binary_expr_single_term_array(self): # init a collection and insert data
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)
# query the first row of data
check_vec = vectors[0].iloc[:, [0, 1]][0:1].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L3)
@pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]'])
def test_query_binary_expr_single_term_array(self, term_expr, check_content):
""" """
target: test query with single array term expr target: test query with single array term expr
method: query with single array value method: query with single array value
expected: query result is one entity expected: query result is one entity
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True, is_binary=True)
term_expr = f'{ct.default_int64_field_name} in [0]'
res, _ = collection_w.query(term_expr)
assert len(res) == 1
int_values = vectors[0][ct.default_int64_field_name].values.tolist()
float_values = vectors[0][ct.default_float_field_name].values.tolist()
vec_values = vectors[0][ct.default_float_vec_field_name].values.tolist()
assert res[0][ct.default_int64_field_name] == int_values[0]
assert res[1][ct.default_float_field_name] == float_values[0]
assert res[2][ct.default_float_vec_field_name] == vec_values[0]
# init a collection and insert data
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True,
is_binary=True)
# query the first row of data
check_vec = vectors[0].iloc[:, [0, 1]][0:1].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_expr_all_term_array(self): def test_query_expr_all_term_array(self):
""" """
target: test query with all array term expr target: test query with all array term expr
method: query with all array value method: query with all array value
expected: verify query result expected: verify query result
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
# init a collection and insert data
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)
# data preparation
int_values = vectors[0][ct.default_int64_field_name].values.tolist() int_values = vectors[0][ct.default_int64_field_name].values.tolist()
term_expr = f'{ct.default_int64_field_name} in {int_values}' term_expr = f'{ct.default_int64_field_name} in {int_values}'
res, _ = collection_w.query(term_expr) check_vec = vectors[0].iloc[:, [0, 1]][0:len(int_values)].to_dict('records')
assert len(res) == ct.default_nb
for i in ct.default_nb: # query all array value
assert res[i][ct.default_int64_field_name] == int_values[i] collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_expr_half_term_array(self): def test_query_expr_half_term_array(self):
""" """
target: test query with half array term expr target: test query with half array term expr
method: query with half array value method: query with half array value
expected: verify query result expected: verify query result
""" """
half = ct.default_nb // 2 half = ct.default_nb // 2
collection_w, partition_w, _, df_default = self.insert_entities_into_two_partitions_in_half(half) collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half(half)
int_values = df_default[ct.default_int64_field_name].values.tolist() int_values = df_default[ct.default_int64_field_name].values.tolist()
float_values = df_default[ct.default_float_field_name].values.tolist()
vec_values = df_default[ct.default_float_vec_field_name].values.tolist()
term_expr = f'{ct.default_int64_field_name} in {int_values}' term_expr = f'{ct.default_int64_field_name} in {int_values}'
res, _ = collection_w.query(term_expr) res, _ = collection_w.query(term_expr)
assert len(res) == half assert len(res) == len(int_values)
for i in half:
assert res[i][ct.default_int64_field_name] == int_values[i] # half = ct.default_nb // 2
assert res[i][ct.default_float_field_name] == float_values[i] # collection_w, partition_w, _, df_default = self.insert_entities_into_two_partitions_in_half(half)
assert res[i][ct.default_float_vec_field_name] == vec_values[i] # int_values = df_default[ct.default_int64_field_name].values.tolist()
# float_values = df_default[ct.default_float_field_name].values.tolist()
# vec_values = df_default[ct.default_float_vec_field_name].values.tolist()
# term_expr = f'{ct.default_int64_field_name} in {int_values}'
# res, _ = collection_w.query(term_expr)
# assert len(res) == half
# for i in half:
# assert res[i][ct.default_int64_field_name] == int_values[i]
# assert res[i][ct.default_float_field_name] == float_values[i]
# assert res[i][ct.default_float_vec_field_name] == vec_values[i]
@pytest.mark.xfail(reason="fail")
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_expr_repeated_term_array(self): def test_query_expr_repeated_term_array(self):
""" """
target: test query with repeated term array on primary field with unique value target: test query with repeated term array on primary field with unique value
method: query with repeated array value method: query with repeated array value
expected: verify query result expected: verify query result
""" """
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True) collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)
int_values = [0, 0] int_values = [0, 0, 0, 0]
term_expr = f'{ct.default_int64_field_name} in {int_values}' term_expr = f'{ct.default_int64_field_name} in {int_values}'
res, _ = collection_w.query(term_expr) res, _ = collection_w.query(term_expr)
assert len(res) == 1 assert len(res) == 1
assert res[0][ct.default_int64_field_name] == int_values[0] assert res[0][ct.default_int64_field_name] == int_values[0]
def test_query_after_index(self, get_simple_index): @pytest.mark.tags(ct.CaseLabel.L3)
def test_query_after_index(self):
""" """
target: test query after creating index target: test query after creating index
method: query after index method: query after index
expected: query result is correct expected: query result is correct
""" """
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)
default_field_name = ct.default_float_vec_field_name
default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}}
index_name = ct.default_index_name
collection_w.create_index(default_field_name, default_index_params, index_name=index_name)
collection_w.load()
int_values = [0]
term_expr = f'{ct.default_int64_field_name} in {int_values}'
check_vec = vectors[0].iloc[:, [0, 1]][0:len(int_values)].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
# entities, ids = init_data(connect, collection) # entities, ids = init_data(connect, collection)
# assert len(ids) == ut.default_nb # assert len(ids) == ut.default_nb
# connect.create_index(collection, ut.default_float_vec_field_name, get_simple_index) # connect.create_index(collection, ut.default_float_vec_field_name, get_simple_index)
...@@ -570,12 +611,33 @@ class TestQueryOperation(TestcaseBase): ...@@ -570,12 +611,33 @@ class TestQueryOperation(TestcaseBase):
# res = connect.query(collection, default_term_expr) # res = connect.query(collection, default_term_expr)
# logging.getLogger().info(res) # logging.getLogger().info(res)
@pytest.mark.xfail(reason='')
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_after_search(self): def test_query_after_search(self):
""" """
target: test query after search target: test query after search
method: query after search method: query after search
expected: query result is correct expected: query result is correct
""" """
limit = 1000
nb_old = 500
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, True, nb_old)
# 2. search for original data after load
vectors_s = [[random.random() for _ in range(ct.default_dim)] for _ in range(ct.default_nq)]
collection_w.search(vectors_s[:ct.default_nq], ct.default_float_vec_field_name,
ct.default_search_params, limit, "int64 >= 0",
check_task=CheckTasks.check_search_results,
check_items={"nq": ct.default_nq, "limit": nb_old})
# check number of entities and that method calls the flush interface
assert collection_w.num_entities == nb_old
term_expr = f'{ct.default_int64_field_name} in {default_term_expr}'
check_vec = vectors[0].iloc[:, [0, 1]][0:len(default_term_expr)].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
# entities, ids = init_data(connect, collection) # entities, ids = init_data(connect, collection)
# assert len(ids) == ut.default_nb # assert len(ids) == ut.default_nb
# top_k = 10 # top_k = 10
...@@ -588,23 +650,50 @@ class TestQueryOperation(TestcaseBase): ...@@ -588,23 +650,50 @@ class TestQueryOperation(TestcaseBase):
# query_res = connect.query(collection, default_term_expr) # query_res = connect.query(collection, default_term_expr)
# logging.getLogger().info(query_res) # logging.getLogger().info(query_res)
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_partition_repeatedly(self): def test_query_partition_repeatedly(self):
""" """
target: test query repeatedly on partition target: test query repeatedly on partition
method: query on partition twice method: query on partition twice
expected: verify query result expected: verify query result
""" """
conn = self._connect()
# create connection
self._connect()
# init collection
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
# init partition
partition_w = self.init_partition_wrap(collection_wrap=collection_w) partition_w = self.init_partition_wrap(collection_wrap=collection_w)
# insert data to partition
df = cf.gen_default_dataframe_data(ct.default_nb) df = cf.gen_default_dataframe_data(ct.default_nb)
partition_w.insert(df) partition_w.insert(df)
conn.flush([collection_w.name])
# check number of entities and that method calls the flush interface
assert collection_w.num_entities == ct.default_nb
# load partition
partition_w.load() partition_w.load()
# query twice
res_one, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name]) res_one, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name])
res_two, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name]) res_two, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name])
assert res_one == res_two assert res_one == res_two
# conn = self._connect()
# collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
# partition_w = self.init_partition_wrap(collection_wrap=collection_w)
# df = cf.gen_default_dataframe_data(ct.default_nb)
# partition_w.insert(df)
# conn.flush([collection_w.name])
# partition_w.load()
# res_one, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name])
# res_two, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name])
# assert res_one == res_two
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_another_partition(self): def test_query_another_partition(self):
""" """
target: test query another partition target: test query another partition
...@@ -614,11 +703,13 @@ class TestQueryOperation(TestcaseBase): ...@@ -614,11 +703,13 @@ class TestQueryOperation(TestcaseBase):
""" """
half = ct.default_nb // 2 half = ct.default_nb // 2
collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half) collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half)
term_expr = f'{ct.default_int64_field_name} in [{half}]' term_expr = f'{ct.default_int64_field_name} in [{half}]'
# half entity in _default partition rather than partition_w # half entity in _default partition rather than partition_w
res, _ = collection_w.query(term_expr, partition_names=[partition_w.name]) collection_w.query(term_expr, partition_names=[partition_w.name], check_task=CheckTasks.check_query_results,
assert len(res) == 0 check_items={exp_res: []})
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_multi_partitions_multi_results(self): def test_query_multi_partitions_multi_results(self):
""" """
target: test query on multi partitions and get multi results target: test query on multi partitions and get multi results
...@@ -628,11 +719,13 @@ class TestQueryOperation(TestcaseBase): ...@@ -628,11 +719,13 @@ class TestQueryOperation(TestcaseBase):
""" """
half = ct.default_nb // 2 half = ct.default_nb // 2
collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half) collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half)
term_expr = f'{ct.default_int64_field_name} in [{half - 1}, {half}]' term_expr = f'{ct.default_int64_field_name} in [{half - 1}, {half}]'
# half entity in _default, half-1 entity in partition_w # half entity in _default, half-1 entity in partition_w
res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name]) res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name])
assert len(res) == 2 assert len(res) == 2
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_multi_partitions_single_result(self): def test_query_multi_partitions_single_result(self):
""" """
target: test query on multi partitions and get single result target: test query on multi partitions and get single result
...@@ -641,7 +734,8 @@ class TestQueryOperation(TestcaseBase): ...@@ -641,7 +734,8 @@ class TestQueryOperation(TestcaseBase):
expected: query from two partitions and get single result expected: query from two partitions and get single result
""" """
half = ct.default_nb // 2 half = ct.default_nb // 2
collection_w, partition_w = self.insert_entities_into_two_partitions_in_half(half) collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half(half)
term_expr = f'{ct.default_int64_field_name} in [{half}]' term_expr = f'{ct.default_int64_field_name} in [{half}]'
# half entity in _default # half entity in _default
res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name]) res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name])
......
...@@ -8,7 +8,7 @@ class Error: ...@@ -8,7 +8,7 @@ class Error:
self.message = getattr(error, 'message', str(error)) self.message = getattr(error, 'message', str(error))
log_row_length = 300 log_row_length = 3000
def api_request_catch(): def api_request_catch():
...@@ -16,12 +16,13 @@ def api_request_catch(): ...@@ -16,12 +16,13 @@ def api_request_catch():
def inner_wrapper(*args, **kwargs): def inner_wrapper(*args, **kwargs):
try: try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
# log.debug("(api_response) Response : %s " % str(res)[0:log_row_length]) log_res = str(res)[0:log_row_length] + '......' if len(str(res)) > log_row_length else str(res)
log.debug("(api_response) Response : %s " % log_res)
return res, True return res, True
except Exception as e: except Exception as e:
log_e = str(e)[0:log_row_length] + '......' if len(str(e)) > log_row_length else str(e)
log.error(traceback.format_exc()) log.error(traceback.format_exc())
log.error("(api_response) [Milvus API Exception]%s: %s" log.error("(api_response) [Milvus API Exception]%s: %s" % (str(func), log_e))
% (str(func), str(e)[0:log_row_length]))
return Error(e), False return Error(e), False
return inner_wrapper return inner_wrapper
return wrapper return wrapper
...@@ -36,8 +37,8 @@ def api_request(_list, **kwargs): ...@@ -36,8 +37,8 @@ def api_request(_list, **kwargs):
if len(_list) > 1: if len(_list) > 1:
for a in _list[1:]: for a in _list[1:]:
arg.append(a) arg.append(a)
# log.debug("(api_request) Request: [%s] args: %s, kwargs: %s" log_arg = str(arg)[0:log_row_length] + '......' if len(str(arg)) > log_row_length else str(arg)
# % (str(func), str(arg)[0:log_row_length], str(kwargs))) log.debug("(api_request) Request: [%s] args: %s, kwargs: %s" % (str(func), log_arg, str(kwargs)))
return func(*arg, **kwargs) return func(*arg, **kwargs)
return False, False return False, False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册