From 55cc6be6c93f9e1cd53bf97738e1e10465028738 Mon Sep 17 00:00:00 2001 From: del-zhenwu <56623710+del-zhenwu@users.noreply.github.com> Date: Tue, 21 Apr 2020 09:33:05 +0800 Subject: [PATCH] Set case level (#1968) * add mishards test stage Signed-off-by: zw * add clean up Signed-off-by: zw * Set case level Signed-off-by: zw * update test timeout Signed-off-by: zw * remove some test cases Signed-off-by: zw * udpate Signed-off-by: zw * remove connect Signed-off-by: zw * update case Signed-off-by: zw --- ci/jenkins/Jenkinsfile | 2 +- tests/milvus_python_test/conftest.py | 73 +++++--- tests/milvus_python_test/requirements.txt | 1 + tests/milvus_python_test/test_add_vectors.py | 6 +- tests/milvus_python_test/test_collection.py | 10 +- .../test_collection_count.py | 10 +- tests/milvus_python_test/test_connect.py | 129 +++---------- tests/milvus_python_test/test_index.py | 14 +- tests/milvus_python_test/test_partition.py | 14 ++ .../milvus_python_test/test_search_vectors.py | 177 +++++++++--------- tests/milvus_python_test/utils.py | 8 +- 11 files changed, 197 insertions(+), 247 deletions(-) diff --git a/ci/jenkins/Jenkinsfile b/ci/jenkins/Jenkinsfile index ec867ed8..70d249fe 100644 --- a/ci/jenkins/Jenkinsfile +++ b/ci/jenkins/Jenkinsfile @@ -13,7 +13,7 @@ pipeline { options { timestamps() - timeout(time: 3, unit: 'HOURS') + timeout(time: 4, unit: 'HOURS') } parameters{ diff --git a/tests/milvus_python_test/conftest.py b/tests/milvus_python_test/conftest.py index 4cca06d7..09461a8f 100644 --- a/tests/milvus_python_test/conftest.py +++ b/tests/milvus_python_test/conftest.py @@ -8,6 +8,7 @@ from milvus import Milvus, IndexType, MetricType from utils import * index_file_size = 10 +timeout = 1 def pytest_addoption(parser): @@ -37,26 +38,18 @@ def connect(request): port = request.config.getoption("--port") http_port = request.config.getoption("--http-port") handler = request.config.getoption("--handler") - milvus = get_milvus(handler=handler) + if handler == "HTTP": + port = http_port try: - if handler == "HTTP": - port = http_port - status = milvus.connect(host=ip, port=port) - logging.getLogger().info(status) - if not status.OK(): - # try again - logging.getLogger().info("------------------------------------") - logging.getLogger().info("Try to connect again") - logging.getLogger().info("------------------------------------") - res = milvus.connect(host=ip, port=port) + milvus = get_milvus(host=ip, port=port, handler=handler) except Exception as e: logging.getLogger().error(str(e)) pytest.exit("Milvus server can not connected, exit pytest ...") def fin(): try: milvus.disconnect() - except: - pass + except Exception as e: + logging.getLogger().info(str(e)) request.addfinalizer(fin) return milvus @@ -67,7 +60,10 @@ def dis_connect(request): port = request.config.getoption("--port") http_port = request.config.getoption("--http-port") handler = request.config.getoption("--handler") - milvus = get_milvus(handler=handler) + if handler == "HTTP": + port = http_port + milvus = get_milvus(host=ip, port=port, handler=handler) + milvus.disconnect() return milvus @@ -85,8 +81,13 @@ def args(request): @pytest.fixture(scope="module") def milvus(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + http_port = request.config.getoption("--http-port") handler = request.config.getoption("--handler") - return get_milvus(handler=handler) + if handler == "HTTP": + port = http_port + return get_milvus(host=ip, port=port, handler=handler) @pytest.fixture(scope="function") @@ -98,8 +99,10 @@ def collection(request, connect): 'dimension': dim, 'index_file_size': index_file_size, 'metric_type': MetricType.L2} - status = connect.create_collection(param) - # logging.getLogger().info(status) + result = connect.create_collection(param, timeout=timeout) + status = result + if isinstance(result, tuple): + status = result[0] if not status.OK(): pytest.exit("collection can not be created, exit pytest ...") @@ -122,8 +125,10 @@ def ip_collection(request, connect): 'dimension': dim, 'index_file_size': index_file_size, 'metric_type': MetricType.IP} - status = connect.create_collection(param) - # logging.getLogger().info(status) + result = connect.create_collection(param, timeout=timeout) + status = result + if isinstance(result, tuple): + status = result[0] if not status.OK(): pytest.exit("collection can not be created, exit pytest ...") @@ -146,8 +151,10 @@ def jac_collection(request, connect): 'dimension': dim, 'index_file_size': index_file_size, 'metric_type': MetricType.JACCARD} - status = connect.create_collection(param) - # logging.getLogger().info(status) + result = connect.create_collection(param, timeout=timeout) + status = result + if isinstance(result, tuple): + status = result[0] if not status.OK(): pytest.exit("collection can not be created, exit pytest ...") @@ -169,8 +176,10 @@ def ham_collection(request, connect): 'dimension': dim, 'index_file_size': index_file_size, 'metric_type': MetricType.HAMMING} - status = connect.create_collection(param) - # logging.getLogger().info(status) + result = connect.create_collection(param, timeout=timeout) + status = result + if isinstance(result, tuple): + status = result[0] if not status.OK(): pytest.exit("collection can not be created, exit pytest ...") @@ -192,8 +201,10 @@ def tanimoto_collection(request, connect): 'dimension': dim, 'index_file_size': index_file_size, 'metric_type': MetricType.TANIMOTO} - status = connect.create_collection(param) - # logging.getLogger().info(status) + result = connect.create_collection(param, timeout=timeout) + status = result + if isinstance(result, tuple): + status = result[0] if not status.OK(): pytest.exit("collection can not be created, exit pytest ...") @@ -214,8 +225,10 @@ def substructure_collection(request, connect): 'dimension': dim, 'index_file_size': index_file_size, 'metric_type': MetricType.SUBSTRUCTURE} - status = connect.create_collection(param) - # logging.getLogger().info(status) + result = connect.create_collection(param, timeout=timeout) + status = result + if isinstance(result, tuple): + status = result[0] if not status.OK(): pytest.exit("collection can not be created, exit pytest ...") @@ -236,8 +249,10 @@ def superstructure_collection(request, connect): 'dimension': dim, 'index_file_size': index_file_size, 'metric_type': MetricType.SUPERSTRUCTURE} - status = connect.create_collection(param) - # logging.getLogger().info(status) + result = connect.create_collection(param, timeout=timeout) + status = result + if isinstance(result, tuple): + status = result[0] if not status.OK(): pytest.exit("collection can not be created, exit pytest ...") diff --git a/tests/milvus_python_test/requirements.txt b/tests/milvus_python_test/requirements.txt index 621cfc35..c40d5e3a 100644 --- a/tests/milvus_python_test/requirements.txt +++ b/tests/milvus_python_test/requirements.txt @@ -6,5 +6,6 @@ pytest-repeat==0.8.0 allure-pytest==2.7.0 pytest-print==0.1.2 pytest-level==0.1.1 +pytest-xdist==1.23.2 scikit-learn>=0.19.1 pymilvus-test>=0.2.0 diff --git a/tests/milvus_python_test/test_add_vectors.py b/tests/milvus_python_test/test_add_vectors.py index d0e43008..faae9cdb 100644 --- a/tests/milvus_python_test/test_add_vectors.py +++ b/tests/milvus_python_test/test_add_vectors.py @@ -1198,12 +1198,8 @@ class TestAddAdvance: scope="function", params=[ 1, - 10, - 100, 1000, - pytest.param(5000 - 1, marks=pytest.mark.xfail), - pytest.param(5000, marks=pytest.mark.xfail), - pytest.param(5000 + 1, marks=pytest.mark.xfail), + 6000 ], ) def insert_count(self, request): diff --git a/tests/milvus_python_test/test_collection.py b/tests/milvus_python_test/test_collection.py index c09a3e33..0eccbcbc 100644 --- a/tests/milvus_python_test/test_collection.py +++ b/tests/milvus_python_test/test_collection.py @@ -465,7 +465,7 @@ class TestCollection: assert the value returned by delete method expected: create ok and delete ok ''' - loops = 5 + loops = 2 for i in range(loops): collection_name = "test_collection" param = {'collection_name': collection_name, @@ -474,7 +474,7 @@ class TestCollection: 'metric_type': MetricType.L2} connect.create_collection(param) status = connect.drop_collection(collection_name) - time.sleep(2) + time.sleep(1) assert status.OK() def test_delete_create_collection_repeatedly_ip(self, connect): @@ -484,7 +484,7 @@ class TestCollection: assert the value returned by delete method expected: create ok and delete ok ''' - loops = 5 + loops = 2 for i in range(loops): collection_name = "test_collection" param = {'collection_name': collection_name, @@ -493,7 +493,7 @@ class TestCollection: 'metric_type': MetricType.IP} connect.create_collection(param) status = connect.drop_collection(collection_name) - time.sleep(2) + time.sleep(1) assert status.OK() # TODO: enable @@ -760,6 +760,7 @@ class TestCollection: with pytest.raises(Exception) as e: status = dis_connect.show_collections() + @pytest.mark.level(2) def test_show_collections_no_collection(self, connect): ''' target: test show collections is correct or not, if no collection in db @@ -1078,7 +1079,6 @@ def gen_sequence(): yield x class TestCollectionLogic(object): - @pytest.mark.parametrize("logic_seq", gen_sequence()) @pytest.mark.level(2) def test_logic(self, connect, logic_seq, args): diff --git a/tests/milvus_python_test/test_collection_count.py b/tests/milvus_python_test/test_collection_count.py index e34623aa..a669c9a4 100644 --- a/tests/milvus_python_test/test_collection_count.py +++ b/tests/milvus_python_test/test_collection_count.py @@ -23,7 +23,7 @@ class TestCollectionCount: params=[ 1, 5000, - 100000, + 20000, ], ) def add_vectors_nb(self, request): @@ -239,7 +239,7 @@ class TestCollectionCountIP: params=[ 1, 5000, - 100000, + 20000, ], ) def add_vectors_nb(self, request): @@ -384,7 +384,7 @@ class TestCollectionCountJAC: params=[ 1, 5000, - 100000, + 20000, ], ) def add_vectors_nb(self, request): @@ -495,7 +495,7 @@ class TestCollectionCountBinary: params=[ 1, 5000, - 100000, + 20000, ], ) def add_vectors_nb(self, request): @@ -689,7 +689,7 @@ class TestCollectionCountTANIMOTO: params=[ 1, 5000, - 100000, + 20000, ], ) def add_vectors_nb(self, request): diff --git a/tests/milvus_python_test/test_connect.py b/tests/milvus_python_test/test_connect.py index 14704b65..947b9bbb 100644 --- a/tests/milvus_python_test/test_connect.py +++ b/tests/milvus_python_test/test_connect.py @@ -36,16 +36,12 @@ class TestConnect: expected: raise an error after disconnected ''' if not connect.connected(): - milvus = get_milvus(args["handler"]) - uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) - milvus.connect(uri=uri_value, timeout=5) - res = milvus.disconnect() with pytest.raises(Exception) as e: - res = milvus.disconnect() + connect.disconnect() else: - res = connect.disconnect() + connect.disconnect() with pytest.raises(Exception) as e: - res = connect.disconnect() + connect.disconnect() def test_connect_correct_ip_port(self, args): ''' @@ -53,8 +49,7 @@ class TestConnect: method: set correct ip and port expected: connected is True ''' - milvus = get_milvus(args["handler"]) - milvus.connect(host=args["ip"], port=args["port"]) + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) assert milvus.connected() def test_connect_connected(self, args): @@ -63,8 +58,7 @@ class TestConnect: method: set correct ip and port expected: connected is False ''' - milvus = get_milvus(args["handler"]) - milvus.connect(host=args["ip"], port=args["port"]) + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) milvus.disconnect() assert not milvus.connected() @@ -75,7 +69,7 @@ class TestConnect: method: set host localhost expected: connected is True ''' - milvus = get_milvus(args["handler"]) + milvus = get_milvus(args["ip"], args["port"], args["handler"]) milvus.connect(host='localhost', port=args["port"]) assert milvus.connected() @@ -86,11 +80,10 @@ class TestConnect: method: set host null expected: not use default ip, connected is False ''' - milvus = get_milvus(args["handler"]) ip = "" with pytest.raises(Exception) as e: - milvus.connect(host=ip, port=args["port"], timeout=1) - assert not milvus.connected() + milvus = get_milvus(ip, args["port"], args["handler"]) + assert not milvus.connected() def test_connect_uri(self, args): ''' @@ -98,9 +91,8 @@ class TestConnect: method: uri format and value are both correct expected: connected is True ''' - milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) - milvus.connect(uri=uri_value) + milvus = get_milvus(args["ip"], args["port"], uri=uri_value, handler=args["handler"]) assert milvus.connected() def test_connect_uri_null(self, args): @@ -109,44 +101,14 @@ class TestConnect: method: uri set null expected: connected is True ''' - milvus = get_milvus(args["handler"]) uri_value = "" - if self.local_ip(args): - milvus.connect(uri=uri_value, timeout=1) + milvus = get_milvus(uri=uri_value, handler=args["handler"]) assert milvus.connected() else: with pytest.raises(Exception) as e: - milvus.connect(uri=uri_value, timeout=1) - assert not milvus.connected() - - @pytest.mark.level(2) - @pytest.mark.timeout(CONNECT_TIMEOUT) - def test_connect_wrong_uri_wrong_port_null(self, args): - ''' - target: test uri connect with port value wouldn't connected - method: set uri port null - expected: connected is True - ''' - milvus = get_milvus(args["handler"]) - uri_value = "tcp://%s:" % args["ip"] - with pytest.raises(Exception) as e: - milvus.connect(uri=uri_value, timeout=1) - - @pytest.mark.level(2) - @pytest.mark.timeout(CONNECT_TIMEOUT) - def test_connect_wrong_uri_wrong_ip_null(self, args): - ''' - target: test uri connect with ip value wouldn't connected - method: set uri ip null - expected: connected is True - ''' - milvus = get_milvus(args["handler"]) - uri_value = "tcp://:%s" % args["port"] - - with pytest.raises(Exception) as e: - milvus.connect(uri=uri_value, timeout=1) - assert not milvus.connected() + milvus = get_milvus(uri=uri_value, handler=args["handler"]) + assert not milvus.connected() # disable def _test_connect_with_multiprocess(self, args): @@ -166,7 +128,7 @@ class TestConnect: assert milvus.connected() for i in range(process_num): - milvus = get_milvus(args["handler"]) + milvus = get_milvus(args["ip"], args["port"], args["handler"]) p = Process(target=connect, args=(milvus, )) processes.append(p) p.start() @@ -179,26 +141,12 @@ class TestConnect: method: connect again expected: status.code is 0, and status.message shows have connected already ''' - milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus = Milvus(uri=uri_value, handler=args["handler"]) milvus.connect(uri=uri_value, timeout=5) milvus.connect(uri=uri_value, timeout=5) assert milvus.connected() - def test_connect_disconnect_repeatedly_once(self, args): - ''' - target: test connect and disconnect repeatedly - method: disconnect, and then connect, assert connect status - expected: status.code is 0 - ''' - milvus = get_milvus(args["handler"]) - uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) - milvus.connect(uri=uri_value) - - milvus.disconnect() - milvus.connect(uri=uri_value) - assert milvus.connected() - def test_connect_disconnect_repeatedly_times(self, args): ''' target: test connect and disconnect for 10 times repeatedly @@ -206,13 +154,10 @@ class TestConnect: expected: status.code is 0 ''' times = 10 - milvus = get_milvus(args["handler"]) - uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) - milvus.connect(uri=uri_value, timeout=5) for i in range(times): + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) milvus.disconnect() - milvus.connect(uri=uri_value, timeout=5) - assert milvus.connected() + assert not milvus.connected() # TODO: enable def _test_connect_disconnect_with_multiprocess(self, args): @@ -232,36 +177,13 @@ class TestConnect: assert milvus.connected() for i in range(process_num): - milvus = get_milvus(args["handler"]) + milvus = get_milvus(args["ip"], args["port"], args["handler"]) p = Process(target=connect, args=(milvus, )) processes.append(p) p.start() for p in processes: p.join() - def test_connect_param_priority_no_port(self, args): - ''' - target: both host_ip_port / uri are both given, if port is null, use the uri params - method: port set "", check if wrong uri connection is ok - expected: connect raise an exception and connected is false - ''' - milvus = get_milvus(args["handler"]) - uri_value = "tcp://%s:39540" % args["ip"] - with pytest.raises(Exception) as e: - milvus.connect(host=args["ip"], port="", uri=uri_value) - - def test_connect_param_priority_uri(self, args): - ''' - target: both host_ip_port / uri are both given, if host is null, use the uri params - method: host set "", check if correct uri connection is ok - expected: connected is False - ''' - milvus = get_milvus(args["handler"]) - uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) - with pytest.raises(Exception) as e: - milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1) - assert not milvus.connected() - # Disable, (issue: https://github.com/milvus-io/milvus/issues/288) def _test_connect_param_priority_both_hostip_uri(self, args): ''' @@ -269,7 +191,7 @@ class TestConnect: method: check if wrong uri connection is ok expected: connect raise an exception and connected is false ''' - milvus = get_milvus(args["handler"]) + milvus = get_milvus(args["ip"], args["port"], args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) with pytest.raises(Exception) as e: res = milvus.connect(host=args["ip"], port=39540, uri=uri_value, timeout=1) @@ -325,11 +247,10 @@ class TestConnectIPInvalid(object): @pytest.mark.level(2) @pytest.mark.timeout(CONNECT_TIMEOUT) def test_connect_with_invalid_ip(self, args, get_invalid_ip): - milvus = get_milvus(args["handler"]) ip = get_invalid_ip with pytest.raises(Exception) as e: - milvus.connect(host=ip, port=args["port"], timeout=1) - assert not milvus.connected() + milvus = get_milvus(ip, args["port"], args["handler"]) + assert not milvus.connected() class TestConnectPortInvalid(object): @@ -352,11 +273,10 @@ class TestConnectPortInvalid(object): method: set port in gen_invalid_ports expected: connected is False ''' - milvus = get_milvus(args["handler"]) port = get_invalid_port with pytest.raises(Exception) as e: - milvus.connect(host=args["ip"], port=port, timeout=1) - assert not milvus.connected() + milvus = get_milvus(args["ip"], port, args["handler"]) + assert not milvus.connected() class TestConnectURIInvalid(object): @@ -378,8 +298,7 @@ class TestConnectURIInvalid(object): method: set port in gen_invalid_uris expected: connected is False ''' - milvus = get_milvus(args["handler"]) uri_value = get_invalid_uri with pytest.raises(Exception) as e: - milvus.connect(uri=uri_value, timeout=1) - assert not milvus.connected() + milvus = get_milvus(uri=uri_value, handler=args["handler"]) + assert not milvus.connected() diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py index d6e52e19..289fca52 100644 --- a/tests/milvus_python_test/test_index.py +++ b/tests/milvus_python_test/test_index.py @@ -186,6 +186,7 @@ class TestIndexBase: assert len(result[0]) == top_k assert result[0][0].distance == 0.0 + @pytest.mark.level(2) @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_index_multithread_multicollection(self, connect, args): ''' @@ -399,6 +400,7 @@ class TestIndexBase: status, ids = connect.add_vectors(collection, vectors) assert status.OK() + @pytest.mark.level(2) @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_same_index_repeatedly(self, connect, collection, get_simple_index): ''' @@ -412,6 +414,7 @@ class TestIndexBase: status = connect.create_index(collection, index_type, index_param) assert status.OK() + @pytest.mark.level(2) @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_different_index_repeatedly(self, connect, collection): ''' @@ -568,6 +571,7 @@ class TestIndexBase: assert result._collection_name == collection assert result._index_type == IndexType.FLAT + @pytest.mark.level(2) def test_drop_index_repeatly(self, connect, collection, get_simple_index): ''' target: test drop index repeatly @@ -635,6 +639,7 @@ class TestIndexBase: logging.getLogger().info(status) assert status.OK() + @pytest.mark.level(2) def test_create_drop_index_repeatly(self, connect, collection, get_simple_index): ''' target: test create / drop index repeatly, use the same index params @@ -760,7 +765,7 @@ class TestIndexIP: def test_create_index_search_with_query_vectors(self, connect, ip_collection, get_simple_index): ''' target: test create index interface, search with more query vectors - method: create collection and add vectors in it, create index + method: create collection and add vectors in it, create index, with no manual flush expected: return code equals to 0, and search success ''' index_param = get_simple_index["index_param"] @@ -991,8 +996,8 @@ class TestIndexIP: logging.getLogger().info(get_simple_index) status = connect.create_partition(ip_collection, tag) status = connect.create_partition(ip_collection, new_tag) - status, ids = connect.add_vectors(ip_collection, vectors, partition_tag=tag) - status, ids = connect.add_vectors(ip_collection, vectors, partition_tag=new_tag) + # status, ids = connect.add_vectors(ip_collection, vectors, partition_tag=tag) + # status, ids = connect.add_vectors(ip_collection, vectors, partition_tag=new_tag) status = connect.create_index(ip_collection, index_type, index_param) status, result = connect.describe_index(ip_collection) logging.getLogger().info(result) @@ -1135,6 +1140,7 @@ class TestIndexIP: assert result._collection_name == ip_collection assert result._index_type == IndexType.FLAT + @pytest.mark.level(2) def test_drop_index_repeatly(self, connect, ip_collection, get_simple_index): ''' target: test drop index repeatly @@ -1190,6 +1196,7 @@ class TestIndexIP: logging.getLogger().info(status) assert status.OK() + @pytest.mark.level(2) def test_create_drop_index_repeatly(self, connect, ip_collection, get_simple_index): ''' target: test create / drop index repeatly, use the same index params @@ -1821,4 +1828,3 @@ class TestCreateIndexParamsInvalid(object): logging.getLogger().info(result) assert result._collection_name == collection assert result._index_type == IndexType.FLAT - diff --git a/tests/milvus_python_test/test_partition.py b/tests/milvus_python_test/test_partition.py index cf8cd9db..aa37c3d4 100644 --- a/tests/milvus_python_test/test_partition.py +++ b/tests/milvus_python_test/test_partition.py @@ -33,6 +33,19 @@ class TestCreateBase: status = connect.create_partition(collection, tag) assert status.OK() + def _test_create_partition_limit(self, connect, collection): + ''' + target: test create partitions, check status returned + method: call function: create_partition for 4097 times + expected: status not ok + ''' + for i in range(4096): + tag_tmp = gen_unique_str() + status = connect.create_partition(collection, tag_tmp) + assert status.OK() + status = connect.create_partition(collection, tag) + assert not status.OK() + def test_create_partition_repeat(self, connect, collection): ''' target: test create partition, check status returned @@ -259,6 +272,7 @@ class TestDropBase: status = connect.drop_partition(new_collection, tag) assert not status.OK() + @pytest.mark.level(2) def test_drop_partition_repeatedly(self, connect, collection): ''' target: test drop partition twice, check status and partition if existed diff --git a/tests/milvus_python_test/test_search_vectors.py b/tests/milvus_python_test/test_search_vectors.py index 98d48d9d..ccffe569 100644 --- a/tests/milvus_python_test/test_search_vectors.py +++ b/tests/milvus_python_test/test_search_vectors.py @@ -136,7 +136,7 @@ class TestSearchBase: """ @pytest.fixture( scope="function", - params=[1, 99, 1024, 2048, 2049] + params=[1, 99, 1024, 2049] ) def get_top_k(self, request): yield request.param @@ -166,20 +166,21 @@ class TestSearchBase: method: search with the given vectors, check the result expected: search status ok, and the length of the result is top_k ''' + top_k = 10 index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) + if index_type == IndexType.IVF_PQ: + pytest.skip("Skip PQ") + vectors, ids = self.init_data(connect, collection) status = connect.create_index(collection, index_type, index_param) query_vec = [vectors[0]] - top_k = 10 search_param = get_search_param(index_type) status, result = connect.search_vectors(collection, top_k, query_vec, params=search_param) logging.getLogger().info(result) if top_k <= 1024: assert status.OK() - if index_type == IndexType.IVF_PQ: - return assert len(result[0]) == min(len(vectors), top_k) assert check_result(result[0], ids[0]) assert result[0][0].distance <= epsilon @@ -192,21 +193,20 @@ class TestSearchBase: method: search with the given vectors, check the result expected: search status ok, and the length of the result is top_k ''' + top_k = 10 index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) + if index_type == IndexType.IVF_PQ: + pytest.skip("Skip PQ") + vectors, ids = self.init_data(connect, collection) status = connect.create_index(collection, index_type, index_param) - query_vec = [] - for i in range (1200): - query_vec.append(vectors[i]) - top_k = 10 + query_vec = vectors[:1000] search_param = get_search_param(index_type) status, result = connect.search_vectors(collection, top_k, query_vec, params=search_param) logging.getLogger().info(result) assert status.OK() - if index_type == IndexType.IVF_PQ: - return assert len(result[0]) == min(len(vectors), top_k) assert check_result(result[0], ids[0]) assert result[0][0].distance <= epsilon @@ -217,22 +217,23 @@ class TestSearchBase: method: add vectors into collection, search with the given vectors, check the result expected: search status ok, and the length of the result is top_k, search collection with partition tag return empty ''' + top_k = 10 index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) + if index_type == IndexType.IVF_PQ: + pytest.skip("Skip PQ") status = connect.create_partition(collection, tag) vectors, ids = self.init_data(connect, collection) status = connect.create_index(collection, index_type, index_param) query_vec = [vectors[0]] - top_k = 10 search_param = get_search_param(index_type) status, result = connect.search_vectors(collection, top_k, query_vec, params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert result[0][0].distance <= epsilon + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance <= epsilon status, result = connect.search_vectors(collection, top_k, query_vec, partition_tags=[tag], params=search_param) logging.getLogger().info(result) assert status.OK() @@ -244,14 +245,17 @@ class TestSearchBase: method: search partition with the given vectors, check the result expected: search status ok, and the length of the result is 0 ''' + top_k = 10 index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) + if index_type == IndexType.IVF_PQ: + pytest.skip("Skip PQ") + status = connect.create_partition(collection, tag) vectors, ids = self.init_data(connect, collection) status = connect.create_index(collection, index_type, index_param) query_vec = [vectors[0]] - top_k = 10 search_param = get_search_param(index_type) status, result = connect.search_vectors(collection, top_k, query_vec, partition_tags=[tag], params=search_param) logging.getLogger().info(result) @@ -264,29 +268,29 @@ class TestSearchBase: method: search with the given vectors, check the result expected: search status ok, and the length of the result is top_k ''' + top_k = 10 index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) + if index_type == IndexType.IVF_PQ: + pytest.skip("Skip PQ") status = connect.create_partition(collection, tag) vectors, ids = self.init_data(connect, collection, partition_tags=tag) status = connect.create_index(collection, index_type, index_param) query_vec = [vectors[0]] - top_k = 10 search_param = get_search_param(index_type) status, result = connect.search_vectors(collection, top_k, query_vec, params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert result[0][0].distance <= epsilon + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance <= epsilon status, result = connect.search_vectors(collection, top_k, query_vec, partition_tags=[tag], params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert result[0][0].distance <= epsilon + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance <= epsilon def test_search_l2_index_params_partition_C(self, connect, collection, get_simple_index): ''' @@ -297,6 +301,8 @@ class TestSearchBase: index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) + if index_type == IndexType.IVF_PQ: + pytest.skip("Skip PQ") status = connect.create_partition(collection, tag) vectors, ids = self.init_data(connect, collection, partition_tags=tag) status = connect.create_index(collection, index_type, index_param) @@ -306,10 +312,9 @@ class TestSearchBase: status, result = connect.search_vectors(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert result[0][0].distance <= epsilon + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance <= epsilon def test_search_l2_index_params_partition_D(self, connect, collection, get_simple_index): ''' @@ -336,9 +341,12 @@ class TestSearchBase: method: search collection with the given vectors and tags, check the result expected: search status ok, and the length of the result is top_k ''' + top_k = 10 new_tag = "new_tag" - index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] + index_param = get_simple_index["index_param"] + if index_type == IndexType.IVF_PQ: + pytest.skip("Skip PQ") logging.getLogger().info(get_simple_index) status = connect.create_partition(collection, tag) status = connect.create_partition(collection, new_tag) @@ -346,24 +354,21 @@ class TestSearchBase: new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag) status = connect.create_index(collection, index_type, index_param) query_vec = [vectors[0], new_vectors[0]] - top_k = 10 search_param = get_search_param(index_type) status, result = connect.search_vectors(collection, top_k, query_vec, partition_tags=[tag, new_tag], params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert check_result(result[1], new_ids[0]) - assert result[0][0].distance <= epsilon - assert result[1][0].distance <= epsilon + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert check_result(result[1], new_ids[0]) + assert result[0][0].distance <= epsilon + assert result[1][0].distance <= epsilon status, result = connect.search_vectors(collection, top_k, query_vec, partition_tags=[new_tag], params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[1], new_ids[0]) - assert result[1][0].distance <= epsilon + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[1], new_ids[0]) + assert result[1][0].distance <= epsilon def test_search_l2_index_params_partition_F(self, connect, collection, get_simple_index): ''' @@ -376,6 +381,8 @@ class TestSearchBase: index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) + if index_type == IndexType.IVF_PQ: + pytest.skip("Skip PQ") status = connect.create_partition(collection, tag) status = connect.create_partition(collection, new_tag) vectors, ids = self.init_data(connect, collection, partition_tags=tag) @@ -387,15 +394,13 @@ class TestSearchBase: status, result = connect.search_vectors(collection, top_k, query_vec, partition_tags=["new(.*)"], params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert result[0][0].distance > epsilon - assert result[1][0].distance <= epsilon + assert result[0][0].distance > epsilon + assert result[1][0].distance <= epsilon status, result = connect.search_vectors(collection, top_k, query_vec, partition_tags=["(.*)tag"], params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert result[0][0].distance <= epsilon - assert result[1][0].distance <= epsilon + assert result[0][0].distance <= epsilon + assert result[1][0].distance <= epsilon def test_search_ip_index_params(self, connect, ip_collection, get_simple_index): ''' @@ -403,27 +408,23 @@ class TestSearchBase: method: search with the given vectors, check the result expected: search status ok, and the length of the result is top_k ''' + top_k = 10 index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) - if index_type == IndexType.RNSG: - pytest.skip("rnsg not support in ip") + if index_type in [IndexType.RNSG, IndexType.IVF_PQ]: + pytest.skip("rnsg not support in ip, skip pq") + vectors, ids = self.init_data(connect, ip_collection) status = connect.create_index(ip_collection, index_type, index_param) query_vec = [vectors[0]] - top_k = 10 search_param = get_search_param(index_type) status, result = connect.search_vectors(ip_collection, top_k, query_vec, params=search_param) logging.getLogger().info(result) - - if top_k <= 1024: - assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) - else: - assert not status.OK() + assert status.OK() + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) def test_search_ip_large_nq_index_params(self, connect, ip_collection, get_simple_index): ''' @@ -434,8 +435,8 @@ class TestSearchBase: index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(get_simple_index) - if index_type == IndexType.RNSG: - pytest.skip("rnsg not support in ip") + if index_type in [IndexType.RNSG, IndexType.IVF_PQ]: + pytest.skip("rnsg not support in ip, skip pq") vectors, ids = self.init_data(connect, ip_collection) status = connect.create_index(ip_collection, index_type, index_param) query_vec = [] @@ -446,10 +447,9 @@ class TestSearchBase: status, result = connect.search_vectors(ip_collection, top_k, query_vec, params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) def test_search_ip_index_params_partition(self, connect, ip_collection, get_simple_index): ''' @@ -457,24 +457,24 @@ class TestSearchBase: method: search with the given vectors, check the result expected: search status ok, and the length of the result is top_k ''' + top_k = 10 index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(index_param) - if index_type == IndexType.RNSG: - pytest.skip("rnsg not support in ip") + if index_type in [IndexType.RNSG, IndexType.IVF_PQ]: + pytest.skip("rnsg not support in ip, skip pq") + status = connect.create_partition(ip_collection, tag) vectors, ids = self.init_data(connect, ip_collection) status = connect.create_index(ip_collection, index_type, index_param) query_vec = [vectors[0]] - top_k = 10 search_param = get_search_param(index_type) status, result = connect.search_vectors(ip_collection, top_k, query_vec, params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) status, result = connect.search_vectors(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param) logging.getLogger().info(result) assert status.OK() @@ -486,24 +486,24 @@ class TestSearchBase: method: search with the given vectors and tag, check the result expected: search status ok, and the length of the result is top_k ''' + top_k = 10 index_param = get_simple_index["index_param"] index_type = get_simple_index["index_type"] logging.getLogger().info(index_param) - if index_type == IndexType.RNSG: - pytest.skip("rnsg not support in ip") + if index_type in [IndexType.RNSG, IndexType.IVF_PQ]: + pytest.skip("rnsg not support in ip, skip pq") + status = connect.create_partition(ip_collection, tag) vectors, ids = self.init_data(connect, ip_collection, partition_tags=tag) status = connect.create_index(ip_collection, index_type, index_param) query_vec = [vectors[0]] - top_k = 10 search_param = get_search_param(index_type) status, result = connect.search_vectors(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param) logging.getLogger().info(result) assert status.OK() - if(index_type != IndexType.IVF_PQ): - assert len(result[0]) == min(len(vectors), top_k) - assert check_result(result[0], ids[0]) - assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) @pytest.mark.level(2) def test_search_vectors_without_connect(self, dis_connect, collection): @@ -833,6 +833,7 @@ class TestSearchBase: for th in threads: th.join() + @pytest.mark.level(2) @pytest.mark.timeout(30) def test_search_concurrent_multithreads(self, args): ''' @@ -851,8 +852,7 @@ class TestSearchBase: 'index_type': IndexType.FLAT, 'store_raw_vector': False} # create collection - milvus = get_milvus(args["handler"]) - milvus.connect(uri=uri, timeout=5) + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) milvus.create_collection(param) vectors, ids = self.init_data(milvus, collection, nb=nb) query_vecs = vectors[nb//2:nb] @@ -864,8 +864,7 @@ class TestSearchBase: assert result[i][0].distance == 0.0 for i in range(threads_num): - milvus = get_milvus(args["handler"]) - milvus.connect(uri=uri, timeout=5) + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) t = threading.Thread(target=search, args=(milvus, )) threads.append(t) t.start() @@ -892,8 +891,7 @@ class TestSearchBase: 'index_type': IndexType.FLAT, 'store_raw_vector': False} # create collection - milvus = get_milvus(args["handler"]) - milvus.connect(uri=uri) + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) milvus.create_collection(param) vectors, ids = self.init_data(milvus, collection, nb=nb) query_vecs = vectors[nb//2:nb] @@ -905,8 +903,7 @@ class TestSearchBase: assert result[i][0].distance == 0.0 for i in range(process_num): - milvus = get_milvus(args["handler"]) - milvus.connect(uri=uri) + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) p = Process(target=search, args=(milvus, )) processes.append(p) p.start() @@ -932,8 +929,7 @@ class TestSearchBase: 'index_file_size': 10, 'metric_type': MetricType.L2} # create collection - milvus = get_milvus(args["handler"]) - milvus.connect(uri=uri, timeout=5) + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) milvus.create_collection(param) status, ids = milvus.add_vectors(collection, vectors) assert status.OK() @@ -973,8 +969,7 @@ class TestSearchBase: 'index_file_size': 10, 'metric_type': MetricType.L2} # create collection - milvus = get_milvus(args["handler"]) - milvus.connect(uri=uri, timeout=5) + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) milvus.create_collection(param) status, ids = milvus.add_vectors(collection, vectors) assert status.OK() diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py index 61a3e01f..5780a006 100644 --- a/tests/milvus_python_test/utils.py +++ b/tests/milvus_python_test/utils.py @@ -24,10 +24,14 @@ all_index_types = [ ] -def get_milvus(handler=None): +def get_milvus(host, port, uri=None, handler=None): if handler is None: handler = "GRPC" - return Milvus(handler=handler) + if uri is not None: + milvus = Milvus(uri=uri, handler=handler) + else: + milvus = Milvus(host=host, port=port, handler=handler) + return milvus def gen_inaccuracy(num): -- GitLab