提交 2b32e6c9 编写于 作者: N neza2017 提交者: yefu.chen

Load collection manually

Signed-off-by: Nneza2017 <yefu.chen@zilliz.com>
上级 000e30d2
......@@ -171,6 +171,7 @@ def collection(request, connect):
try:
default_fields = gen_default_fields()
connect.create_collection(collection_name, default_fields)
connect.load_collection(collection_name)
except Exception as e:
pytest.exit(str(e))
def teardown():
......@@ -189,6 +190,7 @@ def id_collection(request, connect):
try:
fields = gen_default_fields(auto_id=False)
connect.create_collection(collection_name, fields)
connect.load_collection(collection_name)
except Exception as e:
pytest.exit(str(e))
def teardown():
......@@ -206,6 +208,7 @@ def binary_collection(request, connect):
try:
fields = gen_binary_default_fields()
connect.create_collection(collection_name, fields)
connect.load_collection(collection_name)
except Exception as e:
pytest.exit(str(e))
def teardown():
......@@ -225,6 +228,7 @@ def binary_id_collection(request, connect):
try:
fields = gen_binary_default_fields(auto_id=False)
connect.create_collection(collection_name, fields)
connect.load_collection(collection_name)
except Exception as e:
pytest.exit(str(e))
def teardown():
......
此差异已折叠。
......@@ -97,7 +97,7 @@ class TestCreateCollection:
expected: error raised
'''
# pdb.set_trace()
connect.bulk_insert(collection, default_entity)
connect.insert(collection, default_entity)
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
......@@ -108,7 +108,7 @@ class TestCreateCollection:
method: insert vector and create collection
expected: error raised
'''
connect.bulk_insert(collection, default_entity)
connect.insert(collection, default_entity)
connect.flush([collection])
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
......
......@@ -33,9 +33,9 @@ class TestInfoBase:
)
def get_simple_index(self, request, connect):
logging.getLogger().info(request.param)
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
# if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return request.param
"""
......@@ -88,7 +88,7 @@ class TestInfoBase:
@pytest.mark.skip("no create Index")
def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index):
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
info = connect.get_index_info(collection, field_name)
info = connect.describe_index(collection, field_name)
assert info == get_simple_index
res = connect.get_collection_info(collection, default_float_vec_field_name)
assert index["index_type"] == get_simple_index["index_type"]
......@@ -161,7 +161,7 @@ class TestInfoBase:
}
connect.create_collection(collection_name, fields)
entities = gen_entities_by_fields(fields["fields"], default_nb, vector_field["params"]["dim"])
res_ids = connect.bulk_insert(collection_name, entities)
res_ids = connect.insert(collection_name, entities)
connect.flush([collection_name])
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
......@@ -186,7 +186,7 @@ class TestInfoBase:
fields["segment_row_limit"] = get_segment_row_limit
connect.create_collection(collection_name, fields)
entities = gen_entities_by_fields(fields["fields"], default_nb, fields["fields"][-1]["params"]["dim"])
res_ids = connect.bulk_insert(collection_name, entities)
res_ids = connect.insert(collection_name, entities)
connect.flush([collection_name])
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
......
......@@ -26,9 +26,9 @@ class TestIndexBase:
def get_simple_index(self, request, connect):
import copy
logging.getLogger().info(request.param)
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
#if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return copy.deepcopy(request.param)
@pytest.fixture(
......@@ -55,7 +55,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
connect.create_index(collection, field_name, get_simple_index)
def test_create_index_on_field_not_existed(self, connect, collection, get_simple_index):
......@@ -65,7 +65,7 @@ class TestIndexBase:
expected: error raised
'''
tmp_field_name = gen_unique_str()
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
with pytest.raises(Exception) as e:
connect.create_index(collection, tmp_field_name, get_simple_index)
......@@ -77,7 +77,7 @@ class TestIndexBase:
expected: error raised
'''
tmp_field_name = "int64"
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
with pytest.raises(Exception) as e:
connect.create_index(collection, tmp_field_name, get_simple_index)
......@@ -98,7 +98,7 @@ class TestIndexBase:
expected: return search success
'''
connect.create_partition(collection, default_tag)
ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush([collection])
connect.create_index(collection, field_name, get_simple_index)
......@@ -110,7 +110,7 @@ class TestIndexBase:
expected: return search success
'''
connect.create_partition(collection, default_tag)
ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush()
connect.create_index(collection, field_name, get_simple_index)
......@@ -131,7 +131,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
connect.create_index(collection, field_name, get_simple_index)
# logging.getLogger().info(connect.get_collection_stats(collection))
nq = get_nq
......@@ -150,7 +150,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
connect.bulk_insert(collection, default_entities)
connect.insert(collection, default_entities)
def build(connect):
connect.create_index(collection, field_name, default_index)
......@@ -187,7 +187,7 @@ class TestIndexBase:
expected: create index ok, and count correct
'''
connect.create_index(collection, field_name, get_simple_index)
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
connect.flush([collection])
count = connect.count_entities(collection)
assert count == default_nb
......@@ -213,7 +213,7 @@ class TestIndexBase:
method: create another index with different index_params after index have been built
expected: return code 0, and describe index result equals with the second index params
'''
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
indexs = [default_index, {"metric_type":"L2", "index_type": "FLAT", "params":{"nlist": 1024}}]
for index in indexs:
connect.create_index(collection, field_name, index)
......@@ -228,7 +228,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
......@@ -250,7 +250,7 @@ class TestIndexBase:
expected: return search success
'''
connect.create_partition(collection, default_tag)
ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush([collection])
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
......@@ -263,7 +263,7 @@ class TestIndexBase:
expected: return search success
'''
connect.create_partition(collection, default_tag)
ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.flush()
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
......@@ -277,7 +277,7 @@ class TestIndexBase:
expected: return search success
'''
metric_type = "IP"
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
get_simple_index["metric_type"] = metric_type
connect.create_index(collection, field_name, get_simple_index)
# logging.getLogger().info(connect.get_collection_stats(collection))
......@@ -297,7 +297,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
'''
connect.bulk_insert(collection, default_entities)
connect.insert(collection, default_entities)
def build(connect):
default_index["metric_type"] = "IP"
......@@ -336,7 +336,7 @@ class TestIndexBase:
'''
default_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
connect.flush([collection])
count = connect.count_entities(collection)
assert count == default_nb
......@@ -364,7 +364,7 @@ class TestIndexBase:
method: create another index with different index_params after index have been built
expected: return code 0, and describe index result equals with the second index params
'''
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
indexs = [default_index, {"index_type": "FLAT", "params": {"nlist": 1024}, "metric_type": "IP"}]
for index in indexs:
connect.create_index(collection, field_name, index)
......@@ -385,7 +385,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index, call drop index
expected: return code 0, and default index param
'''
# ids = connect.bulk_insert(collection, entities)
# ids = connect.insert(collection, entities)
connect.create_index(collection, field_name, get_simple_index)
connect.drop_index(collection, field_name)
stats = connect.get_collection_stats(collection)
......@@ -439,7 +439,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return code not equals to 0, drop index failed
'''
# ids = connect.bulk_insert(collection, entities)
# ids = connect.insert(collection, entities)
# no create index
connect.drop_index(collection, field_name)
......@@ -462,7 +462,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index, call drop index
expected: return code 0, and default index param
'''
# ids = connect.bulk_insert(collection, entities)
# ids = connect.insert(collection, entities)
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
connect.drop_index(collection, field_name)
......@@ -506,7 +506,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return code not equals to 0, drop index failed
'''
# ids = connect.bulk_insert(collection, entities)
# ids = connect.insert(collection, entities)
# no create index
connect.drop_index(collection, field_name)
......@@ -579,7 +579,7 @@ class TestIndexBinary:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(binary_collection, default_binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
@pytest.mark.timeout(BUILD_TIMEOUT)
......@@ -590,7 +590,7 @@ class TestIndexBinary:
expected: return search success
'''
connect.create_partition(binary_collection, default_tag)
ids = connect.bulk_insert(binary_collection, default_binary_entities, partition_tag=default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
@pytest.mark.skip("r0.3-test")
......@@ -602,7 +602,7 @@ class TestIndexBinary:
expected: return search success
'''
nq = get_nq
ids = connect.bulk_insert(binary_collection, default_binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD")
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
......@@ -619,7 +619,7 @@ class TestIndexBinary:
expected: return create_index failure
'''
# insert 6000 vectors
ids = connect.bulk_insert(binary_collection, default_binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
connect.flush([binary_collection])
if get_l2_index["index_type"] == "BIN_FLAT":
......@@ -641,7 +641,7 @@ class TestIndexBinary:
method: create collection and add entities in it, create index, call describe index
expected: return code 0, and index instructure
'''
ids = connect.bulk_insert(binary_collection, default_binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
......@@ -662,7 +662,7 @@ class TestIndexBinary:
expected: return code 0, and index instructure
'''
connect.create_partition(binary_collection, default_tag)
ids = connect.bulk_insert(binary_collection, default_binary_entities, partition_tag=default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
......@@ -706,7 +706,7 @@ class TestIndexBinary:
expected: return code 0, and default index param
'''
connect.create_partition(binary_collection, default_tag)
ids = connect.bulk_insert(binary_collection, default_binary_entities, partition_tag=default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
......@@ -802,7 +802,7 @@ class TestIndexAsync:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
logging.getLogger().info("start index")
future = connect.create_index(collection, field_name, get_simple_index, _async=True)
logging.getLogger().info("before result")
......@@ -817,7 +817,7 @@ class TestIndexAsync:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
logging.getLogger().info("start index")
future = connect.create_index(collection, field_name, get_simple_index, _async=True)
logging.getLogger().info("DROP")
......@@ -837,7 +837,7 @@ class TestIndexAsync:
method: create collection and add entities in it, create index
expected: return search success
'''
ids = connect.bulk_insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
logging.getLogger().info("start index")
future = connect.create_index(collection, field_name, get_simple_index, _async=True,
_callback=self.check_result)
......
此差异已折叠。
......@@ -101,7 +101,7 @@ class TestCreateBase:
'''
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.bulk_insert(id_collection, default_entities, ids)
insert_ids = connect.insert(id_collection, default_entities, ids)
assert len(insert_ids) == len(ids)
@pytest.mark.skip("not support custom id")
......@@ -113,7 +113,7 @@ class TestCreateBase:
'''
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
assert len(insert_ids) == len(ids)
def test_create_partition_insert_with_tag_not_existed(self, connect, collection):
......@@ -126,7 +126,7 @@ class TestCreateBase:
connect.create_partition(collection, default_tag)
ids = [i for i in range(default_nb)]
with pytest.raises(Exception) as e:
insert_ids = connect.bulk_insert(collection, default_entities, ids, partition_tag=tag_new)
insert_ids = connect.insert(collection, default_entities, ids, partition_tag=tag_new)
@pytest.mark.skip("not support custom id")
def test_create_partition_insert_same_tags(self, connect, id_collection):
......@@ -137,9 +137,9 @@ class TestCreateBase:
'''
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
ids = [(i+default_nb) for i in range(default_nb)]
new_insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
new_insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
connect.flush([id_collection])
res = connect.count_entities(id_collection)
assert res == default_nb * 2
......@@ -156,8 +156,8 @@ class TestCreateBase:
collection_new = gen_unique_str()
connect.create_collection(collection_new, default_fields)
connect.create_partition(collection_new, default_tag)
ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag)
ids = connect.bulk_insert(collection_new, default_entities, partition_tag=default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
ids = connect.insert(collection_new, default_entities, partition_tag=default_tag)
connect.flush([collection, collection_new])
res = connect.count_entities(collection)
assert res == default_nb
......
......@@ -36,14 +36,14 @@ def init_data(connect, collection, nb=1200, partition_tags=None, auto_id=True):
insert_entities = gen_entities(nb, is_normal=True)
if partition_tags is None:
if auto_id:
ids = connect.bulk_insert(collection, insert_entities)
ids = connect.insert(collection, insert_entities)
else:
ids = connect.bulk_insert(collection, insert_entities, ids=[i for i in range(nb)])
ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
else:
if auto_id:
ids = connect.bulk_insert(collection, insert_entities, partition_tag=partition_tags)
ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
else:
ids = connect.bulk_insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
# connect.flush([collection])
return insert_entities, ids
......@@ -62,9 +62,9 @@ def init_binary_data(connect, collection, nb=1200, insert=True, partition_tags=N
insert_raw_vectors, insert_entities = gen_binary_entities(nb)
if insert is True:
if partition_tags is None:
ids = connect.bulk_insert(collection, insert_entities)
ids = connect.insert(collection, insert_entities)
else:
ids = connect.bulk_insert(collection, insert_entities, partition_tag=partition_tags)
ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
connect.flush([collection])
return insert_raw_vectors, insert_entities, ids
......@@ -79,9 +79,9 @@ class TestSearchBase:
params=gen_index()
)
def get_index(self, request, connect):
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
# if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return request.param
@pytest.fixture(
......@@ -90,9 +90,9 @@ class TestSearchBase:
)
def get_simple_index(self, request, connect):
import copy
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
# if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return copy.deepcopy(request.param)
@pytest.fixture(
......@@ -1250,7 +1250,7 @@ class TestSearchDSL(object):
collection_term = gen_unique_str("term")
connect.create_collection(collection_term, term_fields)
term_entities = add_field(entities, field_name="term")
ids = connect.bulk_insert(collection_term, term_entities)
ids = connect.insert(collection_term, term_entities)
assert len(ids) == default_nb
connect.flush([collection_term])
count = connect.count_entities(collection_term) # count_entities is not impelmented
......@@ -1695,9 +1695,9 @@ class TestSearchInvalid(object):
params=gen_simple_index()
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
# if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return request.param
# PASS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册