未验证 提交 d9e9d52f 编写于 作者: D del-zhenwu 提交者: GitHub

enable some cases (#3325)

Signed-off-by: Nzw <zw@milvus.io>
Co-authored-by: Nzw <zw@milvus.io>
上级 55f47def
......@@ -1029,9 +1029,8 @@ class TestSearchDSL(object):
def get_invalid_term(self, request):
return request.param
# TODO
@pytest.mark.level(2)
def _test_query_term_wrong_format(self, connect, collection, get_invalid_term):
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
'''
method: build query with wrong format term
expected: Exception raised
......@@ -1548,9 +1547,8 @@ class TestSearchInvalid(object):
def get_search_params(self, request):
yield request.param
# TODO: This case can all pass, but it's too slow
@pytest.mark.level(2)
def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
'''
target: test search fuction, with the wrong nprobe
method: search with nprobe
......@@ -1560,8 +1558,6 @@ class TestSearchInvalid(object):
index_type = get_simple_index["index_type"]
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
if search_params["index_type"] != index_type:
pytest.skip("Skip case")
query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"])
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
......
......@@ -249,36 +249,32 @@ class TestFlushBase:
assert res
# TODO: CI fail, LOCAL pass
def _test_collection_count_during_flush(self, connect, args):
@pytest.mark.level(2)
def test_collection_count_during_flush(self, connect, collection, args):
'''
method: flush collection at background, call `count_entities`
expected: status ok
expected: no timeout
'''
collection = gen_unique_str("test_flush")
# param = {'collection_name': collection,
# 'dimension': dim,
# 'index_file_size': index_file_size,
# 'metric_type': MetricType.L2}
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
milvus.create_collection(collection, default_fields)
# vectors = gen_vector(nb, dim)
ids = milvus.insert(collection, entities, ids=[i for i in range(nb)])
def flush(collection_name):
ids = []
for i in range(5):
tmp_ids = connect.insert(collection, entities)
connect.flush([collection])
ids.extend(tmp_ids)
disable_flush(connect)
status = connect.delete_entity_by_id(collection, ids)
def flush():
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
status = milvus.delete_entity_by_id(collection_name, [i for i in range(nb)])
with pytest.raises(Exception) as e:
milvus.flush([collection_name])
p = Process(target=flush, args=(collection,))
logging.error("start flush")
milvus.flush([collection])
logging.error("end flush")
p = threading.Thread(target=flush, args=())
p.start()
res = milvus.count_entities(collection)
assert res == nb
time.sleep(0.2)
logging.error("start count")
res = connect.count_entities(collection, timeout = 10)
p.join()
res = milvus.count_entities(collection)
assert res == nb
logging.getLogger().info(res)
res = connect.count_entities(collection)
assert res == 0
......
......@@ -551,22 +551,19 @@ class TestIndexBinary:
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
# TODO:
@pytest.mark.timeout(BUILD_TIMEOUT)
def _test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq):
def test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq):
'''
target: test create index interface, search with more query vectors
method: create collection and add entities in it, create index
expected: return search success
'''
nq = get_nq
pdb.set_trace()
ids = connect.insert(binary_collection, binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
search_param = get_search_param(binary_collection["index_type"])
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
res = connect.search(binary_collection, query, search_params=search_param)
logging.getLogger().info(res)
assert len(res) == nq
"""
......@@ -581,15 +578,18 @@ class TestIndexBinary:
method: create collection and add entities in it, create index, call describe index
expected: return code 0, and index instructure
'''
if get_jaccard_index["index_type"] == "BIN_FLAT":
pytest.skip("GetCollectionStats skip BIN_FLAT")
ids = connect.insert(binary_collection, binary_entities)
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
logging.getLogger().info(stats)
# TODO
# assert stats['partitions'][0]['segments'][0]['index_name'] == get_jaccard_index['index_type']
assert stats["row_count"] == nb
for partition in stats["partitions"]:
segments = partition["segments"]
if segments:
for segment in segments:
for file in segment["files"]:
if "index_type" in file:
assert file["index_type"] == get_jaccard_index["index_type"]
def test_get_index_info_partition(self, connect, binary_collection, get_jaccard_index):
'''
......@@ -597,16 +597,21 @@ class TestIndexBinary:
method: create collection, create partition and add entities in it, create index, call describe index
expected: return code 0, and index instructure
'''
if get_jaccard_index["index_type"] == "BIN_FLAT":
pytest.skip("GetCollectionStats skip BIN_FLAT")
connect.create_partition(binary_collection, tag)
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
logging.getLogger().info(stats)
# TODO
# assert stats['partitions'][1]['segments'][0]['index_name'] == get_jaccard_index['index_type']
assert stats["row_count"] == nb
assert len(stats["partitions"]) == 2
for partition in stats["partitions"]:
segments = partition["segments"]
if segments:
for segment in segments:
for file in segment["files"]:
if "index_type" in file:
assert file["index_type"] == get_jaccard_index["index_type"]
"""
******************************************************************
......@@ -639,65 +644,18 @@ class TestIndexBinary:
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
logging.getLogger().info(stats)
connect.drop_index(binary_collection, binary_field_name)
stats = connect.get_collection_stats(binary_collection)
logging.getLogger().info(stats)
# TODO
# assert stats["partitions"][1]["segments"][0]["index_name"] == default_index_type
class TestIndexMultiCollections(object):
@pytest.mark.level(2)
@pytest.mark.timeout(BUILD_TIMEOUT)
def _test_create_index_multithread_multicollection(self, connect, args):
'''
target: test create index interface with multiprocess
method: create collection and add entities in it, create index
expected: return search success
'''
threads_num = 8
loop_num = 8
threads = []
collection = []
j = 0
while j < (threads_num * loop_num):
collection_name = gen_unique_str("test_create_index_multiprocessing")
collection.append(collection_name)
param = {'collection_name': collection_name,
'dimension': dim,
'index_type': IndexType.FLAT,
'store_raw_vector': False}
connect.create_collection(param)
j = j + 1
def create_index():
i = 0
while i < loop_num:
# assert connect.has_collection(collection[ids*process_num+i])
ids = connect.insert(collection[ids * threads_num + i], vectors)
connect.create_index(collection[ids * threads_num + i], IndexType.IVFLAT, {"nlist": NLIST, "metric_type": "L2"})
assert status.OK()
query_vec = [vectors[0]]
top_k = 1
search_param = {"nprobe": nprobe}
status, result = connect.search(collection[ids * threads_num + i], top_k, query_vec,
params=search_param)
assert len(result) == 1
assert len(result[0]) == top_k
assert result[0][0].distance == 0.0
i = i + 1
for i in range(threads_num):
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
ids = i
t = threading.Thread(target=create_index, args=(m, ids))
threads.append(t)
t.start()
time.sleep(0.2)
for t in threads:
t.join()
assert stats["row_count"] == nb
for partition in stats["partitions"]:
segments = partition["segments"]
if segments:
for segment in segments:
for file in segment["files"]:
if "index_type" not in file:
continue
if file["index_type"] == get_jaccard_index["index_type"]:
assert False
class TestIndexInvalid(object):
......
......@@ -21,6 +21,7 @@ entity = gen_entities(1)
entities = gen_entities(nb)
raw_vector, binary_entity = gen_binary_entities(1)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
class TestCreateBase:
......@@ -38,21 +39,33 @@ class TestCreateBase:
'''
connect.create_partition(collection, tag)
@pytest.mark.level(3)
def _test_create_partition_limit(self, connect, collection, args):
@pytest.mark.level(2)
def test_create_partition_limit(self, connect, collection, args):
'''
target: test create partitions, check status returned
method: call function: create_partition for 4097 times
expected: status not ok
expected: exception raised
'''
threads_num = 16
threads = []
if args["handler"] == "HTTP":
pytest.skip("skip in http mode")
for i in range(4096):
tag_tmp = gen_unique_str()
connect.create_partition(collection, tag_tmp)
def create(connect, threads_num):
for i in range(4096 // threads_num):
tag_tmp = gen_unique_str()
connect.create_partition(collection, tag_tmp)
for i in range(threads_num):
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
t = threading.Thread(target=create, args=(m, threads_num, ))
threads.append(t)
t.start()
for t in threads:
t.join()
tag_tmp = gen_unique_str()
with pytest.raises(Exception) as e:
connect.create_partition(collection, tag)
connect.create_partition(collection, tag_tmp)
def test_create_partition_repeat(self, connect, collection):
'''
......@@ -147,7 +160,8 @@ class TestCreateBase:
res = connect.count_entities(id_collection)
assert res == nb * 2
def _test_create_partition_insert_same_tags_two_collections(self, connect, collection):
@pytest.mark.level(2)
def test_create_partition_insert_same_tags_two_collections(self, connect, collection):
'''
target: test create two partitions, and insert vectors with the same tag to each collection, check status returned
method: call function: create_partition
......@@ -156,16 +170,13 @@ class TestCreateBase:
connect.create_partition(collection, tag)
collection_new = gen_unique_str()
connect.create_collection(collection_new, default_fields)
connect.create_collection(param)
connect.create_partition(collection_new, tag)
ids = [i for i in range(nb)]
status, ids = connect.insert(collection, entities, ids, partition_tag=tag)
ids = [(i+nb) for i in range(nq)]
status, ids = connect.insert(collection_new, entities, ids, partition_tag=tag)
ids = connect.insert(collection, entities, partition_tag=tag)
ids = connect.insert(collection_new, entities, partition_tag=tag)
connect.flush([collection, collection_new])
status, res = connect.count_entities(collection)
res = connect.count_entities(collection)
assert res == nb
status, res = connect.count_entities(collection_new)
res = connect.count_entities(collection_new)
assert res == nb
......
......@@ -768,8 +768,8 @@ def gen_binary_index():
return index_params
def get_search_param(index_type):
search_params = {"metric_type": "L2"}
def get_search_param(index_type, metric_type="L2"):
search_params = {"metric_type": metric_type}
if index_type in ivf() or index_type in binary_support():
search_params.update({"nprobe": 64})
elif index_type == "HNSW":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册