未验证 提交 83a22702 编写于 作者: D del-zhenwu 提交者: GitHub

[skip ci] add auto_id (#3142)

* [skip ci] add auto_id
Signed-off-by: Nzw <zw@milvus.io>

* [skip ci] fix compact bug
Signed-off-by: Nzw <zw@milvus.io>

* [skip ci] update ub/2
Signed-off-by: Nzw <zw@milvus.io>

* [skip ci] remove jac_collection
Signed-off-by: Nzw <zw@milvus.io>
Co-authored-by: Nzw <zw@milvus.io>
上级 2b2ebc27
......@@ -314,9 +314,9 @@ class TestStatsBase:
res = connect.insert(collection_name, entities)
connect.flush(collection_list)
if i % 2:
connect.create_index(collection_name, field_name, {"index_type": "IVF_SQ8", "nlist": 1024, "metric_type": "L2"})
connect.create_index(collection_name, field_name, {"index_type": "IVF_SQ8", "params": {"nlist": 1024}, "metric_type": "L2"})
else:
connect.create_index(collection_name, field_name, {"index_type": "IVF_FLAT", "nlist": 1024, "metric_type": "L2"})
connect.create_index(collection_name, field_name, {"index_type": "IVF_FLAT", "params": {"nlist": 1024}, "metric_type": "L2"})
for i in range(collection_num):
stats = connect.get_collection_stats(collection_list[i])
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
......
......@@ -9,7 +9,6 @@ from utils import *
timeout = 60
dimension = 128
delete_timeout = 60
default_fields = gen_default_fields()
def pytest_addoption(parser):
......@@ -101,6 +100,7 @@ def collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
try:
default_fields = gen_default_fields()
connect.create_collection(collection_name, default_fields)
except Exception as e:
pytest.exit(str(e))
......@@ -113,14 +113,48 @@ def collection(request, connect):
return collection_name
@pytest.fixture(scope="function")
def id_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
try:
fields = gen_default_fields(auto_id=True)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
@pytest.fixture(scope="function")
def binary_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
fields = gen_default_fields()
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "params": {"dim": dimension}}
logging.getLogger().info(fields)
try:
fields = gen_binary_default_fields()
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
@pytest.fixture(scope="function")
def binary_id_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
try:
fields = gen_binary_default_fields(auto_id=True)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
......
......@@ -129,20 +129,20 @@ class TestDeleteBase:
res_count = connect.count_entities(collection)
assert res_count == nb - 1
def test_insert_delete_B(self, connect, collection):
def test_insert_delete_B(self, connect, id_collection):
'''
target: test delete entity
method: add entities with the same ids, and delete the id in collection
expected: no error raised, all entities deleted
'''
ids = [1 for i in range(nb)]
res_ids = connect.insert(collection, entities, ids)
res_ids = connect.insert(id_collection, entities, ids)
connect.flush([collection])
delete_ids = [1]
status = connect.delete_entity_by_id(collection, delete_ids)
status = connect.delete_entity_by_id(id_collection, delete_ids)
assert status
connect.flush([collection])
res_count = connect.count_entities(collection)
connect.flush([id_collection])
res_count = connect.count_entities(id_collection)
assert res_count == 0
def test_delete_exceed_limit(self, connect, collection):
......
......@@ -111,17 +111,17 @@ class TestGetBase:
with pytest.raises(Exception) as e:
res = connect.get_entity_by_id(collection, ids)
def test_get_entity_same_ids(self, connect, collection):
def test_get_entity_same_ids(self, connect, id_collection):
'''
target: test.get_entity_by_id, with the same ids
method: add entity, and get one id
expected: entity returned equals insert
'''
ids = [1 for i in range(nb)]
res_ids = connect.insert(collection, entities, ids)
connect.flush([collection])
res_ids = connect.insert(id_collection, entities, ids)
connect.flush([id_collection])
get_ids = [ids[0]]
res = connect.get_entity_by_id(collection, get_ids)
res = connect.get_entity_by_id(id_collection, get_ids)
assert len(res) == 1
assert_equal_vector(res[0].get(default_float_vec_field_name), entities[-1]["values"][0])
......@@ -464,7 +464,7 @@ class TestGetBase:
enable_flush(connect)
# TODO:
def test_get_entities_after_delete_same_ids(self, connect, collection):
def test_get_entities_after_delete_same_ids(self, connect, id_collection):
'''
target: test.get_entity_by_id
method: add entities with the same ids, and delete, get entity by the given id
......@@ -472,12 +472,12 @@ class TestGetBase:
'''
ids = [i for i in range(nb)]
ids[0] = 1
res_ids = connect.insert(collection, entities, ids)
connect.flush([collection])
status = connect.delete_entity_by_id(collection, [1])
connect.flush([collection])
res_ids = connect.insert(id_collection, entities, ids)
connect.flush([id_collection])
status = connect.delete_entity_by_id(id_collection, [1])
connect.flush([id_collection])
get_ids = [1]
res = connect.get_entity_by_id(collection, get_ids)
res = connect.get_entity_by_id(id_collection, get_ids)
assert res[0] is None
def test_get_entity_after_delete_with_partition(self, connect, collection, get_pos):
......
......@@ -208,7 +208,8 @@ class TestInsertBase:
collection_name = gen_unique_str("test_collection")
fields = {
"fields": [filter_field, vector_field],
"segment_row_count": segment_row_count
"segment_row_count": segment_row_count,
"auto_id": True
}
connect.create_collection(collection_name, fields)
ids = [i for i in range(nb)]
......@@ -221,32 +222,30 @@ class TestInsertBase:
# TODO: assert exception && enable
@pytest.mark.timeout(ADD_TIMEOUT)
def _test_insert_twice_ids_no_ids(self, connect, collection):
def _test_insert_twice_ids_no_ids(self, connect, id_collection):
'''
target: check the result of insert, with params ids and no ids
method: test insert vectors twice, use customize ids first, and then use no ids
expected: error raised
'''
ids = [i for i in range(nb)]
res_ids = connect.insert(collection, entities, ids)
res_ids = connect.insert(id_collection, entities, ids)
with pytest.raises(Exception) as e:
res_ids_new = connect.insert(collection, entities)
res_ids_new = connect.insert(id_collection, entities)
# TODO: assert exception && enable
@pytest.mark.timeout(ADD_TIMEOUT)
def _test_insert_twice_not_ids_ids(self, connect, collection):
def _test_insert_twice_not_ids_ids(self, connect, id_collection):
'''
target: check the result of insert, with params ids and no ids
method: test insert vectors twice, use not ids first, and then use customize ids
expected: error raised
'''
res_ids = connect.insert(collection, entities)
ids = [i for i in range(nb)]
with pytest.raises(Exception) as e:
res_ids_new = connect.insert(collection, entities, ids)
res_ids = connect.insert(id_collection, entities)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_length_not_match_batch(self, connect, collection):
def test_insert_ids_length_not_match_batch(self, connect, id_collection):
'''
target: test insert vectors in collection, use customize ids, len(ids) != len(vectors)
method: create collection and insert vectors in it
......@@ -255,7 +254,7 @@ class TestInsertBase:
ids = [i for i in range(1, nb)]
logging.getLogger().info(len(ids))
with pytest.raises(Exception) as e:
res_ids = connect.insert(collection, entities, ids)
res_ids = connect.insert(id_collection, entities, ids)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_length_not_match_single(self, connect, collection):
......@@ -304,15 +303,15 @@ class TestInsertBase:
assert len(ids) == nb
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_tag_with_ids(self, connect, collection):
def test_insert_tag_with_ids(self, connect, id_collection):
'''
target: test insert entities in collection created before, insert with ids
method: create collection and insert entities in it, with the partition_tag param
expected: the collection row count equals to nq
'''
connect.create_partition(collection, tag)
connect.create_partition(id_collection, tag)
ids = [i for i in range(nb)]
res_ids = connect.insert(collection, entities, ids, partition_tag=tag)
res_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
assert res_ids == ids
@pytest.mark.timeout(ADD_TIMEOUT)
......@@ -821,7 +820,7 @@ class TestInsertInvalid(object):
def get_field_vectors_value(self, request):
yield request.param
def test_insert_ids_invalid(self, connect, collection, get_entity_id):
def test_insert_ids_invalid(self, connect, id_collection, get_entity_id):
'''
target: test insert, with using customize ids, which are not int64
method: create collection and insert entities in it
......@@ -830,7 +829,7 @@ class TestInsertInvalid(object):
entity_id = get_entity_id
ids = [entity_id for _ in range(nb)]
with pytest.raises(Exception):
connect.insert(collection, entities, ids)
connect.insert(id_collection, entities, ids)
def test_insert_with_invalid_collection_name(self, connect, get_collection_name):
collection_name = get_collection_name
......@@ -927,7 +926,7 @@ class TestInsertInvalidBinary(object):
yield request.param
@pytest.mark.level(2)
def test_insert_ids_invalid(self, connect, binary_collection, get_entity_id):
def test_insert_ids_invalid(self, connect, binary_id_collection, get_entity_id):
'''
target: test insert, with using customize ids, which are not int64
method: create collection and insert entities in it
......@@ -936,7 +935,7 @@ class TestInsertInvalidBinary(object):
entity_id = get_entity_id
ids = [entity_id for _ in range(nb)]
with pytest.raises(Exception):
connect.insert(binary_collection, binary_entities, ids)
connect.insert(binary_id_collection, binary_entities, ids)
@pytest.mark.level(2)
def test_insert_with_invalid_tag_name(self, connect, binary_collection, get_tag_name):
......@@ -971,7 +970,7 @@ class TestInsertInvalidBinary(object):
connect.insert(binary_collection, tmp_entity)
@pytest.mark.level(2)
def test_insert_ids_invalid(self, connect, binary_collection, get_entity_id):
def test_insert_ids_invalid(self, connect, binary_id_collection, get_entity_id):
'''
target: test insert, with using customize ids, which are not int64
method: create collection and insert entities in it
......@@ -980,7 +979,7 @@ class TestInsertInvalidBinary(object):
entity_id = get_entity_id
ids = [entity_id for _ in range(nb)]
with pytest.raises(Exception):
connect.insert(binary_collection, binary_entities, ids)
connect.insert(binary_id_collection, binary_entities, ids)
@pytest.mark.level(2)
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
......
......@@ -229,7 +229,7 @@ class TestSearchBase:
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors_(field_name, entities, top_k, nq, search_params=search_param)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
......@@ -964,7 +964,7 @@ class TestSearchDSL(object):
'''
entities, ids = init_data(connect, collection)
expr = {"must": [gen_default_vector_expr(default_query),
gen_default_term_expr(values=[i for i in range(nb / 2, nb + nb / 2)])]}
gen_default_term_expr(values=[i for i in range(nb // 2, nb + nb // 2)])]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
# TODO:
......
......@@ -19,6 +19,7 @@ entities = gen_entities(nb)
raw_vector, binary_entity = gen_binary_entities(1)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
default_binary_fields = gen_binary_default_fields()
field_name = default_float_vec_field_name
default_single_query = {
"bool": {
......@@ -420,7 +421,7 @@ class TestCompactBase:
method: add entities, delete and compact collection; server stopped and restarted during compact
expected: status ok, request recovered
'''
entities = gen_vector(nb * 100, dim)
entities = gen_vectors(nb * 100, dim)
status, ids = connect.insert(collection, entities)
assert status.OK()
status = connect.flush([collection])
......@@ -441,76 +442,76 @@ class TestCompactBase:
assert info["partitions"][0].count == nb * 100 - 1000
class TestCompactJAC:
class TestCompactBinary:
"""
******************************************************************
The following cases are used to test `compact` function
******************************************************************
"""
@pytest.mark.timeout(COMPACT_TIMEOUT)
def test_add_entity_and_compact(self, connect, jac_collection):
def test_add_entity_and_compact(self, connect, binary_collection):
'''
target: test add binary vector and compact
method: add vector and compact collection
expected: status ok, vector added
'''
ids = connect.insert(jac_collection, binary_entity)
ids = connect.insert(binary_collection, binary_entity)
assert len(ids) == 1
connect.flush([jac_collection])
connect.flush([binary_collection])
# get collection info before compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
status = connect.compact(jac_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert(size_before == size_after)
@pytest.mark.timeout(COMPACT_TIMEOUT)
def test_insert_and_compact(self, connect, jac_collection):
def test_insert_and_compact(self, connect, binary_collection):
'''
target: test add entities with binary vector and compact
method: add entities and compact collection
expected: status ok, entities added
'''
ids = connect.insert(jac_collection, binary_entities)
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
connect.flush([jac_collection])
connect.flush([binary_collection])
# get collection info before compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
status = connect.compact(jac_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert(size_before == size_after)
@pytest.mark.timeout(COMPACT_TIMEOUT)
@pytest.mark.skip(reason="delete not support yet")
def test_insert_delete_part_and_compact(self, connect, jac_collection):
def test_insert_delete_part_and_compact(self, connect, binary_collection):
'''
target: test add entities, delete part of them and compact
method: add entities, delete a few and compact collection
expected: status ok, data size is smaller after compact
'''
ids = connect.insert(jac_collection, binary_entities)
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
connect.flush([jac_collection])
connect.flush([binary_collection])
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(jac_collection, delete_ids)
status = connect.delete_entity_by_id(binary_collection, delete_ids)
assert status.OK()
connect.flush([jac_collection])
connect.flush([binary_collection])
# get collection info before compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
logging.getLogger().info(info["partitions"])
size_before = info["partitions"][0]["segments"][0]["data_size"]
logging.getLogger().info(size_before)
status = connect.compact(jac_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
logging.getLogger().info(info["partitions"])
size_after = info["partitions"][0]["segments"][0]["data_size"]
logging.getLogger().info(size_after)
......@@ -518,82 +519,82 @@ class TestCompactJAC:
@pytest.mark.timeout(COMPACT_TIMEOUT)
@pytest.mark.skip(reason="delete not support yet")
def test_insert_delete_all_and_compact(self, connect, jac_collection):
def test_insert_delete_all_and_compact(self, connect, binary_collection):
'''
target: test add entities, delete them and compact
method: add entities, delete all and compact collection
expected: status ok, no data size in collection info because collection is empty
'''
ids = connect.insert(jac_collection, binary_entities)
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
connect.flush([jac_collection])
status = connect.delete_entity_by_id(jac_collection, ids)
connect.flush([binary_collection])
status = connect.delete_entity_by_id(binary_collection, ids)
assert status.OK()
connect.flush([jac_collection])
connect.flush([binary_collection])
# get collection info before compact
info = connect.get_collection_stats(jac_collection)
status = connect.compact(jac_collection)
info = connect.get_collection_stats(binary_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
assert status.OK()
logging.getLogger().info(info["partitions"])
assert not info["partitions"][0]["segments"]
@pytest.mark.timeout(COMPACT_TIMEOUT)
def test_add_entity_and_compact_twice(self, connect, jac_collection):
def test_add_entity_and_compact_twice(self, connect, binary_collection):
'''
target: test add entity and compact twice
method: add entity and compact collection twice
expected: status ok
'''
ids = connect.insert(jac_collection, binary_entity)
ids = connect.insert(binary_collection, binary_entity)
assert len(ids) == 1
connect.flush([jac_collection])
connect.flush([binary_collection])
# get collection info before compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
status = connect.compact(jac_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert(size_before == size_after)
status = connect.compact(jac_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact twice
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_after_twice = info["partitions"][0]["segments"][0]["data_size"]
assert(size_after == size_after_twice)
@pytest.mark.timeout(COMPACT_TIMEOUT)
@pytest.mark.skip(reason="delete not support yet")
def test_insert_delete_part_and_compact_twice(self, connect, jac_collection):
def test_insert_delete_part_and_compact_twice(self, connect, binary_collection):
'''
target: test add entities, delete part of them and compact twice
method: add entities, delete part and compact collection twice
expected: status ok, data size smaller after first compact, no change after second
'''
ids = connect.insert(jac_collection, binary_entities)
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
connect.flush([jac_collection])
connect.flush([binary_collection])
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(jac_collection, delete_ids)
status = connect.delete_entity_by_id(binary_collection, delete_ids)
assert status.OK()
connect.flush([jac_collection])
connect.flush([binary_collection])
# get collection info before compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
status = connect.compact(jac_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert(size_before >= size_after)
status = connect.compact(jac_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact twice
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_after_twice = info["partitions"][0]["segments"][0]["data_size"]
assert(size_after == size_after_twice)
......@@ -612,8 +613,7 @@ class TestCompactJAC:
for i in range(num_collections):
collection_name = gen_unique_str("test_compact_multi_collection_%d" % i)
collection_list.append(collection_name)
fields = update_fields_metric_type(default_fields, "JACCARD")
connect.create_collection(collection_name, fields)
connect.create_collection(collection_name, default_fields)
for i in range(num_collections):
ids = connect.insert(collection_list[i], entities)
assert len(ids) == nq
......@@ -627,66 +627,66 @@ class TestCompactJAC:
@pytest.mark.level(2)
@pytest.mark.timeout(COMPACT_TIMEOUT)
def test_add_entity_after_compact(self, connect, jac_collection):
def test_add_entity_after_compact(self, connect, binary_collection):
'''
target: test add entity after compact
method: after compact operation, add entity
expected: status ok, entity added
'''
ids = connect.insert(jac_collection, binary_entities)
connect.flush([jac_collection])
ids = connect.insert(binary_collection, binary_entities)
connect.flush([binary_collection])
# get collection info before compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
status = connect.compact(jac_collection)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(jac_collection)
info = connect.get_collection_stats(binary_collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert(size_before == size_after)
ids = connect.insert(jac_collection, binary_entity)
connect.flush([jac_collection])
res = connect.count_entities(jac_collection)
ids = connect.insert(binary_collection, binary_entity)
connect.flush([binary_collection])
res = connect.count_entities(binary_collection)
assert res == nb + 1
@pytest.mark.timeout(COMPACT_TIMEOUT)
@pytest.mark.skip(reason="delete not support yet")
def test_delete_entities_after_compact(self, connect, jac_collection):
def test_delete_entities_after_compact(self, connect, binary_collection):
'''
target: test delete entities after compact
method: after compact operation, delete entities
expected: status ok, entities deleted
'''
ids = connect.insert(jac_collection, binary_entities)
connect.flush([jac_collection])
status = connect.compact(jac_collection)
ids = connect.insert(binary_collection, binary_entities)
connect.flush([binary_collection])
status = connect.compact(binary_collection)
assert status.OK()
connect.flush([jac_collection])
status = connect.delete_entity_by_id(jac_collection, ids)
connect.flush([binary_collection])
status = connect.delete_entity_by_id(binary_collection, ids)
assert status.OK()
connect.flush([jac_collection])
res = connect.count_entities(jac_collection)
connect.flush([binary_collection])
res = connect.count_entities(binary_collection)
assert res == 0
@pytest.mark.skip(reason="search not support yet")
@pytest.mark.timeout(COMPACT_TIMEOUT)
def test_search_after_compact(self, connect, jac_collection):
def test_search_after_compact(self, connect, binary_collection):
'''
target: test search after compact
method: after compact operation, search vector
expected: status ok
'''
ids = connect.insert(jac_collection, binary_entities)
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
connect.flush([jac_collection])
status = connect.compact(jac_collection)
connect.flush([binary_collection])
status = connect.compact(binary_collection)
assert status.OK()
query_vecs = [raw_vectors[0]]
distance = jaccard(query_vecs[0], raw_vectors[0])
query = copy.deepcopy(default_single_query)
query["bool"]["must"][0]["vector"][field_name]["query"] = [binary_entities[-1]["values"][0],
binary_entities[-1]["values"][-1]]
res = connect.search(jac_collection, query)
res = connect.search(binary_collection, query)
assert abs(res[0]._distances[0]-distance) <= epsilon
# TODO:
......
......@@ -87,23 +87,23 @@ class TestFlushBase:
# with pytest.raises(Exception) as e:
# connect.flush([collection])
def test_add_partition_flush(self, connect, collection):
def test_add_partition_flush(self, connect, id_collection):
'''
method: add entities into partition in collection, flush serveral times
expected: the length of ids and the collection row count
'''
# vector = gen_vector(nb, dim)
connect.create_partition(collection, tag)
connect.create_partition(id_collection, tag)
# vectors = gen_vectors(nb, dim)
ids = [i for i in range(nb)]
ids = connect.insert(collection, entities, ids)
connect.flush([collection])
res_count = connect.count_entities(collection)
ids = connect.insert(id_collection, entities, ids)
connect.flush([id_collection])
res_count = connect.count_entities(id_collection)
assert res_count == nb
ids = connect.insert(collection, entities, ids, partition_tag=tag)
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
assert len(ids) == nb
connect.flush([collection])
res_count = connect.count_entities(collection)
connect.flush([id_collection])
res_count = connect.count_entities(id_collection)
assert res_count == nb * 2
def test_add_partitions_flush(self, connect, collection):
......@@ -190,19 +190,19 @@ class TestFlushBase:
assert res
# TODO: stable case
def test_add_flush_auto(self, connect, collection):
def test_add_flush_auto(self, connect, id_collection):
'''
method: add entities
expected: no error raised
'''
# vectors = gen_vectors(nb, dim)
ids = [i for i in range(nb)]
ids = connect.insert(collection, entities, ids)
ids = connect.insert(id_collection, entities, ids)
timeout = 10
start_time = time.time()
while (time.time() - start_time < timeout):
time.sleep(1)
res = connect.count_entities(collection)
res = connect.count_entities(id_collection)
if res == nb:
break
if time.time() - start_time > timeout:
......@@ -218,7 +218,7 @@ class TestFlushBase:
def same_ids(self, request):
yield request.param
def test_add_flush_same_ids(self, connect, collection, same_ids):
def test_add_flush_same_ids(self, connect, id_collection, same_ids):
'''
method: add entities, with same ids, count(same ids) < 15, > 15
expected: the length of ids and the collection row count
......@@ -228,9 +228,9 @@ class TestFlushBase:
for i, item in enumerate(ids):
if item <= same_ids:
ids[i] = 0
ids = connect.insert(collection, entities, ids)
connect.flush([collection])
res = connect.count_entities(collection)
ids = connect.insert(id_collection, entities, ids)
connect.flush([id_collection])
res = connect.count_entities(id_collection)
assert res == nb
@pytest.mark.skip(reason="search not support yet")
......
......@@ -489,7 +489,7 @@ class TestIndexBase:
connect.drop_index(collection, field_name)
class TestIndexJAC:
class TestIndexBinary:
@pytest.fixture(
scope="function",
params=gen_simple_index()
......@@ -529,29 +529,29 @@ class TestIndexJAC:
"""
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index(self, connect, jac_collection, get_jaccard_index):
def test_create_index(self, connect, binary_collection, get_jaccard_index):
'''
target: test create index interface
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.insert(jac_collection, binary_entities)
connect.create_index(jac_collection, binary_field_name, get_jaccard_index)
ids = connect.insert(binary_collection, binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_partition(self, connect, jac_collection, get_jaccard_index):
def test_create_index_partition(self, connect, binary_collection, get_jaccard_index):
'''
target: test create index interface
method: create collection, create partition, and add entities in it, create index
expected: return search success
'''
connect.create_partition(jac_collection, tag)
ids = connect.insert(jac_collection, binary_entities, partition_tag=tag)
connect.create_index(jac_collection, binary_field_name, get_jaccard_index)
connect.create_partition(binary_collection, tag)
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
# TODO:
@pytest.mark.timeout(BUILD_TIMEOUT)
def _test_create_index_search_with_query_vectors(self, connect, jac_collection, get_jaccard_index, get_nq):
def _test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq):
'''
target: test create index interface, search with more query vectors
method: create collection and add entities in it, create index
......@@ -559,11 +559,11 @@ class TestIndexJAC:
'''
nq = get_nq
pdb.set_trace()
ids = connect.insert(jac_collection, binary_entities)
connect.create_index(jac_collection, binary_field_name, get_jaccard_index)
ids = connect.insert(binary_collection, binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
search_param = get_search_param(get_jaccard_index["index_type"])
res = connect.search(jac_collection, query, search_params=search_param)
search_param = get_search_param(binary_collection["index_type"])
res = connect.search(binary_collection, query, search_params=search_param)
logging.getLogger().info(res)
assert len(res) == nq
......@@ -573,7 +573,7 @@ class TestIndexJAC:
******************************************************************
"""
def test_get_index_info(self, connect, jac_collection, get_jaccard_index):
def test_get_index_info(self, connect, binary_collection, get_jaccard_index):
'''
target: test describe index interface
method: create collection and add entities in it, create index, call describe index
......@@ -581,14 +581,14 @@ class TestIndexJAC:
'''
if get_jaccard_index["index_type"] == "BIN_FLAT":
pytest.skip("GetCollectionStats skip BIN_FLAT")
ids = connect.insert(jac_collection, binary_entities)
connect.flush([jac_collection])
connect.create_index(jac_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(jac_collection)
ids = connect.insert(binary_collection, binary_entities)
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
logging.getLogger().info(stats)
assert stats['partitions'][0]['segments'][0]['index_name'] == get_jaccard_index['index_type']
def test_get_index_info_partition(self, connect, jac_collection, get_jaccard_index):
def test_get_index_info_partition(self, connect, binary_collection, get_jaccard_index):
'''
target: test describe index interface
method: create collection, create partition and add entities in it, create index, call describe index
......@@ -596,11 +596,11 @@ class TestIndexJAC:
'''
if get_jaccard_index["index_type"] == "BIN_FLAT":
pytest.skip("GetCollectionStats skip BIN_FLAT")
connect.create_partition(jac_collection, tag)
ids = connect.insert(jac_collection, binary_entities, partition_tag=tag)
connect.flush([jac_collection])
connect.create_index(jac_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(jac_collection)
connect.create_partition(binary_collection, tag)
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
logging.getLogger().info(stats)
assert stats['partitions'][1]['segments'][0]['index_name'] == get_jaccard_index['index_type']
......@@ -642,10 +642,6 @@ class TestIndexJAC:
assert stats["partitions"][1]["segments"][0]["index_name"] == default_index_type
class TestIndexBinary:
pass
class TestIndexMultiCollections(object):
@pytest.mark.level(2)
......
......@@ -98,15 +98,15 @@ class TestCreateBase:
assert tag_name in tag_list
assert "_default" in tag_list
def test_create_partition_insert_default(self, connect, collection):
def test_create_partition_insert_default(self, connect, id_collection):
'''
target: test create partition, and insert vectors, check status returned
method: call function: create_partition
expected: status ok
'''
connect.create_partition(collection, tag)
connect.create_partition(id_collection, tag)
ids = [i for i in range(nb)]
insert_ids = connect.insert(collection, entities, ids)
insert_ids = connect.insert(id_collection, entities, ids)
assert len(insert_ids) == len(ids)
def test_create_partition_insert_with_tag(self, connect, collection):
......
......@@ -206,7 +206,7 @@ def gen_single_vector_fields():
return fields
def gen_default_fields():
def gen_default_fields(auto_id=False):
default_fields = {
"fields": [
{"field": "int64", "type": DataType.INT64},
......@@ -215,6 +215,22 @@ def gen_default_fields():
],
"segment_row_count": segment_row_count
}
if auto_id is True:
default_fields["auto_id"] = True
return default_fields
def gen_binary_default_fields(auto_id=False):
default_fields = {
"fields": [
{"field": "int64", "type": DataType.INT64},
{"field": "float", "type": DataType.FLOAT},
{"field": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": dimension}}
],
"segment_row_count": segment_row_count
}
if auto_id is True:
default_fields["auto_id"] = True
return default_fields
......@@ -291,14 +307,14 @@ def gen_default_vector_expr(default_query):
def gen_default_term_expr(keyword="term", values=None):
if values is None:
values = [i for i in range(nb / 2)]
values = [i for i in range(nb // 2)]
expr = {keyword: {"int64": {"values": values}}}
return expr
def gen_default_range_expr(ranges=None):
if ranges is None:
ranges = {"GT": 1, "LT": nb / 2}
ranges = {"GT": 1, "LT": nb // 2}
expr = {"range": {"int64": {"ranges": ranges}}}
return expr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册