未验证 提交 d993bcc9 编写于 作者: D del-zhenwu 提交者: GitHub

remove todo case (#3296)

* remove todo case
Signed-off-by: Nzw <zw@milvus.io>

* remove todo case
Signed-off-by: Nzw <zw@milvus.io>
Co-authored-by: Nzw <zw@milvus.io>
上级 2fc5dc3b
......@@ -16,7 +16,8 @@ tag = "tag"
collection_id = "count_collection"
add_interval_time = 3
segment_row_count = 5000
default_fields = gen_default_fields()
default_fields = gen_default_fields()
default_binary_fields = gen_binary_default_fields()
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
field_name = "fload_vector"
......@@ -493,8 +494,8 @@ class TestCollectionMultiCollections:
res = connect.count_entities(collection_list[i])
assert res == insert_count
# TODO:
def _test_collection_count_multi_collections_binary(self, connect, binary_collection, insert_count):
@pytest.mark.level(2)
def test_collection_count_multi_collections_binary(self, connect, binary_collection, insert_count):
'''
target: test collection rows_count is correct or not with multiple collections of JACCARD
method: create collection and add entities in it,
......@@ -503,21 +504,20 @@ class TestCollectionMultiCollections:
'''
raw_vectors, entities = gen_binary_entities(insert_count)
res = connect.insert(binary_collection, entities)
# logging.getLogger().info(entities)
collection_list = []
collection_num = 20
for i in range(collection_num):
collection_name = gen_unique_str(collection_id)
collection_list.append(collection_name)
connect.create_collection(collection_name, default_fields)
connect.create_collection(collection_name, default_binary_fields)
res = connect.insert(collection_name, entities)
connect.flush(collection_list)
for i in range(collection_num):
res = connect.count_entities(collection_list[i])
assert res == insert_count
# TODO:
def _test_collection_count_multi_collections_mix(self, connect):
@pytest.mark.level(2)
def test_collection_count_multi_collections_mix(self, connect):
'''
target: test collection rows_count is correct or not with multiple collections of JACCARD
method: create collection and add entities in it,
......@@ -534,7 +534,7 @@ class TestCollectionMultiCollections:
for i in range(int(collection_num / 2), collection_num):
collection_name = gen_unique_str(collection_id)
collection_list.append(collection_name)
connect.create_collection(collection_name, default_fields)
connect.create_collection(collection_name, default_binary_fields)
res = connect.insert(collection_name, binary_entities)
connect.flush(collection_list)
for i in range(collection_num):
......
......@@ -134,9 +134,8 @@ class TestStatsBase:
connect.flush([collection])
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb - 2
assert stats["partitions"][0]["row_count"] == nb -2
assert stats["partitions"][0]["segments"][0]["data_size"] > 0
# TODO
# assert stats["partitions"][0]["segments"][0]["index_type"] == "FLAT"
def test_get_collection_stats_after_compact_parts(self, connect, collection):
'''
......@@ -228,10 +227,11 @@ class TestStatsBase:
connect.flush([collection])
connect.create_index(collection, field_name, get_simple_index)
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
# TODO
# assert stats["partitions"][0]["segments"][0]["index_name"] == get_simple_index["index_type"]
assert stats["row_count"] == nb
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["data_size"] > 0
assert file["index_type"] == get_simple_index["index_type"]
def test_get_collection_stats_after_index_created_ip(self, connect, collection, get_simple_index):
'''
......@@ -245,10 +245,11 @@ class TestStatsBase:
get_simple_index.update({"metric_type": "IP"})
connect.create_index(collection, field_name, get_simple_index)
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
# TODO
# assert stats["partitions"][0]["segments"][0]["index_name"] == get_simple_index["index_type"]
assert stats["row_count"] == nb
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["data_size"] > 0
assert file["index_type"] == get_simple_index["index_type"]
def test_get_collection_stats_after_index_created_jac(self, connect, binary_collection, get_jaccard_index):
'''
......@@ -260,10 +261,11 @@ class TestStatsBase:
connect.flush([binary_collection])
connect.create_index(binary_collection, "binary_vector", get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
logging.getLogger().info(stats)
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
# TODO
# assert stats["partitions"][0]["segments"][0]["index_name"] == get_jaccard_index["index_type"]
assert stats["row_count"] == nb
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["data_size"] > 0
assert file["index_type"] == get_simple_index["index_type"]
def test_get_collection_stats_after_create_different_index(self, connect, collection):
'''
......@@ -276,10 +278,11 @@ class TestStatsBase:
for index_type in ["IVF_FLAT", "IVF_SQ8"]:
connect.create_index(collection, field_name, {"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
# TODO
# assert stats["partitions"][0]["segments"][0]["index_name"] == index_type
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
assert stats["row_count"] == nb
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["data_size"] > 0
assert file["index_type"] == index_type
def test_collection_count_multi_collections(self, connect):
'''
......@@ -323,10 +326,12 @@ class TestStatsBase:
connect.create_index(collection_name, field_name, {"index_type": "IVF_FLAT","params":{ "nlist": 1024}, "metric_type": "L2"})
for i in range(collection_num):
stats = connect.get_collection_stats(collection_list[i])
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
# TODO
# if i % 2:
# assert stats["partitions"][0]["segments"][0]["index_name"] == "IVF_SQ8"
# else:
# assert stats["partitions"][0]["segments"][0]["index_name"] == "IVF_FLAT"
if i % 2:
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["index_type"] == "IVF_SQ8"
else:
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["index_type"] == "IVF_FLAT"
connect.drop_collection(collection_list[i])
......@@ -65,7 +65,6 @@ class TestCreateCollection:
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
# TODO
def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields
......@@ -298,20 +297,6 @@ class TestCreateCollectionInvalid(object):
logging.getLogger().info(res)
assert res["segment_row_count"] == default_segment_row_count
# def _test_create_collection_no_metric_type(self, connect):
# '''
# target: test create collection with no metric_type params
# method: create collection with corrent params
# expected: use default L2
# '''
# collection_name = gen_unique_str(collection_id)
# fields = copy.deepcopy(default_fields)
# fields["fields"][-1]["params"].pop("metric_type")
# connect.create_collection(collection_name, fields)
# res = connect.get_collection_info(collection_name)
# logging.getLogger().info(res)
# assert res["metric_type"] == "L2"
# TODO: assert exception
def test_create_collection_limit_fields(self, connect):
collection_name = gen_unique_str(collection_id)
......
......@@ -7,6 +7,7 @@ import threading
from multiprocessing import Process
from utils import *
nb = 1000
collection_id = "info"
default_fields = gen_default_fields()
segment_row_count = 5000
......@@ -53,7 +54,6 @@ class TestInfoBase:
******************************************************************
"""
# TODO
def test_info_collection_fields(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields, check info returned
......@@ -69,13 +69,16 @@ class TestInfoBase:
}
connect.create_collection(collection_name, fields)
res = connect.get_collection_info(collection_name)
# assert field_name
# assert field_type
# assert vector field params
# assert metric type
# assert dimension
assert res['auto_id'] == True
assert res['segment_row_count'] == segment_row_count
assert len(res["fields"]) == 3
for field in res["fields"]:
if field["type"] == filter_field:
assert field["name"] == filter_field["name"]
elif field["type"] == vector_field:
assert field["name"] == vector_field["name"]
assert field["params"] == vector_field["params"]
# TODO
def test_create_collection_segment_row_count(self, connect, get_segment_row_count):
'''
target: test create normal collection with different fields
......@@ -86,7 +89,9 @@ class TestInfoBase:
fields = copy.deepcopy(default_fields)
fields["segment_row_count"] = get_segment_row_count
connect.create_collection(collection_name, fields)
# assert segment size
# assert segment row count
res = connect.get_collection_info(collection_name)
assert res['segment_row_count'] == get_segment_row_count
def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index):
connect.create_index(collection, field_name, get_simple_index)
......@@ -148,7 +153,6 @@ class TestInfoBase:
******************************************************************
"""
# TODO
def test_info_collection_fields_after_insert(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields, check info returned
......@@ -163,15 +167,20 @@ class TestInfoBase:
"segment_row_count": segment_row_count
}
connect.create_collection(collection_name, fields)
# insert
entities = gen_entities_by_fields(fields["fields"], nb, vector_field["params"]["dim"])
res_ids = connect.insert(collection_name, entities)
connect.flush([collection_name])
res = connect.get_collection_info(collection_name)
# assert field_name
# assert field_type
# assert vector field params
# assert metric type
# assert dimension
assert res['auto_id'] == True
assert res['segment_row_count'] == segment_row_count
assert len(res["fields"]) == 3
for field in res["fields"]:
if field["type"] == filter_field:
assert field["name"] == filter_field["name"]
elif field["type"] == vector_field:
assert field["name"] == vector_field["name"]
assert field["params"] == vector_field["params"]
# TODO
def test_create_collection_segment_row_count_after_insert(self, connect, get_segment_row_count):
'''
target: test create normal collection with different fields
......@@ -182,8 +191,12 @@ class TestInfoBase:
fields = copy.deepcopy(default_fields)
fields["segment_row_count"] = get_segment_row_count
connect.create_collection(collection_name, fields)
# insert
# assert segment size
entities = gen_entities_by_fields(fields["fields"], nb, fields["fields"][-1]["params"]["dim"])
res_ids = connect.insert(collection_name, entities)
connect.flush([collection_name])
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
assert res['segment_row_count'] == get_segment_row_count
class TestInfoInvalid(object):
......
......@@ -10,10 +10,12 @@ collection_id = "load_collection"
nb = 6000
default_fields = gen_default_fields()
entities = gen_entities(nb)
field_name = "float_vector"
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
raw_vectors, binary_entities = gen_binary_entities(nb)
class TestLoadCollection:
class TestLoadBase:
"""
******************************************************************
......@@ -30,11 +32,22 @@ class TestLoadCollection:
pytest.skip("sq8h not support in cpu mode")
return request.param
@pytest.fixture(
scope="function",
params=gen_binary_index()
)
def get_binary_index(self, request, connect):
logging.getLogger().info(request.param)
if request.param["index_type"] in binary_support():
return request.param
else:
pytest.skip("Skip index Temporary")
def test_load_collection_after_index(self, connect, collection, get_simple_index):
'''
target: test load collection, after index created
method: insert and create index, load collection with correct params
expected: describe raise exception
expected: no error raised
'''
connect.insert(collection, entities)
connect.flush([collection])
......@@ -42,20 +55,20 @@ class TestLoadCollection:
connect.create_index(collection, field_name, get_simple_index)
connect.load_collection(collection)
# TODO:
@pytest.mark.level(1)
def test_load_collection_after_index_binary(self, connect, binary_collection):
@pytest.mark.level(2)
def test_load_collection_after_index_binary(self, connect, binary_collection, get_binary_index):
'''
target: test load binary_collection, after index created
method: insert and create index, load binary_collection with correct params
expected: describe raise exception
expected: no error raised
'''
# connect.insert(binary_collection, entities)
# connect.flush([binary_collection])
# logging.getLogger().info(get_simple_index)
# connect.create_index(binary_collection, field_name, get_simple_index)
# connect.load_collection(binary_collection)
pass
connect.insert(binary_collection, binary_entities)
connect.flush([binary_collection])
for metric_type in binary_metrics():
logging.getLogger().info(metric_type)
get_binary_index["metric_type"] = metric_type
connect.create_index(binary_collection, binary_field_name, get_binary_index)
connect.load_collection(binary_collection)
def load_empty_collection(self, connect, collection):
'''
......@@ -86,10 +99,6 @@ class TestLoadCollection:
def test_load_collection_after_search(self, connect, collection):
pass
@pytest.mark.level(2)
def test_load_collection_before_search(self, connect, collection):
pass
class TestLoadCollectionInvalid(object):
"""
......
......@@ -68,7 +68,6 @@ class TestDeleteBase:
status = connect.delete_entity_by_id(collection, [0])
assert status
# TODO
def test_delete_empty_collection(self, connect, collection):
'''
target: test delete entity, params collection_name not existed
......@@ -100,7 +99,6 @@ class TestDeleteBase:
with pytest.raises(Exception) as e:
status = connect.delete_entity_by_id(collection_new, [0])
# TODO:
def test_insert_delete(self, connect, collection, insert_count):
'''
target: test delete entity
......@@ -113,6 +111,9 @@ class TestDeleteBase:
delete_ids = [ids[0]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
connect.flush([collection])
res_count = connect.count_entities(collection)
assert res_count == insert_count - 1
def test_insert_delete_A(self, connect, collection):
'''
......@@ -159,7 +160,6 @@ class TestDeleteBase:
res_count = connect.count_entities(collection)
assert res_count == 0
# TODO
def test_flush_after_delete(self, connect, collection):
'''
target: test delete entity
......@@ -190,6 +190,16 @@ class TestDeleteBase:
res_count = connect.count_entities(binary_collection)
assert res_count == nb - len(delete_ids)
def test_insert_delete_binary(self, connect, binary_collection):
'''
method: add entities and delete
expected: status DELETED
'''
ids = connect.insert(binary_collection, binary_entities)
connect.flush([binary_collection])
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(binary_collection, delete_ids)
def test_insert_same_ids_after_delete(self, connect, id_collection):
'''
method: add entities and delete
......@@ -257,7 +267,6 @@ class TestDeleteBase:
connect.create_index(collection, field_name, get_simple_index)
# assert index info
# TODO
def test_delete_multiable_times(self, connect, collection):
'''
method: add entities and delete id serveral times
......@@ -313,7 +322,6 @@ class TestDeleteBase:
The following cases are used to test `delete_entity_by_id` function, with tags
******************************************************************
"""
# TODO:
def test_insert_tag_delete(self, connect, collection):
'''
method: add entitys with given tag, delete entities with the return ids
......@@ -325,8 +333,10 @@ class TestDeleteBase:
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
connect.flush([collection])
res_count = connect.count_entities(collection)
assert res_count == nb - 2
# TODO:
def test_insert_default_tag_delete(self, connect, collection):
'''
method: add entitys, delete entities with the return ids
......@@ -338,8 +348,10 @@ class TestDeleteBase:
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
connect.flush([collection])
res_count = connect.count_entities(collection)
assert res_count == nb - 2
# TODO:
def test_insert_tags_delete(self, connect, collection):
'''
method: add entitys with given two tags, delete entities with the return ids
......@@ -413,14 +425,3 @@ class TestDeleteInvalid(object):
with pytest.raises(Exception) as e:
status = connect.delete_entity_by_id(collection_name, [1])
# def test_insert_same_ids_after_delete_jac(self, connect, jac_collection):
# '''
# method: add entities and delete
# expected: status DELETED
# '''
# insert_ids = [i for i in range(nb)]
# ids = connect.insert(jac_collection, binary_entities, insert_ids)
# connect.flush([jac_collection])
# delete_ids = [ids[0], ids[-1]]
# with pytest.raises(Exception) as e:
# status = connect.delete_entity_by_id(jac_collection, delete_ids)
......@@ -282,7 +282,6 @@ class TestGetBase:
The following cases are used to test `get_entity_by_id` function, with fields params
******************************************************************
"""
# TODO:
def test_get_entity_field(self, connect, collection, get_pos):
'''
target: test.get_entity_by_id, get one
......@@ -295,8 +294,11 @@ class TestGetBase:
fields = ["int64"]
res = connect.get_entity_by_id(collection, get_ids, fields = fields)
# assert fields
res = res.dict()
assert res[0]["field"] == fields[0]
assert res[0]["values"] == [entities[0]["values"][get_pos]]
assert res[0]["type"] == DataType.INT64
# TODO:
def test_get_entity_fields(self, connect, collection, get_pos):
'''
target: test.get_entity_by_id, get one
......@@ -309,6 +311,15 @@ class TestGetBase:
fields = ["int64", "float", default_float_vec_field_name]
res = connect.get_entity_by_id(collection, get_ids, fields = fields)
# assert fields
res = res.dict()
assert len(res) == len(fields)
for field in res:
if field["field"] == fields[0]:
assert field["values"] == [entities[0]["values"][get_pos]]
elif field["field"] == fields[1]:
assert field["values"] == [entities[1]["values"][get_pos]]
else:
assert_equal_vector(field["values"][0], entities[-1]["values"][get_pos])
# TODO: assert exception
def test_get_entity_field_not_match(self, connect, collection, get_pos):
......
......@@ -194,7 +194,6 @@ class TestInsertBase:
res_count = connect.count_entities(id_collection)
assert res_count == nb
# TODO
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field):
'''
......@@ -221,8 +220,9 @@ class TestInsertBase:
assert res_count == nb
# TODO: assert exception && enable
@pytest.mark.level(2)
@pytest.mark.timeout(ADD_TIMEOUT)
def _test_insert_twice_ids_no_ids(self, connect, id_collection):
def test_insert_twice_ids_no_ids(self, connect, id_collection):
'''
target: check the result of insert, with params ids and no ids
method: test insert vectors twice, use customize ids first, and then use no ids
......@@ -234,8 +234,9 @@ class TestInsertBase:
res_ids_new = connect.insert(id_collection, entities)
# TODO: assert exception && enable
@pytest.mark.level(2)
@pytest.mark.timeout(ADD_TIMEOUT)
def _test_insert_twice_not_ids_ids(self, connect, id_collection):
def test_insert_twice_not_ids_ids(self, connect, id_collection):
'''
target: check the result of insert, with params ids and no ids
method: test insert vectors twice, use not ids first, and then use customize ids
......@@ -268,7 +269,6 @@ class TestInsertBase:
with pytest.raises(Exception) as e:
res_ids = connect.insert(collection, entity, ids)
# TODO
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field):
'''
......@@ -394,25 +394,25 @@ class TestInsertBase:
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
# TODO: Python sdk needs to do check
def _test_insert_with_field_type_not_match(self, connect, collection):
@pytest.mark.level(2)
def test_insert_with_field_type_not_match(self, connect, collection):
'''
target: test insert entities, with the entity field type updated
method: update entity field type
expected: error raised
'''
tmp_entity = update_field_type(copy.deepcopy(entity), DataType.INT64, DataType.FLOAT)
tmp_entity = update_field_type(copy.deepcopy(entity), "int64", DataType.FLOAT)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
# TODO: Python sdk needs to do check
def _test_insert_with_field_value_not_match(self, connect, collection):
@pytest.mark.level(2)
def test_insert_with_field_value_not_match(self, connect, collection):
'''
target: test insert entities, with the entity field value updated
method: update entity field value
expected: error raised
'''
tmp_entity = update_field_value(copy.deepcopy(entity), 'int64', 's')
tmp_entity = update_field_value(copy.deepcopy(entity), DataType.FLOAT, 's')
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
......@@ -519,7 +519,7 @@ class TestInsertBase:
assert res_count == thread_num * nb
class TestAddAsync:
class TestInsertAsync:
@pytest.fixture(scope="function", autouse=True)
def skip_http_check(self, args):
if args["handler"] == "HTTP":
......@@ -539,9 +539,9 @@ class TestAddAsync:
logging.getLogger().info("In callback check status")
assert not result
def check_status_not_ok(self, status, result):
def check_result(self, result):
logging.getLogger().info("In callback check status")
assert not status.OK()
assert result
def test_insert_async(self, connect, collection, insert_count):
'''
......@@ -578,14 +578,15 @@ class TestAddAsync:
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status)
future.done()
def _test_insert_async_long(self, connect, collection):
@pytest.mark.level(2)
def test_insert_async_long(self, connect, collection):
'''
target: test insert vectors with different length of vectors
method: set different vectors as insert method params
expected: length of ids is equal to the length of vectors
'''
nb = 50000
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status)
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result)
result = future.result()
assert len(result) == nb
connect.flush([collection])
......@@ -593,14 +594,14 @@ class TestAddAsync:
logging.getLogger().info(count)
assert count == nb
# TODO:
def _test_insert_async_callback_timeout(self, connect, collection):
@pytest.mark.level(2)
def test_insert_async_callback_timeout(self, connect, collection):
'''
target: test insert vectors with different length of vectors
method: set different vectors as insert method params
expected: length of ids is equal to the length of vectors
'''
nb = 500000
nb = 100000
future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status, timeout=1)
with pytest.raises(Exception) as e:
result = future.result()
......
......@@ -168,8 +168,7 @@ class TestCompactBase:
logging.getLogger().info(size_after)
assert(size_before >= size_after)
# TODO
@pytest.mark.skip("not implement")
@pytest.mark.level(2)
@pytest.mark.timeout(COMPACT_TIMEOUT)
def test_insert_delete_all_and_compact(self, connect, collection):
'''
......@@ -419,33 +418,6 @@ class TestCompactBase:
assert res[1]._distances[0] < epsilon
assert res[2]._distances[0] < epsilon
# TODO: enable
def _test_compact_server_crashed_recovery(self, connect, collection):
'''
target: test compact when server crashed unexpectedly and restarted
method: add entities, delete and compact collection; server stopped and restarted during compact
expected: status ok, request recovered
'''
entities = gen_vectors(nb * 100, dim)
status, ids = connect.insert(collection, entities)
assert status.OK()
status = connect.flush([collection])
assert status.OK()
delete_ids = ids[0:1000]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status.OK()
status = connect.flush([collection])
assert status.OK()
# start to compact, kill and restart server
logging.getLogger().info("compact starting...")
status = connect.compact(collection)
# pdb.set_trace()
assert status.OK()
# get collection info after compact
status, info = connect.get_collection_stats(collection)
assert status.OK()
assert info["partitions"][0].count == nb * 100 - 1000
class TestCompactBinary:
"""
......@@ -521,8 +493,6 @@ class TestCompactBinary:
logging.getLogger().info(size_after)
assert(size_before >= size_after)
# TODO
@pytest.mark.skip("not implement")
@pytest.mark.level(2)
@pytest.mark.timeout(COMPACT_TIMEOUT)
def test_insert_delete_all_and_compact(self, connect, binary_collection):
......@@ -693,7 +663,6 @@ class TestCompactBinary:
res = connect.search(binary_collection, query)
assert abs(res[0]._distances[0]-distance) <= epsilon
# TODO:
@pytest.mark.timeout(COMPACT_TIMEOUT)
def test_search_after_compact_ip(self, connect, collection):
'''
......
......@@ -189,7 +189,6 @@ class TestFlushBase:
logging.getLogger().debug(res)
assert res
# TODO: stable case
def test_add_flush_auto(self, connect, id_collection):
'''
method: add entities
......@@ -198,7 +197,7 @@ class TestFlushBase:
# vectors = gen_vectors(nb, dim)
ids = [i for i in range(nb)]
ids = connect.insert(id_collection, entities, ids)
timeout = 10
timeout = 20
start_time = time.time()
while (time.time() - start_time < timeout):
time.sleep(1)
......@@ -312,8 +311,7 @@ class TestFlushAsync:
future = connect.flush([collection], _async=True)
status = future.result()
# TODO:
def _test_flush_async(self, connect, collection):
def test_flush_async(self, connect, collection):
nb = 100000
vectors = gen_vectors(nb, dim)
connect.insert(collection, entities)
......
......@@ -66,6 +66,10 @@ def ivf():
return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
def binary_metrics():
return ["JACCARD", "HAMMING", "TANIMOTO", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
def l2(x, y):
return np.linalg.norm(np.array(x) - np.array(y))
......@@ -430,25 +434,28 @@ def remove_vector_field(entities):
def update_field_name(entities, old_name, new_name):
for item in entities:
tmp_entities = copy.deepcopy(entities)
for item in tmp_entities:
if item["field"] == old_name:
item["field"] = new_name
return entities
return tmp_entities
def update_field_type(entities, old_name, new_name):
for item in entities:
tmp_entities = copy.deepcopy(entities)
for item in tmp_entities:
if item["field"] == old_name:
item["type"] = new_name
return entities
return tmp_entities
def update_field_value(entities, old_type, new_value):
for item in entities:
tmp_entities = copy.deepcopy(entities)
for item in tmp_entities:
if item["type"] == old_type:
for i in item["values"]:
item["values"][i] = new_value
return entities
for index, value in enumerate(item["values"]):
item["values"][index] = new_value
return tmp_entities
def add_vector_field(nb, dimension=dimension):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册