未验证 提交 41b2bb58 编写于 作者: D del-zhenwu 提交者: GitHub

[skip ci] update some cases (#2700)

* [skip ci] remove timeout in partition-test case
Signed-off-by: Nzw <zw@milvus.io>

* Update server_versiong
Signed-off-by: Nzw <zw@milvus.io>

* fix client_test.go
Signed-off-by: Nzw <zw@milvus.io>

* Add ci param is_manual_trigger
Signed-off-by: Nzw <zw@milvus.io>

* update ci param
Signed-off-by: Nzw <zw@milvus.io>

* [skip ci] update some cases
Signed-off-by: Nzw <zw@milvus.io>
Co-authored-by: Nzw <zw@milvus.io>
上级 7f1501a8
......@@ -3,19 +3,22 @@ import copy
import logging
import itertools
from time import sleep
import threading
from multiprocessing import Process
import sklearn.preprocessing
import pytest
from milvus import IndexType, MetricType
from utils import *
nb = 1
dim = 128
collection_id = "create_collection"
default_segment_size = 1024
drop_collection_interval_time = 3
segment_size = 10
vectors = gen_vectors(100, dim)
default_fields = gen_default_fields()
entities = gen_entities(nb)
class TestCreateCollection:
......@@ -52,14 +55,16 @@ class TestCreateCollection:
expected: no exception raised
'''
filter_field = get_filter_field
logging.getLogger().info(filter_field)
vector_field = get_vector_field
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
fields = {
"fields": [filter_field, vector_field],
"segment_size": segment_size
}
logging.getLogger().info(fields)
connect.create_collection(collection_name, fields)
assert_has_collection(collection_name)
assert connect.has_collection(collection_name)
# TODO
def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field):
......@@ -70,13 +75,13 @@ class TestCreateCollection:
'''
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
fields = {
"fields": [filter_field, vector_field],
"segment_size": segment_size
}
connect.create_collection(collection_name, fields)
assert_has_collection(collection_name)
assert connect.has_collection(collection_name)
def test_create_collection_segment_size(self, connect, get_segment_size):
'''
......@@ -85,13 +90,13 @@ class TestCreateCollection:
expected: no exception raised
'''
segment_size = get_segment_size
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
fields = {
"fields": default_fields["fields"],
"segment_size": segment_size
}
connect.create_collection(collection_name, fields)
assert_has_collection(collection_name)
assert connect.has_collection(collection_name)
def test_create_collection_auto_flush_disabled(self, connect):
'''
......@@ -100,11 +105,12 @@ class TestCreateCollection:
expected: create status return ok
'''
disable_flush(connect)
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
try:
connect.create_collection(collection_name, default_fields)
finally:
enable_flush(connect)
# pdb.set_trace()
def test_create_collection_after_insert(self, connect, collection):
'''
......@@ -112,7 +118,9 @@ class TestCreateCollection:
method: insert vector and create collection
expected: error raised
'''
# pdb.set_trace()
connect.insert(collection, entities)
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
......@@ -135,7 +143,7 @@ class TestCreateCollection:
method: create collection with correct params, with a disconnected instance
expected: create raise exception
'''
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
......@@ -145,7 +153,7 @@ class TestCreateCollection:
method: create collection with the same collection_name
expected: create status return not ok
'''
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
connect.create_collection(collection_name, default_fields)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
......@@ -157,14 +165,14 @@ class TestCreateCollection:
method: create collection using multithread,
expected: collections are created
'''
threads_num = 4
threads_num = 8
threads = []
collection_names = []
def create():
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
collection_names.append(collection_name)
connect.create_collection(collection_name, fields)
connect.create_collection(collection_name, default_fields)
for i in range(threads_num):
t = threading.Thread(target=create, args=())
threads.append(t)
......@@ -174,8 +182,8 @@ class TestCreateCollection:
t.join()
res = connect.list_collections()
for item in res:
assert item in collection_names
for item in collection_names:
assert item in res
class TestCreateCollectionInvalid(object):
......@@ -198,14 +206,14 @@ class TestCreateCollectionInvalid(object):
@pytest.fixture(
scope="function",
params=gen_invalid_dims()
params=gen_invalid_ints()
)
def get_dim(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strings()
params=gen_invalid_strs()
)
def get_invalid_string(self, request):
yield request.param
......@@ -229,7 +237,7 @@ class TestCreateCollectionInvalid(object):
def test_create_collection_with_invalid_metric_type(self, connect, get_metric_type):
collection_name = gen_unique_str()
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["extra_params"]["metric_type"] = get_metric_type
fields["fields"][-1]["params"]["metric_type"] = get_metric_type
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
......@@ -238,13 +246,13 @@ class TestCreateCollectionInvalid(object):
dimension = get_dim
collection_name = gen_unique_str()
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["extra_params"]["dimension"] = dimension
fields["fields"][-1]["params"]["dimension"] = dimension
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@pytest.mark.level(2)
def test_create_collection_with_invalid_collectionname(self, connect, get_invalid_string):
collection_name = get_collection_name
collection_name = get_invalid_string
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
......@@ -275,9 +283,9 @@ class TestCreateCollectionInvalid(object):
method: create collection with corrent params
expected: create status return ok
'''
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["extra_params"].pop("dimension")
fields["fields"][-1]["params"].pop("dimension")
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
......@@ -287,7 +295,7 @@ class TestCreateCollectionInvalid(object):
method: create collection with corrent params
expected: create status return ok, use default default_segment_size
'''
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
fields.pop("segment_size")
connect.create_collection(collection_name, fields)
......@@ -301,7 +309,7 @@ class TestCreateCollectionInvalid(object):
method: create collection with corrent params
expected: create status return ok, use default L2
'''
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
fields["fields"][-1].pop("metric_type")
connect.create_collection(collection_name, fields)
......@@ -311,7 +319,7 @@ class TestCreateCollectionInvalid(object):
# TODO: assert exception
def test_create_collection_limit_fields(self, connect):
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
limit_num = 64
fields = copy.deepcopy(default_fields)
for i in range(limit_num):
......@@ -323,7 +331,7 @@ class TestCreateCollectionInvalid(object):
# TODO: assert exception
def test_create_collection_invalid_field_name(self, connect, get_invalid_string):
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
field_name = get_invalid_string
field = {"field": field_name, "type": DataType.INT8}
......@@ -333,7 +341,7 @@ class TestCreateCollectionInvalid(object):
# TODO: assert exception
def test_create_collection_invalid_field_type(self, connect, get_field_type):
collection_name = gen_unique_str("test_collection")
collection_name = gen_unique_str(collection_id)
fields = copy.deepcopy(default_fields)
field_type = get_field_type
field = {"field": "test_field", "type": field_type}
......
......@@ -7,7 +7,7 @@ from multiprocessing import Process
from milvus import IndexType, MetricType
from utils import *
uniq_id = "test_drop_collection"
uniq_id = "drop_collection"
default_fields = gen_default_fields()
......@@ -26,7 +26,8 @@ class TestDropCollection:
expected: status ok, and no collection in collections
'''
connect.drop_collection(collection)
assert not assert_has_collection(connect, collection)
time.sleep(2)
assert not connect.has_collection(collection)
@pytest.mark.level(2)
def test_drop_collection_without_connection(self, collection, dis_connect):
......@@ -47,7 +48,7 @@ class TestDropCollection:
'''
collection_name = gen_unique_str(uniq_id)
with pytest.raises(Exception) as e:
assert not assert_has_collection(connect, collection_name)
connect.drop_collection(collection_name)
class TestDropCollectionInvalid(object):
......@@ -56,7 +57,7 @@ class TestDropCollectionInvalid(object):
"""
@pytest.fixture(
scope="function",
params=gen_invalid_collection_names()
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
......
......@@ -2,11 +2,13 @@ import pdb
import pytest
import logging
import itertools
import threading
from time import sleep
from multiprocessing import Process
from milvus import IndexType, MetricType
from utils import *
collection_id = "has_collection"
default_fields = gen_default_fields()
......@@ -23,7 +25,7 @@ class TestHasCollection:
method: create collection, assert the value returned by has_collection method
expected: True
'''
assert assert_has_collection(connect, collection)
assert connect.has_collection(collection)
@pytest.mark.level(2)
def test_has_collection_without_connection(self, collection, dis_connect):
......@@ -33,7 +35,7 @@ class TestHasCollection:
expected: has collection raise exception
'''
with pytest.raises(Exception) as e:
assert_has_collection(dis_connect, collection)
assert connect.has_collection(collection)
def test_has_collection_not_existed(self, connect):
'''
......@@ -43,7 +45,7 @@ class TestHasCollection:
expected: False
'''
collection_name = gen_unique_str("test_collection")
assert not assert_has_collection(connect, collection_name)
assert not connect.has_collection(collection_name)
@pytest.mark.level(2)
def test_has_collection_multithread(self, connect):
......@@ -54,8 +56,8 @@ class TestHasCollection:
'''
threads_num = 4
threads = []
collection_name = gen_unique_str("test_collection")
connect.create_collection(collection_name, fields)
collection_name = gen_unique_str(collection_id)
connect.create_collection(collection_name, default_fields)
def has():
assert not assert_collection(connect, collection_name)
......@@ -74,7 +76,7 @@ class TestHasCollectionInvalid(object):
"""
@pytest.fixture(
scope="function",
params=gen_invalid_collection_names()
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
......
......@@ -2,12 +2,15 @@ import pdb
import pytest
import logging
import itertools
import threading
from time import sleep
from multiprocessing import Process
from milvus import IndexType, MetricType
from utils import *
drop_collection_interval_time = 3
drop_interval_time = 3
collection_id = "list_collections"
default_fields = gen_default_fields()
......@@ -34,7 +37,7 @@ class TestListCollections:
'''
collection_num = 100
for i in range(collection_num):
collection_name = gen_unique_str("test_list_collections")
collection_name = gen_unique_str(collection_id)
connect.create_collection(collection_name, default_fields)
assert collection_name in connect.list_collections()
......@@ -46,7 +49,7 @@ class TestListCollections:
expected: list collections raise exception
'''
with pytest.raises(Exception) as e:
assert_list_collections(dis_connect)
dis_connect.list_collections()
def test_list_collections_not_existed(self, connect):
'''
......@@ -55,8 +58,8 @@ class TestListCollections:
assert the value returned by list_collections method
expected: False
'''
collection_name = gen_unique_str("test_collection")
assert collection_name not connect.list_collections()
collection_name = gen_unique_str(collection_id)
assert collection_name not in connect.list_collections()
def test_list_collections_no_collection(self, connect):
'''
......@@ -69,7 +72,7 @@ class TestListCollections:
if result:
for collection_name in result:
connect.drop_collection(collection_name)
time.sleep(drop_collection_interval_time)
time.sleep(drop_interval_time)
result = connect.list_collections()
assert len(result) == 0
......@@ -82,11 +85,11 @@ class TestListCollections:
'''
threads_num = 4
threads = []
collection_name = gen_unique_str("test_collection")
connect.create_collection(collection_name, fields)
collection_name = gen_unique_str(collection_id)
connect.create_collection(collection_name, default_fields)
def _list():
assert collection_name not connect.list_collections()
assert collection_name in connect.list_collections()
for i in range(threads_num):
t = threading.Thread(target=_list, args=())
threads.append(t)
......
......@@ -7,10 +7,12 @@ from multiprocessing import Process
from milvus import IndexType, MetricType
from utils import *
uniq_id = "test_load_collection"
collection_id = "load_collection"
index_name = "load_index_name"
nb = 6000
default_fields = gen_default_fields()
entities = gen_entities(6000)
entities = gen_entities(nb)
field_name = "fload_vector"
class TestLoadCollection:
......@@ -26,10 +28,8 @@ class TestLoadCollection:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == "IVFSQ8H":
pytest.skip("sq8h not support in cpu mode")
if request.param["index_type"] == IndexType.IVF_PQ:
pytest.skip("Skip PQ Temporary")
return request.param
def test_load_collection_after_index(self, connect, collection, get_simple_index):
......@@ -40,7 +40,6 @@ class TestLoadCollection:
'''
connect.insert(collection, entities)
connect.flush([collection])
field_name = "fload_vector"
connect.create_index(collection, field_name, index_name, get_simple_index)
connect.load_collection(collection)
......@@ -64,7 +63,7 @@ class TestLoadCollection:
@pytest.mark.level(2)
def test_load_collection_not_existed(self, connect, collection):
collection_name = gen_unique_str()
collection_name = gen_unique_str(collection_id)
with pytest.raises(Exception) as e:
connect.load_collection(collection_name)
......@@ -75,7 +74,7 @@ class TestLoadCollectionInvalid(object):
"""
@pytest.fixture(
scope="function",
params=gen_invalid_collection_names()
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
......@@ -85,15 +84,3 @@ class TestLoadCollectionInvalid(object):
collection_name = get_collection_name
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)
@pytest.mark.level(2)
def test_load_collection_with_empty_collectionname(self, connect):
collection_name = ''
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)
@pytest.mark.level(2)
def test_load_collection_with_none_collectionname(self, connect):
collection_name = None
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)
import socket
import pdb
import logging
import socket
import pytest
from utils import gen_unique_str
from milvus import Milvus, IndexType, MetricType, DataType
......@@ -10,6 +9,7 @@ from utils import *
timeout = 60
dimension = 128
delete_timeout = 60
default_fields = gen_default_fields()
def pytest_addoption(parser):
......@@ -101,14 +101,15 @@ def collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
try:
milvus.create_collection(collection_name, default_fields())
connect.create_collection(collection_name, default_fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
status, collection_names = connect.list_collections()
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown())
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
......@@ -117,16 +118,18 @@ def ip_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
fields = gen_default_fields()
fields["fields"][-1]["extra_params"]["metric_type"] = MetricType.IP
fields["fields"][-1]["params"]["metric_type"] = "IP"
try:
milvus.create_collection(collection_name, fields)
connect.create_collection(collection_name, fields)
except Exception as e:
logging.getLogger().info(str(e))
pytest.exit(str(e))
def teardown():
status, collection_names = connect.list_collections()
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown())
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
......@@ -135,16 +138,17 @@ def jac_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
fields = gen_default_fields()
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "extra_params": {"metric_type": MetricType.JACCARD}}
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "params": {"metric_type": "JACCARD"}}
try:
milvus.create_collection(collection_name, fields)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
status, collection_names = connect.list_collections()
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown())
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
......@@ -153,16 +157,17 @@ def ham_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
fields = gen_default_fields()
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "extra_params": {"metric_type": MetricType.HAMMING}}
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "params": {"metric_type": "HAMMING"}}
try:
milvus.create_collection(collection_name, fields)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
status, collection_names = connect.list_collections()
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown())
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
......@@ -171,16 +176,17 @@ def tanimoto_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
fields = gen_default_fields()
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "extra_params": {"metric_type": MetricType.TANIMOTO}}
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "params": {"metric_type": "TANIMOTO"}}
try:
milvus.create_collection(collection_name, fields)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
status, collection_names = connect.list_collections()
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown())
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
......@@ -189,16 +195,17 @@ def substructure_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
fields = gen_default_fields()
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "extra_params": {"metric_type": MetricType.SUBSTRUCTURE}}
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "params": {"metric_type": "SUBSTRUCTURE"}}
try:
milvus.create_collection(collection_name, fields)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
status, collection_names = connect.list_collections()
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown())
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
......@@ -208,14 +215,15 @@ def superstructure_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
fields = gen_default_fields()
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "extra_params": {"metric_type": MetricType.SUPERSTRUCTURE}}
fields["fields"][-1] = {"field": "binary_vector", "type": DataType.BINARY_VECTOR, "dimension": dimension, "params": {"metric_type": MetricType.SUPERSTRUCTURE}}
try:
milvus.create_collection(collection_name, fields)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
status, collection_names = connect.list_collections()
collection_names = connect.list_collections()
for collection_name in collection_names:
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown())
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
......@@ -48,10 +48,10 @@ class TestDeleteBase:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVF_SQ8H]:
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVFSQ8H]:
pytest.skip("Only support index_type: idmap/ivf")
elif str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] in [IndexType.IVF_SQ8H]:
if request.param["index_type"] in [IndexType.IVFSQ8H]:
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param
......
......@@ -42,10 +42,10 @@ class TestGetBase:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVF_SQ8H]:
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVFSQ8H]:
pytest.skip("Only support index_type: idmap/ivf")
elif str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] in [IndexType.IVF_SQ8H]:
if request.param["index_type"] in [IndexType.IVFSQ8H]:
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param
......
......@@ -42,7 +42,7 @@ class TestInsertBase:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == "IVFSQ8H":
pytest.skip("sq8h not support in cpu mode")
return request.param
......@@ -597,7 +597,7 @@ class TestInsertMultiCollections:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == "IVFSQ8H":
pytest.skip("sq8h not support in cpu mode")
def test_insert_vector_multi_collections(self, connect):
......
......@@ -10,8 +10,7 @@ from utils import *
dim = 128
segment_size = 10
collection_id = "test_insert"
ADD_TIMEOUT = 60
collection_id = "search"
tag = "1970-01-01"
insert_interval_time = 1.5
nb = 6000
......@@ -24,23 +23,27 @@ entity = gen_entities(1, is_normal=True)
binary_entity = gen_binary_entities(1)
entities = gen_entities(nb, is_normal=True)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": entity, "params": {"index_name": default_index_name, "nprobe": 10}}}}
]
# query = {
# "bool": {
# "must": [
# {"term": {"A": {"values": [1, 2, 5]}}},
# {"range": {"B": {"ranges": {"GT": 1, "LT": 100}}}},
# {"vector": {"Vec": {"topk": 10, "query": vec[: 1], "params": {"index_name": "IVFFLAT", "nprobe": 10}}}}
# ],
# },
# }
def get_query_inside(entities, top_k, nq, search_params={"nprobe": 10}):
query_vectors = entities[-1]["values"][:nq]
query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}}
]
}
}
}
query = {
"bool": {
"must": [
{"term": {"A": {"values": [1, 2, 5]}}},
{"range": {"B": {"ranges": {"GT": 1, "LT": 100}}}},
{"vector": {"Vec": {"topk": 10, "query": vec[: 1], "params": {"index_name": Indextype.IVF_FLAT, "nprobe": 10}}}}
],
},
}
return query
class TestSearchBase:
def init_data(self, connect, collection, nb=6000, partition_tags=None):
......@@ -88,11 +91,8 @@ class TestSearchBase:
)
def get_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == "IVFSQ8H":
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
pytest.skip("ivfpq not support in GPU mode")
return request.param
@pytest.fixture(
......@@ -101,11 +101,8 @@ class TestSearchBase:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == "IVFSQ8H":
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
pytest.skip("ivfpq not support in GPU mode")
return request.param
@pytest.fixture(
......@@ -114,7 +111,7 @@ class TestSearchBase:
)
def get_jaccard_index(self, request, connect):
logging.getLogger().info(request.param)
if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
if request.param["index_type"] in ["IVFFLAT", "FLAT"]:
return request.param
else:
pytest.skip("Skip index Temporary")
......@@ -125,7 +122,7 @@ class TestSearchBase:
)
def get_hamming_index(self, request, connect):
logging.getLogger().info(request.param)
if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
if request.param["index_type"] in ["IVFFLAT", "FLAT"]:
return request.param
else:
pytest.skip("Skip index Temporary")
......@@ -136,7 +133,7 @@ class TestSearchBase:
)
def get_structure_index(self, request, connect):
logging.getLogger().info(request.param)
if request.param["index_type"] == IndexType.FLAT:
if request.param["index_type"] == "FLAT":
return request.param
else:
pytest.skip("Skip index Temporary")
......@@ -156,25 +153,28 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
vectors, ids = self.init_data(connect, collection)
query_vec = [vectors[0]]
top_k = get_top_k
status, result = connect.search(collection, top_k, query_vec)
entities, ids = self.init_data(connect, collection)
query = get_query_inside(entities, top_k, 1)
if top_k <= 2048:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert result[0][0].distance <= epsilon
assert check_result(result[0], ids[0])
logging.getLogger().info(query)
res = connect.search(collection, query)
logging.getLogger().info(res)
assert len(res[0]) == top_k
assert res[0][0].distance <= epsilon
assert check_result(res[0], ids[0])
else:
assert not status.OK()
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
def test_search_l2_index_params(self, connect, collection, get_simple_index):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
top_k = 10
index_param = get_simple_index["index_param"]
......@@ -201,7 +201,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
top_k = 10
index_param = get_simple_index["index_param"]
......@@ -225,7 +225,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: add vectors into collection, search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k, search collection with partition tag return empty
expected: the length of the result is top_k, search collection with partition tag return empty
'''
top_k = 10
index_param = get_simple_index["index_param"]
......@@ -253,7 +253,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search partition with the given vectors, check the result
expected: search status ok, and the length of the result is 0
expected: the length of the result is 0
'''
top_k = 10
index_param = get_simple_index["index_param"]
......@@ -276,7 +276,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
top_k = 10
index_param = get_simple_index["index_param"]
......@@ -306,7 +306,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors and tags (one of the tags not existed in collection), check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
index_param = get_simple_index["index_param"]
index_type = get_simple_index["index_type"]
......@@ -331,7 +331,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors and tag (tag name not existed in collection), check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
index_param = get_simple_index["index_param"]
index_type = get_simple_index["index_type"]
......@@ -351,7 +351,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search collection with the given vectors and tags, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
top_k = 10
new_tag = "new_tag"
......@@ -386,7 +386,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search collection with the given vectors and tags with "re" expr, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
tag = "atag"
new_tag = "new_tag"
......@@ -419,7 +419,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
top_k = 10
index_param = get_simple_index["index_param"]
......@@ -443,7 +443,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
index_param = get_simple_index["index_param"]
index_type = get_simple_index["index_type"]
......@@ -469,7 +469,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
top_k = 10
index_param = get_simple_index["index_param"]
......@@ -499,7 +499,7 @@ class TestSearchBase:
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors and tag, check the result
expected: search status ok, and the length of the result is top_k
expected: the length of the result is top_k
'''
top_k = 10
index_param = get_simple_index["index_param"]
......@@ -595,7 +595,7 @@ class TestSearchBase:
nb = 2
nprobe = 1
vectors, ids = self.init_data(connect, ip_collection, nb=nb)
index_type = IndexType.FLAT
index_type = "FLAT"
index_param = {
"nlist": 16384
}
......@@ -617,7 +617,7 @@ class TestSearchBase:
# from scipy.spatial import distance
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2)
index_type = IndexType.FLAT
index_type = "FLAT"
index_param = {
"nlist": 16384
}
......@@ -642,7 +642,7 @@ class TestSearchBase:
# from scipy.spatial import distance
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, ham_collection, nb=2)
index_type = IndexType.FLAT
index_type = "FLAT"
index_param = {
"nlist": 16384
}
......@@ -667,7 +667,7 @@ class TestSearchBase:
# from scipy.spatial import distance
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
index_type = IndexType.FLAT
index_type = "FLAT"
index_param = {
"nlist": 16384
}
......@@ -693,7 +693,7 @@ class TestSearchBase:
top_k = 3
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
index_type = IndexType.FLAT
index_type = "FLAT"
index_param = {
"nlist": 16384
}
......@@ -721,7 +721,7 @@ class TestSearchBase:
# from scipy.spatial import distance
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
index_type = IndexType.FLAT
index_type = "FLAT"
index_param = {
"nlist": 16384
}
......@@ -747,7 +747,7 @@ class TestSearchBase:
top_k = 3
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
index_type = IndexType.FLAT
index_type = "FLAT"
index_param = {
"nlist": 16384
}
......@@ -775,7 +775,7 @@ class TestSearchBase:
# from scipy.spatial import distance
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_collection, nb=2)
index_type = IndexType.FLAT
index_type = "FLAT"
index_param = {
"nlist": 16384
}
......@@ -854,7 +854,7 @@ class TestSearchBase:
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'collection_name': collection,
'dimension': dim,
'index_type': IndexType.FLAT,
'index_type': "FLAT",
'store_raw_vector': False}
# create collection
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
......@@ -893,7 +893,7 @@ class TestSearchBase:
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'collection_name': collection,
'dimension': dim,
'index_type': IndexType.FLAT,
'index_type': "FLAT",
'store_raw_vector': False}
# create collection
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
......@@ -1026,7 +1026,7 @@ class TestSearchParamsInvalid(object):
"""
@pytest.fixture(
scope="function",
params=gen_invalid_collection_names()
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
......@@ -1061,7 +1061,7 @@ class TestSearchParamsInvalid(object):
"""
@pytest.fixture(
scope="function",
params=gen_invalid_top_ks()
params=gen_invalid_ints()
)
def get_top_k(self, request):
yield request.param
......@@ -1106,7 +1106,7 @@ class TestSearchParamsInvalid(object):
"""
@pytest.fixture(
scope="function",
params=gen_invalid_nprobes()
params=gen_invalid_ints()
)
def get_nprobes(self, request):
yield request.param
......@@ -1160,7 +1160,7 @@ class TestSearchParamsInvalid(object):
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == "IVFSQ8H":
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -1181,7 +1181,7 @@ class TestSearchParamsInvalid(object):
query_vecs = gen_vectors(1, dim)
status, result = connect.search(collection, top_k, query_vecs, params={})
if index_type == IndexType.FLAT:
if index_type == "FLAT":
assert status.OK()
else:
assert not status.OK()
......@@ -1192,7 +1192,7 @@ class TestSearchParamsInvalid(object):
)
def get_invalid_search_param(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == "IVFSQ8H":
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......
......@@ -32,7 +32,7 @@
# # @pytest.fixture(scope="function", autouse=True)
# # def skip_check(self, connect):
# # if str(connect._cmd("mode")[1]) == "CPU":
# # if request.param["index_type"] == IndexType.IVF_SQ8H:
# # if request.param["index_type"] == IndexType.IVFSQ8H:
# # pytest.skip("sq8h not support in CPU mode")
# # if str(connect._cmd("mode")[1]) == "GPU":
# # if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -106,7 +106,7 @@
# )
# def get_simple_index(self, request, connect):
# if str(connect._cmd("mode")[1]) == "CPU":
# if request.param["index_type"] == IndexType.IVF_SQ8H:
# if request.param["index_type"] == IndexType.IVFSQ8H:
# pytest.skip("sq8h not support in CPU mode")
# if str(connect._cmd("mode")[1]) == "GPU":
# if request.param["index_type"] == IndexType.IVF_PQ:
......
......@@ -8,5 +8,4 @@ pytest-print==0.1.2
pytest-level==0.1.1
pytest-xdist==1.23.2
scikit-learn>=0.19.1
pymilvus-test>=0.2.0
kubernetes==10.0.1
......@@ -32,7 +32,7 @@ class TestFlushBase:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVF_SQ8H]:
if request.param["index_type"] not in [IndexType.IVF_SQ8, IndexType.IVFLAT, IndexType.FLAT, IndexType.IVF_PQ, IndexType.IVFSQ8H]:
pytest.skip("Only support index_type: idmap/flat")
return request.param
......
"""
For testing index operations, including `create_index`, `get_index_info` and `drop_index` interfaces
"""
import logging
import pytest
import time
import pdb
import threading
from multiprocessing import Pool, Process
import numpy
import pytest
import sklearn.preprocessing
from milvus import IndexType, MetricType
from utils import *
......@@ -32,7 +29,7 @@ class TestIndexBase:
)
def get_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -45,7 +42,7 @@ class TestIndexBase:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -674,7 +671,7 @@ class TestIndexIP:
)
def get_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -689,7 +686,7 @@ class TestIndexIP:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -1233,7 +1230,7 @@ class TestIndexJAC:
)
def get_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -1246,7 +1243,7 @@ class TestIndexJAC:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -1434,7 +1431,7 @@ class TestIndexBinary:
)
def get_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if request.param["index_type"] == IndexType.IVF_PQ or request.param["index_type"] == IndexType.HNSW:
pytest.skip("Skip PQ Temporary")
......@@ -1446,7 +1443,7 @@ class TestIndexBinary:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if request.param["index_type"] == IndexType.IVF_PQ or request.param["index_type"] == IndexType.HNSW:
pytest.skip("Skip PQ Temporary")
......@@ -1785,7 +1782,7 @@ class TestCreateIndexParamsInvalid(object):
"""
@pytest.fixture(
scope="function",
params=[IndexType.FLAT,IndexType.IVFLAT,IndexType.IVF_SQ8,IndexType.IVF_SQ8H]
params=[IndexType.FLAT,IndexType.IVFLAT,IndexType.IVF_SQ8,IndexType.IVFSQ8H]
)
def get_index_type(self, request):
yield request.param
......@@ -1826,7 +1823,7 @@ class TestIndexAsync:
)
def get_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
if request.param["index_type"] == IndexType.IVF_PQ:
......@@ -1839,7 +1836,7 @@ class TestIndexAsync:
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
if request.param["index_type"] == IndexType.IVFSQ8H:
pytest.skip("sq8h not support in CPU mode")
if str(connect._cmd("mode")[1]) == "GPU":
# if request.param["index_type"] == IndexType.IVF_PQ:
......
import os
import sys
import random
import pdb
import string
import struct
import logging
import time, datetime
import copy
import numpy as np
from sklearn import preprocessing
from milvus import Milvus, IndexType, MetricType, DataType
port = 19530
......@@ -18,14 +20,14 @@ segment_size = 10
all_index_types = [
IndexType.FLAT,
IndexType.IVFLAT,
IndexType.IVF_SQ8,
IndexType.IVF_SQ8H,
IndexType.IVF_PQ,
IndexType.HNSW,
IndexType.RNSG,
IndexType.ANNOY
"FLAT",
"IVFFLAT",
"IVFSQ8",
"IVFSQ8H",
"IVFPQ",
"HNSW",
"RNSG",
"ANNOY"
]
......@@ -71,16 +73,13 @@ def get_milvus(host, port, uri=None, handler=None, **kwargs):
def disable_flush(connect):
status, reply = connect.set_config("storage", "auto_flush_interval", big_flush_interval)
assert status.OK()
connect.set_config("storage", "auto_flush_interval", big_flush_interval)
def enable_flush(connect):
# reset auto_flush_interval=1
status, reply = connect.set_config("storage", "auto_flush_interval", default_flush_interval)
assert status.OK()
status, config_value = connect.get_config("storage", "auto_flush_interval")
assert status.OK()
connect.set_config("storage", "auto_flush_interval", default_flush_interval)
config_value = connect.get_config("storage", "auto_flush_interval")
assert config_value == str(default_flush_interval)
......@@ -90,14 +89,14 @@ def gen_inaccuracy(num):
def gen_vectors(num, dim, is_normal=False):
vectors = [[random.random() for _ in range(dim)] for _ in range(num)]
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
return vectors.tolist()
def gen_vectors(nb, d, seed=np.random.RandomState(1234), is_normal=False):
xb = seed.rand(nb, d).astype("float32")
xb = klearn.preprocessing.normalize(xb, axis=1, norm='l2')
return xb.tolist()
# def gen_vectors(num, dim, seed=np.random.RandomState(1234), is_normal=False):
# xb = seed.rand(num, dim).astype("float32")
# xb = preprocessing.normalize(xb, axis=1, norm='l2')
# return xb.tolist()
def gen_binary_vectors(num, dim):
......@@ -152,18 +151,19 @@ def gen_unique_str(str_value=None):
def gen_single_filter_fields():
fields = []
for data_type in [i.value for i in DataType]:
fields.append({"field": data_type.name, "type": data_type})
for data_type in DataType:
if data_type in [DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]:
fields.append({"field": data_type.name, "type": data_type})
return fields
def gen_single_vector_fields():
fields = []
for metric_type in MetricType:
for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
if metric_type in [MetricType.L2, MetricType.IP] and data_type == DataType.BINARY_VECTOR:
for metric_type in ['HAMMING', 'IP', 'JACCARD', 'L2', 'SUBSTRUCTURE', 'SUPERSTRUCTURE', 'TANIMOTO']:
for data_type in [DataType.VECTOR, DataType.BINARY_VECTOR]:
if metric_type in ["L2", "IP"] and data_type == DataType.BINARY_VECTOR:
continue
field = {"field": data_type.name, "type": data_type, "extra_params": {"metric_type": metric_type, "dimension": dimension}}
field = {"field": data_type.name, "type": data_type, "params": {"metric_type": metric_type, "dimension": dimension}}
fields.append(field)
return fields
......@@ -174,7 +174,7 @@ def gen_default_fields():
{"field": "int8", "type": DataType.INT8},
{"field": "int64", "type": DataType.INT64},
{"field": "float", "type": DataType.FLOAT},
{"field": "float_vector", "type": DataType.FLOAT_VECTOR, "extra_params": {"metric_type": MetricType.L2, "dimension": dimension}}
{"field": "vector", "type": DataType.VECTOR, "params": {"metric_type": "L2", "dimension": dimension}}
],
"segment_size": segment_size
}
......@@ -187,7 +187,7 @@ def gen_entities(nb, is_normal=False):
{"field": "int8", "type": DataType.INT8, "values": [1 for i in range(nb)]},
{"field": "int64", "type": DataType.INT64, "values": [2 for i in range(nb)]},
{"field": "float", "type": DataType.FLOAT, "values": [3.0 for i in range(nb)]},
{"field": "float_vector", "type": DataType.FLOAT_VECTOR, "values": vectors}
{"field": "vector", "type": DataType.VECTOR, "values": vectors}
]
return entities
......@@ -217,13 +217,22 @@ def add_vector_field(entities, is_normal=False):
vectors = gen_vectors(nb, dimension, is_normal)
field = {
"field": gen_unique_str(),
"type": DataType.FLOAT_VECTOR,
"type": DataType.VECTOR,
"values": vectors
}
entities.append(field)
return entities
def update_fields_metric_type(fields, metric_type):
if metric_type in ["L2", "IP"]:
fields["fields"][-1]["type"] = DataType.VECTOR
else:
fields["fields"][-1]["type"] = DataType.BINARY_VECTOR
fields["fields"][-1]["params"]["metric_type"] = metric_type
return fields
def remove_field(entities):
del entities[0]
return entities
......@@ -256,11 +265,11 @@ def update_field_value(entities, old_type, new_value):
return entities
def add_float_vector_field(nb, dimension):
def add_vector_field(nb, dimension):
field_name = gen_unique_str()
field = {
"field": field_name,
"type": DataType.FLOAT_VECTOR,
"type": DataType.VECTOR,
"values": gen_vectors(nb, dimension)
}
return field_name
......@@ -359,6 +368,18 @@ def gen_invalid_field_types():
return field_types
def gen_invalid_metric_types():
metric_types = [
1,
"=c",
0,
None,
"",
"a".join("a" for i in range(256))
]
return metric_types
def gen_invalid_ints():
top_ks = [
1.0,
......@@ -435,23 +456,23 @@ def gen_invaild_search_params():
invalid_search_key = 100
search_params = []
for index_type in all_index_types:
if index_type == IndexType.FLAT:
if index_type == "FLAT":
continue
search_params.append({"index_type": index_type, "search_param": {"invalid_key": invalid_search_key}})
if index_type in [IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ]:
if index_type in ["IVFFLAT", "IVFSQ8", "IVFSQ8H", "IVFPQ"]:
for nprobe in gen_invalid_params():
ivf_search_params = {"index_type": index_type, "search_param": {"nprobe": nprobe}}
search_params.append(ivf_search_params)
elif index_type == IndexType.HNSW:
elif index_type == "HNSW":
for ef in gen_invalid_params():
hnsw_search_param = {"index_type": index_type, "search_param": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == IndexType.RNSG:
elif index_type == "RNSG":
for search_length in gen_invalid_params():
nsg_search_param = {"index_type": index_type, "search_param": {"search_length": search_length}}
search_params.append(nsg_search_param)
search_params.append({"index_type": index_type, "search_param": {"invalid_key": 100}})
elif index_type == IndexType.ANNOY:
elif index_type == "ANNOY":
for search_k in gen_invalid_params():
if isinstance(search_k, int):
continue
......@@ -466,36 +487,36 @@ def gen_invalid_index():
index_param = {"index_type": index_type, "params": {"nlist": 1024}}
index_params.append(index_param)
for nlist in gen_invalid_params():
index_param = {"index_type": IndexType.IVFLAT, "params": {"nlist": nlist}}
index_param = {"index_type": "IVFFLAT", "params": {"nlist": nlist}}
index_params.append(index_param)
for M in gen_invalid_params():
index_param = {"index_type": IndexType.HNSW, "params": {"M": M, "efConstruction": 100}}
index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}}
index_params.append(index_param)
for efConstruction in gen_invalid_params():
index_param = {"index_type": IndexType.HNSW, "params": {"M": 16, "efConstruction": efConstruction}}
index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}}
index_params.append(index_param)
for search_length in gen_invalid_params():
index_param = {"index_type": IndexType.RNSG,
index_param = {"index_type": "RNSG",
"params": {"search_length": search_length, "out_degree": 40, "candidate_pool_size": 50,
"knng": 100}}
index_params.append(index_param)
for out_degree in gen_invalid_params():
index_param = {"index_type": IndexType.RNSG,
index_param = {"index_type": "RNSG",
"params": {"search_length": 100, "out_degree": out_degree, "candidate_pool_size": 50,
"knng": 100}}
index_params.append(index_param)
for candidate_pool_size in gen_invalid_params():
index_param = {"index_type": IndexType.RNSG, "params": {"search_length": 100, "out_degree": 40,
index_param = {"index_type": "RNSG", "params": {"search_length": 100, "out_degree": 40,
"candidate_pool_size": candidate_pool_size,
"knng": 100}}
index_params.append(index_param)
index_params.append({"index_type": IndexType.IVF_FLAT, "params": {"invalid_key": 1024}})
index_params.append({"index_type": IndexType.HNSW, "params": {"invalid_key": 16, "efConstruction": 100}})
index_params.append({"index_type": IndexType.RNSG,
index_params.append({"index_type": "IVFFLAT", "params": {"invalid_key": 1024}})
index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}})
index_params.append({"index_type": "RNSG",
"params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
"knng": 100}})
for invalid_n_trees in gen_invalid_params():
index_params.append({"index_type": IndexType.ANNOY, "params": {"n_trees": invalid_n_trees}})
index_params.append({"index_type": "ANNOY", "params": {"n_trees": invalid_n_trees}})
return index_params
......@@ -512,23 +533,23 @@ def gen_index():
index_params = []
for index_type in all_index_types:
if index_type == IndexType.FLAT:
if index_type == "FLAT":
index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}})
elif index_type in [IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]:
elif index_type in ["IVFFLAT", "IVFSQ8", "IVFSQ8H"]:
ivf_params = [{"index_type": index_type, "index_param": {"nlist": nlist}} \
for nlist in nlists]
index_params.extend(ivf_params)
elif index_type == IndexType.IVF_PQ:
ivf_pq_params = [{"index_type": index_type, "index_param": {"nlist": nlist, "m": m}} \
elif index_type == "IVFPQ":
IVFPQ_params = [{"index_type": index_type, "index_param": {"nlist": nlist, "m": m}} \
for nlist in nlists \
for m in pq_ms]
index_params.extend(ivf_pq_params)
elif index_type == IndexType.HNSW:
index_params.extend(IVFPQ_params)
elif index_type == "HNSW":
hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \
for M in Ms \
for efConstruction in efConstructions]
index_params.extend(hnsw_params)
elif index_type == IndexType.RNSG:
elif index_type == "RNSG":
nsg_params = [{"index_type": index_type,
"index_param": {"search_length": search_length, "out_degree": out_degree,
"candidate_pool_size": candidate_pool_size, "knng": knng}} \
......@@ -559,24 +580,19 @@ def gen_simple_index():
def get_search_param(index_type):
if index_type in [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ]:
if index_type in ["FLAT", "IVFFLAT", "IVFSQ8", "IVFSQ8H", "IVFPQ"]:
return {"nprobe": 32}
elif index_type == IndexType.HNSW:
elif index_type == "HNSW":
return {"ef": 64}
elif index_type == IndexType.RNSG:
elif index_type == "RNSG":
return {"search_length": 100}
elif index_type == IndexType.ANNOY:
elif index_type == "ANNOY":
return {"search_k": 100}
else:
logging.getLogger().info("Invalid index_type.")
def assert_has_collection(conn, collection_name):
res = conn.has_collection(collection_name)
return res
def assert_equal_vector(v1, v2):
if len(v1) != len(v2):
assert False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册