test_search.py 53.6 KB
Newer Older
J
JinHai-CN 已提交
1
import pdb
G
groot 已提交
2
import struct
3
from random import sample
J
JinHai-CN 已提交
4 5 6 7
import threading
import datetime
import logging
from time import sleep
D
del-zhenwu 已提交
8
import concurrent.futures
J
JinHai-CN 已提交
9
from multiprocessing import Process
D
del-zhenwu 已提交
10
import pytest
J
JinHai-CN 已提交
11
import numpy
12
import sklearn.preprocessing
13
from milvus import IndexType, MetricType
J
JinHai-CN 已提交
14 15 16
from utils import *

dim = 128
X
Xiaohai Xu 已提交
17
collection_id = "test_search"
J
JinHai-CN 已提交
18
add_interval_time = 2
G
groot 已提交
19
vectors = gen_vectors(6000, dim)
20 21
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
vectors = vectors.tolist()
D
del-zhenwu 已提交
22
top_k = 1
Z
zhenwu 已提交
23
nprobe = 1
J
JinHai-CN 已提交
24
epsilon = 0.001
Z
zhenwu 已提交
25
tag = "1970-01-01"
G
groot 已提交
26
raw_vectors, binary_vectors = gen_binary_vectors(6000, dim)
J
JinHai-CN 已提交
27 28 29


class TestSearchBase:
X
Xiaohai Xu 已提交
30
    def init_data(self, connect, collection, nb=6000, partition_tags=None):
J
JinHai-CN 已提交
31
        '''
X
Xiaohai Xu 已提交
32
        Generate vectors and add it in collection, before search vectors
J
JinHai-CN 已提交
33 34
        '''
        global vectors
G
groot 已提交
35
        if nb == 6000:
J
JinHai-CN 已提交
36 37 38
            add_vectors = vectors
        else:  
            add_vectors = gen_vectors(nb, dim)
D
del-zhenwu 已提交
39 40
            add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2')
            add_vectors = add_vectors.tolist()
41
        if partition_tags is None:
D
del-zhenwu 已提交
42
            status, ids = connect.insert(collection, add_vectors)
43 44
            assert status.OK()
        else:
D
del-zhenwu 已提交
45
            status, ids = connect.insert(collection, add_vectors, partition_tag=partition_tags)
46
            assert status.OK()
D
del-zhenwu 已提交
47
        connect.flush([collection])
J
JinHai-CN 已提交
48 49
        return add_vectors, ids

X
Xiaohai Xu 已提交
50
    def init_binary_data(self, connect, collection, nb=6000, insert=True, partition_tags=None):
G
groot 已提交
51
        '''
X
Xiaohai Xu 已提交
52
        Generate vectors and add it in collection, before search vectors
G
groot 已提交
53 54 55 56 57 58 59 60 61 62
        '''
        ids = []
        global binary_vectors
        global raw_vectors
        if nb == 6000:
            add_vectors = binary_vectors
            add_raw_vectors = raw_vectors
        else:  
            add_raw_vectors, add_vectors = gen_binary_vectors(nb, dim)
        if insert is True:
63
            if partition_tags is None:
D
del-zhenwu 已提交
64
                status, ids = connect.insert(collection, add_vectors)
65 66
                assert status.OK()
            else:
D
del-zhenwu 已提交
67
                status, ids = connect.insert(collection, add_vectors, partition_tag=partition_tags)
68
                assert status.OK()
D
del-zhenwu 已提交
69
            connect.flush([collection])
G
groot 已提交
70 71
        return add_raw_vectors, add_vectors, ids

J
JinHai-CN 已提交
72 73 74 75 76
    """
    generate valid create_index params
    """
    @pytest.fixture(
        scope="function",
77
        params=gen_index()
J
JinHai-CN 已提交
78
    )
79
    def get_index(self, request, connect):
G
groot 已提交
80
        if str(connect._cmd("mode")[1]) == "CPU":
81
            if request.param["index_type"] == IndexType.IVF_SQ8H:
82
                pytest.skip("sq8h not support in CPU mode")
83
        return request.param
J
JinHai-CN 已提交
84

Z
zhenwu 已提交
85 86
    @pytest.fixture(
        scope="function",
87
        params=gen_simple_index()
Z
zhenwu 已提交
88
    )
89
    def get_simple_index(self, request, connect):
G
groot 已提交
90
        if str(connect._cmd("mode")[1]) == "CPU":
Z
zhenwu 已提交
91
            if request.param["index_type"] == IndexType.IVF_SQ8H:
92
                pytest.skip("sq8h not support in CPU mode")
Z
zhenwu 已提交
93
        return request.param
G
groot 已提交
94 95 96

    @pytest.fixture(
        scope="function",
97
        params=gen_simple_index()
G
groot 已提交
98
    )
99
    def get_jaccard_index(self, request, connect):
G
groot 已提交
100 101 102 103 104 105 106 107
        logging.getLogger().info(request.param)
        if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
            return request.param
        else:
            pytest.skip("Skip index Temporary")

    @pytest.fixture(
        scope="function",
108
        params=gen_simple_index()
G
groot 已提交
109
    )
110
    def get_hamming_index(self, request, connect):
G
groot 已提交
111 112 113 114 115 116
        logging.getLogger().info(request.param)
        if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
            return request.param
        else:
            pytest.skip("Skip index Temporary")

D
del-zhenwu 已提交
117 118 119 120 121 122 123 124 125 126 127
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_structure_index(self, request, connect):
        logging.getLogger().info(request.param)
        if request.param["index_type"] == IndexType.FLAT:
            return request.param
        else:
            pytest.skip("Skip index Temporary")

J
JinHai-CN 已提交
128 129 130 131 132
    """
    generate top-k params
    """
    @pytest.fixture(
        scope="function",
D
del-zhenwu 已提交
133
        params=[1, 99, 1024, 2049]
J
JinHai-CN 已提交
134 135 136 137 138
    )
    def get_top_k(self, request):
        yield request.param


X
Xiaohai Xu 已提交
139
    def test_search_top_k_flat_index(self, connect, collection, get_top_k):
J
JinHai-CN 已提交
140 141 142 143 144
        '''
        target: test basic search fuction, all the search params is corrent, change top-k value
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
X
Xiaohai Xu 已提交
145
        vectors, ids = self.init_data(connect, collection)
J
JinHai-CN 已提交
146 147
        query_vec = [vectors[0]]
        top_k = get_top_k
D
del-zhenwu 已提交
148
        status, result = connect.search(collection, top_k, query_vec)
J
JinHai-CN 已提交
149 150 151 152 153 154 155 156
        if top_k <= 2048:
            assert status.OK()
            assert len(result[0]) == min(len(vectors), top_k)
            assert result[0][0].distance <= epsilon
            assert check_result(result[0], ids[0])
        else:
            assert not status.OK()

X
Xiaohai Xu 已提交
157
    def test_search_l2_index_params(self, connect, collection, get_simple_index):
J
JinHai-CN 已提交
158 159 160 161 162
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
163
        top_k = 10
164 165 166
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
X
Xiaohai Xu 已提交
167 168
        vectors, ids = self.init_data(connect, collection)
        status = connect.create_index(collection, index_type, index_param)
169
        query_vec = [vectors[0], vectors[1]]
170
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
171
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
J
JinHai-CN 已提交
172 173 174 175 176
        logging.getLogger().info(result)
        if top_k <= 1024:
            assert status.OK()
            assert len(result[0]) == min(len(vectors), top_k)
            assert check_result(result[0], ids[0])
177 178
            assert result[0][0].distance < result[0][1].distance
            assert result[1][0].distance < result[1][1].distance
J
JinHai-CN 已提交
179 180 181
        else:
            assert not status.OK()

X
Xiaohai Xu 已提交
182
    def test_search_l2_large_nq_index_params(self, connect, collection, get_simple_index):
183 184 185 186 187
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
188
        top_k = 10
189 190 191
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
192 193 194
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")

X
Xiaohai Xu 已提交
195 196
        vectors, ids = self.init_data(connect, collection)
        status = connect.create_index(collection, index_type, index_param)
D
del-zhenwu 已提交
197
        query_vec = vectors[:1000]
198
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
199
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
200 201 202 203 204 205
        logging.getLogger().info(result)
        assert status.OK()
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon

X
Xiaohai Xu 已提交
206
    def test_search_l2_index_params_partition(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
207 208
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
209 210
        method: add vectors into collection, search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k, search collection with partition tag return empty
Z
zhenwu 已提交
211
        '''
D
del-zhenwu 已提交
212
        top_k = 10
213 214 215
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
216 217
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
X
Xiaohai Xu 已提交
218 219 220
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
221
        query_vec = [vectors[0]]
222
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
223
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
Z
zhenwu 已提交
224 225
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
226 227 228
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon
D
del-zhenwu 已提交
229
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
230 231 232 233
        logging.getLogger().info(result)
        assert status.OK()
        assert len(result) == 0

X
Xiaohai Xu 已提交
234
    def test_search_l2_index_params_partition_A(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
235 236 237 238 239
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search partition with the given vectors, check the result
        expected: search status ok, and the length of the result is 0
        '''
D
del-zhenwu 已提交
240
        top_k = 10
241 242 243
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
244 245 246
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")

X
Xiaohai Xu 已提交
247 248 249
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
250
        query_vec = [vectors[0]]
251
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
252
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
253 254 255 256
        logging.getLogger().info(result)
        assert status.OK()
        assert len(result) == 0

X
Xiaohai Xu 已提交
257
    def test_search_l2_index_params_partition_B(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
258 259 260 261 262
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
263
        top_k = 10
264 265 266
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
267 268
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
X
Xiaohai Xu 已提交
269 270 271
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
272
        query_vec = [vectors[0]]
273
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
274
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
Z
zhenwu 已提交
275 276
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
277 278 279
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon
D
del-zhenwu 已提交
280
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
281 282
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
283 284 285
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon
Z
zhenwu 已提交
286

X
Xiaohai Xu 已提交
287
    def test_search_l2_index_params_partition_C(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
288 289
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
290
        method: search with the given vectors and tags (one of the tags not existed in collection), check the result
Z
zhenwu 已提交
291 292
        expected: search status ok, and the length of the result is top_k
        '''
293 294 295
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
296 297
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
X
Xiaohai Xu 已提交
298 299 300
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
301 302
        query_vec = [vectors[0]]
        top_k = 10
303
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
304
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], params=search_param)
Z
zhenwu 已提交
305 306
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
307 308 309
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon
Z
zhenwu 已提交
310

D
del-zhenwu 已提交
311
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
312
    def test_search_l2_index_params_partition_D(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
313 314
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
315
        method: search with the given vectors and tag (tag name not existed in collection), check the result
Z
zhenwu 已提交
316 317
        expected: search status ok, and the length of the result is top_k
        '''
318 319 320
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
X
Xiaohai Xu 已提交
321 322 323
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
324 325
        query_vec = [vectors[0]]
        top_k = 10
326
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
327
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["new_tag"], params=search_param)
Z
zhenwu 已提交
328
        logging.getLogger().info(result)
T
Tinkerrr 已提交
329
        assert not status.OK()
Z
zhenwu 已提交
330

D
del-zhenwu 已提交
331
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
332
    def test_search_l2_index_params_partition_E(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
333 334
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
335
        method: search collection with the given vectors and tags, check the result
Z
zhenwu 已提交
336 337
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
338
        top_k = 10
Z
zhenwu 已提交
339
        new_tag = "new_tag"
340
        index_type = get_simple_index["index_type"]
D
del-zhenwu 已提交
341 342 343
        index_param = get_simple_index["index_param"]
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
344
        logging.getLogger().info(get_simple_index)
X
Xiaohai Xu 已提交
345 346 347 348 349
        status = connect.create_partition(collection, tag)
        status = connect.create_partition(collection, new_tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
350
        query_vec = [vectors[0], new_vectors[0]]
351
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
352
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag], params=search_param)
Z
zhenwu 已提交
353 354
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
355 356 357 358 359
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert check_result(result[1], new_ids[0])
        assert result[0][0].distance <= epsilon
        assert result[1][0].distance <= epsilon
D
del-zhenwu 已提交
360
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[new_tag], params=search_param)
Z
zhenwu 已提交
361 362
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
363 364 365
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[1], new_ids[0])
        assert result[1][0].distance <= epsilon
Z
zhenwu 已提交
366

X
Xiaohai Xu 已提交
367
    def test_search_l2_index_params_partition_F(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
368 369
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
370
        method: search collection with the given vectors and tags with "re" expr, check the result
Z
zhenwu 已提交
371 372
        expected: search status ok, and the length of the result is top_k
        '''
Z
zhenwu 已提交
373
        tag = "atag"
Z
zhenwu 已提交
374
        new_tag = "new_tag"
375 376 377
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
378 379
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
X
Xiaohai Xu 已提交
380 381 382 383 384
        status = connect.create_partition(collection, tag)
        status = connect.create_partition(collection, new_tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
385 386
        query_vec = [vectors[0], new_vectors[0]]
        top_k = 10
387
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
388
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["new(.*)"], params=search_param)
Z
zhenwu 已提交
389 390
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
391 392
        assert result[0][0].distance > epsilon
        assert result[1][0].distance <= epsilon
D
del-zhenwu 已提交
393
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["(.*)tag"], params=search_param)
Z
zhenwu 已提交
394 395
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
396 397
        assert result[0][0].distance <= epsilon
        assert result[1][0].distance <= epsilon
Z
zhenwu 已提交
398

D
del-zhenwu 已提交
399
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
400
    def test_search_ip_index_params(self, connect, ip_collection, get_simple_index):
J
JinHai-CN 已提交
401 402 403 404 405
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
406
        top_k = 10
407 408 409
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
X
Xiaohai Xu 已提交
410 411
        vectors, ids = self.init_data(connect, ip_collection)
        status = connect.create_index(ip_collection, index_type, index_param)
J
JinHai-CN 已提交
412
        query_vec = [vectors[0]]
413
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
414
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
J
JinHai-CN 已提交
415
        logging.getLogger().info(result)
D
del-zhenwu 已提交
416 417 418
        assert status.OK()
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
419
        assert result[0][0].distance >= result[0][1].distance
420

X
Xiaohai Xu 已提交
421
    def test_search_ip_large_nq_index_params(self, connect, ip_collection, get_simple_index):
422 423 424 425 426 427 428 429
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
430 431
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
            pytest.skip("rnsg not support in ip, skip pq")
X
Xiaohai Xu 已提交
432 433
        vectors, ids = self.init_data(connect, ip_collection)
        status = connect.create_index(ip_collection, index_type, index_param)
434 435 436 437 438
        query_vec = []
        for i in range (1200):
            query_vec.append(vectors[i])
        top_k = 10
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
439
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
440 441
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
442 443 444
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
J
JinHai-CN 已提交
445

D
del-zhenwu 已提交
446
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
447
    def test_search_ip_index_params_partition(self, connect, ip_collection, get_simple_index):
Z
zhenwu 已提交
448 449 450 451 452
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
453
        top_k = 10
454 455 456
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(index_param)
D
del-zhenwu 已提交
457 458 459
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
            pytest.skip("rnsg not support in ip, skip pq")

X
Xiaohai Xu 已提交
460 461 462
        status = connect.create_partition(ip_collection, tag)
        vectors, ids = self.init_data(connect, ip_collection)
        status = connect.create_index(ip_collection, index_type, index_param)
Z
zhenwu 已提交
463
        query_vec = [vectors[0]]
464
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
465
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
Z
zhenwu 已提交
466 467
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
468 469 470
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
D
del-zhenwu 已提交
471
        status, result = connect.search(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
472 473 474 475
        logging.getLogger().info(result)
        assert status.OK()
        assert len(result) == 0

D
del-zhenwu 已提交
476
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
477
    def test_search_ip_index_params_partition_A(self, connect, ip_collection, get_simple_index):
Z
zhenwu 已提交
478 479 480 481 482
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors and tag, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
483
        top_k = 10
484 485 486
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(index_param)
D
del-zhenwu 已提交
487 488 489
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
            pytest.skip("rnsg not support in ip, skip pq")

X
Xiaohai Xu 已提交
490 491 492
        status = connect.create_partition(ip_collection, tag)
        vectors, ids = self.init_data(connect, ip_collection, partition_tags=tag)
        status = connect.create_index(ip_collection, index_type, index_param)
Z
zhenwu 已提交
493
        query_vec = [vectors[0]]
494
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
495
        status, result = connect.search(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
496 497
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
498 499 500
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
Z
zhenwu 已提交
501

D
del-zhenwu 已提交
502 503 504 505 506 507 508 509 510 511 512
    @pytest.mark.level(2)
    def test_search_vectors_without_connect(self, dis_connect, collection):
        '''
        target: test search vectors without connection
        method: use dis connected instance, call search method and check if search successfully
        expected: raise exception
        '''
        query_vectors = [vectors[0]]
        nprobe = 1
        with pytest.raises(Exception) as e:
            status, ids = dis_connect.search(collection, top_k, query_vectors)
J
JinHai-CN 已提交
513

X
Xiaohai Xu 已提交
514
    def test_search_collection_name_not_existed(self, connect, collection):
J
JinHai-CN 已提交
515
        '''
X
Xiaohai Xu 已提交
516 517
        target: search collection not existed
        method: search with the random collection_name, which is not in db
J
JinHai-CN 已提交
518 519
        expected: status not ok
        '''
X
Xiaohai Xu 已提交
520
        collection_name = gen_unique_str("not_existed_collection")
J
JinHai-CN 已提交
521 522
        nprobe = 1
        query_vecs = [vectors[0]]
D
del-zhenwu 已提交
523
        status, result = connect.search(collection_name, top_k, query_vecs)
J
JinHai-CN 已提交
524 525
        assert not status.OK()

X
Xiaohai Xu 已提交
526
    def test_search_collection_name_None(self, connect, collection):
J
JinHai-CN 已提交
527
        '''
X
Xiaohai Xu 已提交
528 529
        target: search collection that collection name is None
        method: search with the collection_name: None
J
JinHai-CN 已提交
530 531
        expected: status not ok
        '''
X
Xiaohai Xu 已提交
532
        collection_name = None
J
JinHai-CN 已提交
533 534 535
        nprobe = 1
        query_vecs = [vectors[0]]
        with pytest.raises(Exception) as e: 
D
del-zhenwu 已提交
536
            status, result = connect.search(collection_name, top_k, query_vecs)
J
JinHai-CN 已提交
537

X
Xiaohai Xu 已提交
538
    def test_search_top_k_query_records(self, connect, collection):
J
JinHai-CN 已提交
539 540 541 542 543 544
        '''
        target: test search fuction, with search params: query_records
        method: search with the given query_records, which are subarrays of the inserted vectors
        expected: status ok and the returned vectors should be query_records
        '''
        top_k = 10
X
Xiaohai Xu 已提交
545
        vectors, ids = self.init_data(connect, collection)
J
JinHai-CN 已提交
546
        query_vecs = [vectors[0],vectors[55],vectors[99]]
D
del-zhenwu 已提交
547
        status, result = connect.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
548 549 550 551 552 553
        assert status.OK()
        assert len(result) == len(query_vecs)
        for i in range(len(query_vecs)):
            assert len(result[i]) == top_k
            assert result[i][0].distance <= epsilon

X
Xiaohai Xu 已提交
554
    def test_search_distance_l2_flat_index(self, connect, collection):
J
JinHai-CN 已提交
555
        '''
X
Xiaohai Xu 已提交
556
        target: search collection, and check the result: distance
J
JinHai-CN 已提交
557 558 559 560
        method: compare the return distance value with value computed with Euclidean
        expected: the return distance equals to the computed value
        '''
        nb = 2
X
Xiaohai Xu 已提交
561
        vectors, ids = self.init_data(connect, collection, nb=nb)
J
JinHai-CN 已提交
562 563 564
        query_vecs = [[0.50 for i in range(dim)]]
        distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
        distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
D
del-zhenwu 已提交
565
        status, result = connect.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
566 567
        assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)

X
Xiaohai Xu 已提交
568
    def test_search_distance_ip_flat_index(self, connect, ip_collection):
J
JinHai-CN 已提交
569
        '''
X
Xiaohai Xu 已提交
570
        target: search ip_collection, and check the result: distance
J
JinHai-CN 已提交
571 572 573 574 575
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        nb = 2
        nprobe = 1
X
Xiaohai Xu 已提交
576
        vectors, ids = self.init_data(connect, ip_collection, nb=nb)
577 578
        index_type = IndexType.FLAT
        index_param = {
J
JinHai-CN 已提交
579 580
            "nlist": 16384
        }
X
Xiaohai Xu 已提交
581
        connect.create_index(ip_collection, index_type, index_param)
D
del-zhenwu 已提交
582
        logging.getLogger().info(connect.get_index_info(ip_collection))
J
JinHai-CN 已提交
583 584 585
        query_vecs = [[0.50 for i in range(dim)]]
        distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
        distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
586
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
587
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
J
JinHai-CN 已提交
588 589
        assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)

X
Xiaohai Xu 已提交
590
    def test_search_distance_jaccard_flat_index(self, connect, jac_collection):
G
groot 已提交
591
        '''
X
Xiaohai Xu 已提交
592
        target: search ip_collection, and check the result: distance
G
groot 已提交
593 594 595 596 597
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
X
Xiaohai Xu 已提交
598
        int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2)
599 600
        index_type = IndexType.FLAT
        index_param = {
G
groot 已提交
601 602
            "nlist": 16384
        }
X
Xiaohai Xu 已提交
603
        connect.create_index(jac_collection, index_type, index_param)
D
del-zhenwu 已提交
604 605
        logging.getLogger().info(connect.get_collection_info(jac_collection))
        logging.getLogger().info(connect.get_index_info(jac_collection))
X
Xiaohai Xu 已提交
606
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_collection, nb=1, insert=False)
G
groot 已提交
607 608
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
609
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
610
        status, result = connect.search(jac_collection, top_k, query_vecs, params=search_param)
G
groot 已提交
611 612 613 614
        logging.getLogger().info(status)
        logging.getLogger().info(result)
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon

X
Xiaohai Xu 已提交
615
    def test_search_distance_hamming_flat_index(self, connect, ham_collection):
G
groot 已提交
616
        '''
X
Xiaohai Xu 已提交
617
        target: search ip_collection, and check the result: distance
G
groot 已提交
618 619 620 621 622
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
X
Xiaohai Xu 已提交
623
        int_vectors, vectors, ids = self.init_binary_data(connect, ham_collection, nb=2)
624 625
        index_type = IndexType.FLAT
        index_param = {
G
groot 已提交
626 627
            "nlist": 16384
        }
X
Xiaohai Xu 已提交
628
        connect.create_index(ham_collection, index_type, index_param)
D
del-zhenwu 已提交
629 630
        logging.getLogger().info(connect.get_collection_info(ham_collection))
        logging.getLogger().info(connect.get_index_info(ham_collection))
X
Xiaohai Xu 已提交
631
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, ham_collection, nb=1, insert=False)
G
groot 已提交
632 633
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
634
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
635
        status, result = connect.search(ham_collection, top_k, query_vecs, params=search_param)
G
groot 已提交
636 637 638 639
        logging.getLogger().info(status)
        logging.getLogger().info(result)
        assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon

D
del-zhenwu 已提交
640 641 642 643 644 645 646 647 648 649 650 651 652 653
    def test_search_distance_substructure_flat_index(self, connect, substructure_collection):
        '''
        target: search ip_collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
        index_type = IndexType.FLAT
        index_param = {
            "nlist": 16384
        }
        connect.create_index(substructure_collection, index_type, index_param)
D
del-zhenwu 已提交
654 655
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
        logging.getLogger().info(connect.get_index_info(substructure_collection))
D
del-zhenwu 已提交
656 657 658 659
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False)
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
660
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
D
del-zhenwu 已提交
661 662
        logging.getLogger().info(status)
        logging.getLogger().info(result)
D
del-zhenwu 已提交
663
        assert len(result[0]) == 0
664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679

    def test_search_distance_substructure_flat_index_B(self, connect, substructure_collection):
        '''
        target: search ip_collection, and check the result: distance
        method: compare the return distance value with value computed with SUB 
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        top_k = 3
        nprobe = 512
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
        index_type = IndexType.FLAT
        index_param = {
            "nlist": 16384
        }
        connect.create_index(substructure_collection, index_type, index_param)
D
del-zhenwu 已提交
680 681
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
        logging.getLogger().info(connect.get_index_info(substructure_collection))
682 683
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
684
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
685 686
        logging.getLogger().info(status)
        logging.getLogger().info(result) 
D
del-zhenwu 已提交
687 688
        assert len(result[0]) == 1
        assert len(result[1]) == 1
689 690 691 692
        assert result[0][0].distance <= epsilon
        assert result[0][0].id == ids[0]
        assert result[1][0].distance <= epsilon
        assert result[1][0].id == ids[1]
D
del-zhenwu 已提交
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707

    def test_search_distance_superstructure_flat_index(self, connect, superstructure_collection):
        '''
        target: search ip_collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
        index_type = IndexType.FLAT
        index_param = {
            "nlist": 16384
        }
        connect.create_index(superstructure_collection, index_type, index_param)
D
del-zhenwu 已提交
708 709
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
D
del-zhenwu 已提交
710 711 712 713
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False)
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
714
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
D
del-zhenwu 已提交
715 716
        logging.getLogger().info(status)
        logging.getLogger().info(result)
D
del-zhenwu 已提交
717
        assert len(result[0]) == 0
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733

    def test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection):
        '''
        target: search ip_collection, and check the result: distance
        method: compare the return distance value with value computed with SUPER
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        top_k = 3
        nprobe = 512
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
        index_type = IndexType.FLAT
        index_param = {
            "nlist": 16384
        }
        connect.create_index(superstructure_collection, index_type, index_param)
D
del-zhenwu 已提交
734 735
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
736 737
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
738
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
739 740
        logging.getLogger().info(status)
        logging.getLogger().info(result)
D
del-zhenwu 已提交
741 742
        assert len(result[0]) == 2
        assert len(result[1]) == 2
743 744 745 746
        assert result[0][0].id in ids
        assert result[0][0].distance <= epsilon
        assert result[1][0].id in ids
        assert result[1][0].distance <= epsilon
D
del-zhenwu 已提交
747

X
Xiaohai Xu 已提交
748
    def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection):
G
groot 已提交
749
        '''
X
Xiaohai Xu 已提交
750
        target: search ip_collection, and check the result: distance
G
groot 已提交
751 752 753 754 755
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
X
Xiaohai Xu 已提交
756
        int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_collection, nb=2)
757 758
        index_type = IndexType.FLAT
        index_param = {
G
groot 已提交
759 760
            "nlist": 16384
        }
X
Xiaohai Xu 已提交
761
        connect.create_index(tanimoto_collection, index_type, index_param)
D
del-zhenwu 已提交
762 763
        logging.getLogger().info(connect.get_collection_info(tanimoto_collection))
        logging.getLogger().info(connect.get_index_info(tanimoto_collection))
X
Xiaohai Xu 已提交
764
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, tanimoto_collection, nb=1, insert=False)
G
groot 已提交
765 766
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
767
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
768
        status, result = connect.search(tanimoto_collection, top_k, query_vecs, params=search_param)
G
groot 已提交
769 770 771 772
        logging.getLogger().info(status)
        logging.getLogger().info(result)
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon

X
Xiaohai Xu 已提交
773
    def test_search_distance_ip_index_params(self, connect, ip_collection, get_index):
J
JinHai-CN 已提交
774
        '''
X
Xiaohai Xu 已提交
775
        target: search collection, and check the result: distance
J
JinHai-CN 已提交
776 777 778 779 780
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        top_k = 2
        nprobe = 1
781 782
        index_param = get_index["index_param"]
        index_type = get_index["index_type"]
D
del-zhenwu 已提交
783 784 785
        if index_type == IndexType.RNSG:
            pytest.skip("rnsg not support in ip")
        vectors, ids = self.init_data(connect, ip_collection, nb=2)
X
Xiaohai Xu 已提交
786
        connect.create_index(ip_collection, index_type, index_param)
D
del-zhenwu 已提交
787
        logging.getLogger().info(connect.get_index_info(ip_collection))
J
JinHai-CN 已提交
788
        query_vecs = [[0.50 for i in range(dim)]]
789
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
790
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
G
groot 已提交
791 792
        logging.getLogger().debug(status)
        logging.getLogger().debug(result)
J
JinHai-CN 已提交
793 794 795 796
        distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
        distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
        assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)

D
del-zhenwu 已提交
797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
    # def test_search_concurrent(self, connect, collection):
    #     vectors, ids = self.init_data(connect, collection, nb=5000)
    #     thread_num = 50
    #     nq = 1
    #     top_k = 2
    #     threads = []
    #     query_vecs = vectors[:nq]
    #     def search(thread_number):
    #         for i in range(1000000):
    #             status, result = connect.search(collection, top_k, query_vecs, timeout=2)
    #             assert len(result) == len(query_vecs)
    #             assert status.OK()
    #             if i % 1000 == 0:
    #                 logging.getLogger().info("In %d, %d" % (thread_number, i))
    #         logging.getLogger().info("%d finished" % thread_number)
    #     # with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
    #     #     future_results = {executor.submit(
    #     #         search): i for i in range(1000000)}
    #     #     for future in concurrent.futures.as_completed(future_results):
    #     #         future.result()
    #     for i in range(thread_num):
    #         t = threading.Thread(target=search, args=(i, ))
    #         threads.append(t)
    #         t.start()
    #     for t in threads:
    #         t.join()
J
JinHai-CN 已提交
823

D
del-zhenwu 已提交
824
    @pytest.mark.level(2)
825 826 827 828 829 830 831 832 833 834 835
    @pytest.mark.timeout(30)
    def test_search_concurrent_multithreads(self, args):
        '''
        target: test concurrent search with multiprocessess
        method: search with 10 processes, each process uses dependent connection
        expected: status ok and the returned vectors should be query_records
        '''
        nb = 100
        top_k = 10
        threads_num = 4
        threads = []
X
Xiaohai Xu 已提交
836
        collection = gen_unique_str("test_search_concurrent_multiprocessing")
837
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
X
Xiaohai Xu 已提交
838
        param = {'collection_name': collection,
839 840 841
                 'dimension': dim,
                 'index_type': IndexType.FLAT,
                 'store_raw_vector': False}
X
Xiaohai Xu 已提交
842
        # create collection
D
del-zhenwu 已提交
843
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
X
Xiaohai Xu 已提交
844 845
        milvus.create_collection(param)
        vectors, ids = self.init_data(milvus, collection, nb=nb)
846 847
        query_vecs = vectors[nb//2:nb]
        def search(milvus):
D
del-zhenwu 已提交
848
            status, result = milvus.search(collection, top_k, query_vecs)
849 850 851 852 853 854
            assert len(result) == len(query_vecs)
            for i in range(len(query_vecs)):
                assert result[i][0].id in ids
                assert result[i][0].distance == 0.0

        for i in range(threads_num):
D
del-zhenwu 已提交
855
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
856 857 858 859 860 861 862
            t = threading.Thread(target=search, args=(milvus, ))
            threads.append(t)
            t.start()
            time.sleep(0.2)
        for t in threads:
            t.join()

J
JinHai-CN 已提交
863 864 865 866 867 868 869 870 871 872 873 874
    # TODO: enable
    @pytest.mark.timeout(30)
    def _test_search_concurrent_multiprocessing(self, args):
        '''
        target: test concurrent search with multiprocessess
        method: search with 10 processes, each process uses dependent connection
        expected: status ok and the returned vectors should be query_records
        '''
        nb = 100
        top_k = 10
        process_num = 4
        processes = []
X
Xiaohai Xu 已提交
875
        collection = gen_unique_str("test_search_concurrent_multiprocessing")
J
JinHai-CN 已提交
876
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
X
Xiaohai Xu 已提交
877
        param = {'collection_name': collection,
J
JinHai-CN 已提交
878 879 880
             'dimension': dim,
             'index_type': IndexType.FLAT,
             'store_raw_vector': False}
X
Xiaohai Xu 已提交
881
        # create collection
D
del-zhenwu 已提交
882
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
X
Xiaohai Xu 已提交
883 884
        milvus.create_collection(param)
        vectors, ids = self.init_data(milvus, collection, nb=nb)
J
JinHai-CN 已提交
885 886
        query_vecs = vectors[nb//2:nb]
        def search(milvus):
D
del-zhenwu 已提交
887
            status, result = milvus.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
888 889 890 891 892 893
            assert len(result) == len(query_vecs)
            for i in range(len(query_vecs)):
                assert result[i][0].id in ids
                assert result[i][0].distance == 0.0

        for i in range(process_num):
D
del-zhenwu 已提交
894
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
J
JinHai-CN 已提交
895 896 897 898 899 900 901
            p = Process(target=search, args=(milvus, ))
            processes.append(p)
            p.start()
            time.sleep(0.2)
        for p in processes:
            p.join()

X
Xiaohai Xu 已提交
902
    def test_search_multi_collection_L2(search, args):
J
JinHai-CN 已提交
903
        '''
X
Xiaohai Xu 已提交
904 905
        target: test search multi collections of L2
        method: add vectors into 10 collections, and search
J
JinHai-CN 已提交
906 907 908 909
        expected: search status ok, the length of result
        '''
        num = 10
        top_k = 10
X
Xiaohai Xu 已提交
910
        collections = []
J
JinHai-CN 已提交
911 912
        idx = []
        for i in range(num):
X
Xiaohai Xu 已提交
913
            collection = gen_unique_str("test_add_multicollection_%d" % i)
J
JinHai-CN 已提交
914
            uri = "tcp://%s:%s" % (args["ip"], args["port"])
X
Xiaohai Xu 已提交
915
            param = {'collection_name': collection,
J
JinHai-CN 已提交
916 917 918
                     'dimension': dim,
                     'index_file_size': 10,
                     'metric_type': MetricType.L2}
X
Xiaohai Xu 已提交
919
            # create collection
D
del-zhenwu 已提交
920
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
X
Xiaohai Xu 已提交
921
            milvus.create_collection(param)
D
del-zhenwu 已提交
922
            status, ids = milvus.insert(collection, vectors)
J
JinHai-CN 已提交
923 924
            assert status.OK()
            assert len(ids) == len(vectors)
X
Xiaohai Xu 已提交
925
            collections.append(collection)
J
JinHai-CN 已提交
926 927 928
            idx.append(ids[0])
            idx.append(ids[10])
            idx.append(ids[20])
929
            milvus.flush([collection])
J
JinHai-CN 已提交
930
        query_vecs = [vectors[0], vectors[10], vectors[20]]
X
Xiaohai Xu 已提交
931
        # start query from random collection
J
JinHai-CN 已提交
932
        for i in range(num):
X
Xiaohai Xu 已提交
933
            collection = collections[i]
D
del-zhenwu 已提交
934
            status, result = milvus.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
935 936 937 938 939 940 941
            assert status.OK()
            assert len(result) == len(query_vecs)
            for j in range(len(query_vecs)):
                assert len(result[j]) == top_k
            for j in range(len(query_vecs)):
                assert check_result(result[j], idx[3 * i + j])

X
Xiaohai Xu 已提交
942
    def test_search_multi_collection_IP(search, args):
J
JinHai-CN 已提交
943
        '''
X
Xiaohai Xu 已提交
944 945
        target: test search multi collections of IP
        method: add vectors into 10 collections, and search
J
JinHai-CN 已提交
946 947 948 949
        expected: search status ok, the length of result
        '''
        num = 10
        top_k = 10
X
Xiaohai Xu 已提交
950
        collections = []
J
JinHai-CN 已提交
951 952
        idx = []
        for i in range(num):
X
Xiaohai Xu 已提交
953
            collection = gen_unique_str("test_add_multicollection_%d" % i)
J
JinHai-CN 已提交
954
            uri = "tcp://%s:%s" % (args["ip"], args["port"])
X
Xiaohai Xu 已提交
955
            param = {'collection_name': collection,
J
JinHai-CN 已提交
956 957 958
                     'dimension': dim,
                     'index_file_size': 10,
                     'metric_type': MetricType.L2}
X
Xiaohai Xu 已提交
959
            # create collection
D
del-zhenwu 已提交
960
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
X
Xiaohai Xu 已提交
961
            milvus.create_collection(param)
D
del-zhenwu 已提交
962
            status, ids = milvus.insert(collection, vectors)
J
JinHai-CN 已提交
963 964
            assert status.OK()
            assert len(ids) == len(vectors)
X
Xiaohai Xu 已提交
965
            collections.append(collection)
J
JinHai-CN 已提交
966 967 968
            idx.append(ids[0])
            idx.append(ids[10])
            idx.append(ids[20])
969
            milvus.flush([collection])
J
JinHai-CN 已提交
970
        query_vecs = [vectors[0], vectors[10], vectors[20]]
X
Xiaohai Xu 已提交
971
        # start query from random collection
J
JinHai-CN 已提交
972
        for i in range(num):
X
Xiaohai Xu 已提交
973
            collection = collections[i]
D
del-zhenwu 已提交
974
            status, result = milvus.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
975 976 977 978 979 980 981 982 983
            assert status.OK()
            assert len(result) == len(query_vecs)
            for j in range(len(query_vecs)):
                assert len(result[j]) == top_k
            for j in range(len(query_vecs)):
                assert check_result(result[j], idx[3 * i + j])
"""
******************************************************************
#  The following cases are used to test `search_vectors` function 
X
Xiaohai Xu 已提交
984
#  with invalid collection_name top-k / nprobe / query_range
J
JinHai-CN 已提交
985 986 987 988
******************************************************************
"""

class TestSearchParamsInvalid(object):
Z
zhenwu 已提交
989
    nlist = 16384
990 991
    index_type = IndexType.IVF_SQ8
    index_param = {"nlist": nlist}
Z
zhenwu 已提交
992
    logging.getLogger().info(index_param)
J
JinHai-CN 已提交
993

X
Xiaohai Xu 已提交
994
    def init_data(self, connect, collection, nb=6000):
J
JinHai-CN 已提交
995
        '''
X
Xiaohai Xu 已提交
996
        Generate vectors and add it in collection, before search vectors
J
JinHai-CN 已提交
997 998
        '''
        global vectors
G
groot 已提交
999
        if nb == 6000:
D
del-zhenwu 已提交
1000
            insert = vectors
J
JinHai-CN 已提交
1001
        else:  
D
del-zhenwu 已提交
1002 1003
            insert = gen_vectors(nb, dim)
        status, ids = connect.insert(collection, insert)
J
JinHai-CN 已提交
1004
        sleep(add_interval_time)
D
del-zhenwu 已提交
1005
        return insert, ids
J
JinHai-CN 已提交
1006 1007

    """
X
Xiaohai Xu 已提交
1008
    Test search collection with invalid collection names
J
JinHai-CN 已提交
1009 1010 1011
    """
    @pytest.fixture(
        scope="function",
X
Xiaohai Xu 已提交
1012
        params=gen_invalid_collection_names()
J
JinHai-CN 已提交
1013
    )
X
Xiaohai Xu 已提交
1014
    def get_collection_name(self, request):
J
JinHai-CN 已提交
1015 1016 1017
        yield request.param

    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
1018 1019 1020
    def test_search_with_invalid_collectionname(self, connect, get_collection_name):
        collection_name = get_collection_name
        logging.getLogger().info(collection_name)
J
JinHai-CN 已提交
1021 1022
        nprobe = 1 
        query_vecs = gen_vectors(1, dim)
D
del-zhenwu 已提交
1023
        status, result = connect.search(collection_name, top_k, query_vecs)
J
JinHai-CN 已提交
1024 1025
        assert not status.OK()

Z
zhenwu 已提交
1026
    @pytest.mark.level(1)
X
Xiaohai Xu 已提交
1027
    def test_search_with_invalid_tag_format(self, connect, collection):
Z
zhenwu 已提交
1028 1029 1030
        nprobe = 1 
        query_vecs = gen_vectors(1, dim)
        with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1031
            status, result = connect.search(collection, top_k, query_vecs, partition_tags="tag")
1032 1033 1034 1035 1036 1037
            logging.getLogger().debug(result)

    @pytest.mark.level(1)
    def test_search_with_tag_not_existed(self, connect, collection):
        nprobe = 1
        query_vecs = gen_vectors(1, dim)
D
del-zhenwu 已提交
1038
        status, result = connect.search(collection, top_k, query_vecs, partition_tags=["tag"])
1039 1040
        logging.getLogger().info(result)
        assert not status.OK()
Z
zhenwu 已提交
1041

J
JinHai-CN 已提交
1042
    """
X
Xiaohai Xu 已提交
1043
    Test search collection with invalid top-k
J
JinHai-CN 已提交
1044 1045 1046 1047 1048 1049 1050 1051
    """
    @pytest.fixture(
        scope="function",
        params=gen_invalid_top_ks()
    )
    def get_top_k(self, request):
        yield request.param

Z
zhenwu 已提交
1052
    @pytest.mark.level(1)
X
Xiaohai Xu 已提交
1053
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
J
JinHai-CN 已提交
1054 1055 1056 1057 1058 1059 1060 1061 1062
        '''
        target: test search fuction, with the wrong top_k
        method: search with top_k
        expected: raise an error, and the connection is normal
        '''
        top_k = get_top_k
        logging.getLogger().info(top_k)
        nprobe = 1
        query_vecs = gen_vectors(1, dim)
Z
zhenwu 已提交
1063
        if isinstance(top_k, int):
D
del-zhenwu 已提交
1064
            status, result = connect.search(collection, top_k, query_vecs)
Z
zhenwu 已提交
1065 1066 1067
            assert not status.OK()
        else:
            with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1068
                status, result = connect.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
1069 1070

    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
1071
    def test_search_with_invalid_top_k_ip(self, connect, ip_collection, get_top_k):
J
JinHai-CN 已提交
1072 1073 1074 1075 1076 1077 1078 1079 1080
        '''
        target: test search fuction, with the wrong top_k
        method: search with top_k
        expected: raise an error, and the connection is normal
        '''
        top_k = get_top_k
        logging.getLogger().info(top_k)
        nprobe = 1
        query_vecs = gen_vectors(1, dim)
Z
zhenwu 已提交
1081
        if isinstance(top_k, int):
D
del-zhenwu 已提交
1082
            status, result = connect.search(ip_collection, top_k, query_vecs)
Z
zhenwu 已提交
1083 1084 1085
            assert not status.OK()
        else:
            with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1086
                status, result = connect.search(ip_collection, top_k, query_vecs)
J
JinHai-CN 已提交
1087
    """
X
Xiaohai Xu 已提交
1088
    Test search collection with invalid nprobe
J
JinHai-CN 已提交
1089 1090 1091 1092 1093 1094 1095 1096
    """
    @pytest.fixture(
        scope="function",
        params=gen_invalid_nprobes()
    )
    def get_nprobes(self, request):
        yield request.param

Z
zhenwu 已提交
1097
    @pytest.mark.level(1)
X
Xiaohai Xu 已提交
1098
    def test_search_with_invalid_nprobe(self, connect, collection, get_nprobes):
J
JinHai-CN 已提交
1099
        '''
1100 1101
        target: test search fuction, with the wrong nprobe
        method: search with nprobe
J
JinHai-CN 已提交
1102 1103
        expected: raise an error, and the connection is normal
        '''
1104 1105
        index_type = IndexType.IVF_SQ8
        index_param = {"nlist": 16384}
X
Xiaohai Xu 已提交
1106
        connect.create_index(collection, index_type, index_param)
J
JinHai-CN 已提交
1107
        nprobe = get_nprobes
1108
        search_param = {"nprobe": nprobe}
J
JinHai-CN 已提交
1109 1110
        logging.getLogger().info(nprobe)
        query_vecs = gen_vectors(1, dim)
1111
        # if isinstance(nprobe, int):
D
del-zhenwu 已提交
1112
        status, result = connect.search(collection, top_k, query_vecs, params=search_param)
1113 1114 1115
        assert not status.OK()
        # else:
        #     with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1116
        #         status, result = connect.search(collection, top_k, query_vecs, params=search_param)
J
JinHai-CN 已提交
1117 1118

    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
1119
    def test_search_with_invalid_nprobe_ip(self, connect, ip_collection, get_nprobes):
J
JinHai-CN 已提交
1120 1121 1122 1123 1124
        '''
        target: test search fuction, with the wrong top_k
        method: search with top_k
        expected: raise an error, and the connection is normal
        '''
1125 1126
        index_type = IndexType.IVF_SQ8
        index_param = {"nlist": 16384}
X
Xiaohai Xu 已提交
1127
        connect.create_index(ip_collection, index_type, index_param)
J
JinHai-CN 已提交
1128
        nprobe = get_nprobes
1129
        search_param = {"nprobe": nprobe}
J
JinHai-CN 已提交
1130 1131
        logging.getLogger().info(nprobe)
        query_vecs = gen_vectors(1, dim)
1132 1133

        # if isinstance(nprobe, int):
D
del-zhenwu 已提交
1134
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
1135 1136 1137
        assert not status.OK()
        # else:
        #     with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1138
        #         status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        if str(connect._cmd("mode")[1]) == "CPU":
            if request.param["index_type"] == IndexType.IVF_SQ8H:
                pytest.skip("sq8h not support in CPU mode")
        if str(connect._cmd("mode")[1]) == "GPU":
            if request.param["index_type"] == IndexType.IVF_PQ:
                pytest.skip("ivfpq not support in GPU mode")
        return request.param

1153
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
1154 1155 1156 1157 1158
        '''
        target: test search fuction, with empty search params
        method: search with params
        expected: search status not ok, and the connection is normal
        '''
1159 1160
        if args["handler"] == "HTTP":
            pytest.skip("skip in http mode")
1161 1162
        index_type = get_simple_index["index_type"]
        index_param = get_simple_index["index_param"]
X
Xiaohai Xu 已提交
1163
        connect.create_index(collection, index_type, index_param)
1164
        query_vecs = gen_vectors(1, dim)
D
del-zhenwu 已提交
1165
        status, result = connect.search(collection, top_k, query_vecs, params={})
1166 1167 1168

        if index_type == IndexType.FLAT:
            assert status.OK()
J
JinHai-CN 已提交
1169
        else:
1170 1171 1172 1173 1174 1175
            assert not status.OK()

    @pytest.fixture(
        scope="function",
        params=gen_invaild_search_params()
    )
D
del-zhenwu 已提交
1176
    def get_invalid_search_param(self, request, connect):
1177 1178 1179 1180 1181 1182 1183 1184
        if str(connect._cmd("mode")[1]) == "CPU":
            if request.param["index_type"] == IndexType.IVF_SQ8H:
                pytest.skip("sq8h not support in CPU mode")
        if str(connect._cmd("mode")[1]) == "GPU":
            if request.param["index_type"] == IndexType.IVF_PQ:
                pytest.skip("ivfpq not support in GPU mode")
        return request.param

D
del-zhenwu 已提交
1185
    def test_search_with_invalid_params(self, connect, collection, get_invalid_search_param):
1186 1187 1188 1189 1190
        '''
        target: test search fuction, with invalid search params
        method: search with params
        expected: search status not ok, and the connection is normal
        '''
D
del-zhenwu 已提交
1191 1192 1193 1194 1195
        index_type = get_invalid_search_param["index_type"]
        search_param = get_invalid_search_param["search_param"]
        for index in gen_simple_index():
            if index_type == index["index_type"]:
                connect.create_index(collection, index_type, index["index_param"])
1196
        query_vecs = gen_vectors(1, dim)
D
del-zhenwu 已提交
1197
        status, result = connect.search(collection, top_k, query_vecs, params=search_param)
1198
        assert not status.OK()
J
JinHai-CN 已提交
1199 1200 1201 1202 1203

def check_result(result, id):
    if len(result) >= 5:
        return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
    else:
1204
        return id in (i.id for i in result)