test_search.py 53.6 KB
Newer Older
J
JinHai-CN 已提交
1
import pdb
G
groot 已提交
2
import struct
3
from random import sample
G
groot 已提交
4

J
JinHai-CN 已提交
5 6 7 8 9 10 11
import pytest
import threading
import datetime
import logging
from time import sleep
from multiprocessing import Process
import numpy
12
import sklearn.preprocessing
13
from milvus import IndexType, MetricType
J
JinHai-CN 已提交
14 15 16
from utils import *

dim = 128
X
Xiaohai Xu 已提交
17
collection_id = "test_search"
J
JinHai-CN 已提交
18
add_interval_time = 2
G
groot 已提交
19
vectors = gen_vectors(6000, dim)
20 21
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
vectors = vectors.tolist()
D
del-zhenwu 已提交
22
top_k = 1
Z
zhenwu 已提交
23
nprobe = 1
J
JinHai-CN 已提交
24
epsilon = 0.001
Z
zhenwu 已提交
25
tag = "1970-01-01"
G
groot 已提交
26
raw_vectors, binary_vectors = gen_binary_vectors(6000, dim)
J
JinHai-CN 已提交
27 28 29


class TestSearchBase:
X
Xiaohai Xu 已提交
30
    def init_data(self, connect, collection, nb=6000, partition_tags=None):
J
JinHai-CN 已提交
31
        '''
X
Xiaohai Xu 已提交
32
        Generate vectors and add it in collection, before search vectors
J
JinHai-CN 已提交
33 34
        '''
        global vectors
G
groot 已提交
35
        if nb == 6000:
J
JinHai-CN 已提交
36 37 38
            add_vectors = vectors
        else:  
            add_vectors = gen_vectors(nb, dim)
D
del-zhenwu 已提交
39 40
            add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2')
            add_vectors = add_vectors.tolist()
41
        if partition_tags is None:
D
del-zhenwu 已提交
42
            status, ids = connect.insert(collection, add_vectors)
43 44
            assert status.OK()
        else:
D
del-zhenwu 已提交
45
            status, ids = connect.insert(collection, add_vectors, partition_tag=partition_tags)
46
            assert status.OK()
D
del-zhenwu 已提交
47
        connect.flush([collection])
J
JinHai-CN 已提交
48 49
        return add_vectors, ids

X
Xiaohai Xu 已提交
50
    def init_binary_data(self, connect, collection, nb=6000, insert=True, partition_tags=None):
G
groot 已提交
51
        '''
X
Xiaohai Xu 已提交
52
        Generate vectors and add it in collection, before search vectors
G
groot 已提交
53 54 55 56 57 58 59 60 61 62
        '''
        ids = []
        global binary_vectors
        global raw_vectors
        if nb == 6000:
            add_vectors = binary_vectors
            add_raw_vectors = raw_vectors
        else:  
            add_raw_vectors, add_vectors = gen_binary_vectors(nb, dim)
        if insert is True:
63
            if partition_tags is None:
D
del-zhenwu 已提交
64
                status, ids = connect.insert(collection, add_vectors)
65 66
                assert status.OK()
            else:
D
del-zhenwu 已提交
67
                status, ids = connect.insert(collection, add_vectors, partition_tag=partition_tags)
68
                assert status.OK()
D
del-zhenwu 已提交
69
            connect.flush([collection])
G
groot 已提交
70 71
        return add_raw_vectors, add_vectors, ids

J
JinHai-CN 已提交
72 73 74 75 76
    """
    generate valid create_index params
    """
    @pytest.fixture(
        scope="function",
77
        params=gen_index()
J
JinHai-CN 已提交
78
    )
79
    def get_index(self, request, connect):
G
groot 已提交
80
        if str(connect._cmd("mode")[1]) == "CPU":
81
            if request.param["index_type"] == IndexType.IVF_SQ8H:
82
                pytest.skip("sq8h not support in CPU mode")
83 84 85
        if str(connect._cmd("mode")[1]) == "GPU":
            if request.param["index_type"] == IndexType.IVF_PQ:
                pytest.skip("ivfpq not support in GPU mode")
86
        return request.param
J
JinHai-CN 已提交
87

Z
zhenwu 已提交
88 89
    @pytest.fixture(
        scope="function",
90
        params=gen_simple_index()
Z
zhenwu 已提交
91
    )
92
    def get_simple_index(self, request, connect):
G
groot 已提交
93
        if str(connect._cmd("mode")[1]) == "CPU":
Z
zhenwu 已提交
94
            if request.param["index_type"] == IndexType.IVF_SQ8H:
95
                pytest.skip("sq8h not support in CPU mode")
96 97 98
        if str(connect._cmd("mode")[1]) == "GPU":
            if request.param["index_type"] == IndexType.IVF_PQ:
                pytest.skip("ivfpq not support in GPU mode")
Z
zhenwu 已提交
99
        return request.param
G
groot 已提交
100 101 102

    @pytest.fixture(
        scope="function",
103
        params=gen_simple_index()
G
groot 已提交
104
    )
105
    def get_jaccard_index(self, request, connect):
G
groot 已提交
106 107 108 109 110 111 112 113
        logging.getLogger().info(request.param)
        if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
            return request.param
        else:
            pytest.skip("Skip index Temporary")

    @pytest.fixture(
        scope="function",
114
        params=gen_simple_index()
G
groot 已提交
115
    )
116
    def get_hamming_index(self, request, connect):
G
groot 已提交
117 118 119 120 121 122
        logging.getLogger().info(request.param)
        if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
            return request.param
        else:
            pytest.skip("Skip index Temporary")

D
del-zhenwu 已提交
123 124 125 126 127 128 129 130 131 132 133
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_structure_index(self, request, connect):
        logging.getLogger().info(request.param)
        if request.param["index_type"] == IndexType.FLAT:
            return request.param
        else:
            pytest.skip("Skip index Temporary")

J
JinHai-CN 已提交
134 135 136 137 138
    """
    generate top-k params
    """
    @pytest.fixture(
        scope="function",
D
del-zhenwu 已提交
139
        params=[1, 99, 1024, 2049]
J
JinHai-CN 已提交
140 141 142 143 144
    )
    def get_top_k(self, request):
        yield request.param


X
Xiaohai Xu 已提交
145
    def test_search_top_k_flat_index(self, connect, collection, get_top_k):
J
JinHai-CN 已提交
146 147 148 149 150
        '''
        target: test basic search fuction, all the search params is corrent, change top-k value
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
X
Xiaohai Xu 已提交
151
        vectors, ids = self.init_data(connect, collection)
J
JinHai-CN 已提交
152 153
        query_vec = [vectors[0]]
        top_k = get_top_k
D
del-zhenwu 已提交
154
        status, result = connect.search(collection, top_k, query_vec)
J
JinHai-CN 已提交
155 156 157 158 159 160 161 162
        if top_k <= 2048:
            assert status.OK()
            assert len(result[0]) == min(len(vectors), top_k)
            assert result[0][0].distance <= epsilon
            assert check_result(result[0], ids[0])
        else:
            assert not status.OK()

X
Xiaohai Xu 已提交
163
    def test_search_l2_index_params(self, connect, collection, get_simple_index):
J
JinHai-CN 已提交
164 165 166 167 168
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
169
        top_k = 10
170 171 172
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
173 174 175
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")

X
Xiaohai Xu 已提交
176 177
        vectors, ids = self.init_data(connect, collection)
        status = connect.create_index(collection, index_type, index_param)
J
JinHai-CN 已提交
178
        query_vec = [vectors[0]]
179
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
180
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
J
JinHai-CN 已提交
181 182 183 184 185 186 187 188 189
        logging.getLogger().info(result)
        if top_k <= 1024:
            assert status.OK()
            assert len(result[0]) == min(len(vectors), top_k)
            assert check_result(result[0], ids[0])
            assert result[0][0].distance <= epsilon
        else:
            assert not status.OK()

X
Xiaohai Xu 已提交
190
    def test_search_l2_large_nq_index_params(self, connect, collection, get_simple_index):
191 192 193 194 195
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
196
        top_k = 10
197 198 199
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
200 201 202
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")

X
Xiaohai Xu 已提交
203 204
        vectors, ids = self.init_data(connect, collection)
        status = connect.create_index(collection, index_type, index_param)
D
del-zhenwu 已提交
205
        query_vec = vectors[:1000]
206
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
207
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
208 209 210 211 212 213
        logging.getLogger().info(result)
        assert status.OK()
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon

X
Xiaohai Xu 已提交
214
    def test_search_l2_index_params_partition(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
215 216
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
217 218
        method: add vectors into collection, search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k, search collection with partition tag return empty
Z
zhenwu 已提交
219
        '''
D
del-zhenwu 已提交
220
        top_k = 10
221 222 223
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
224 225
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
X
Xiaohai Xu 已提交
226 227 228
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
229
        query_vec = [vectors[0]]
230
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
231
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
Z
zhenwu 已提交
232 233
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
234 235 236
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon
D
del-zhenwu 已提交
237
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
238 239 240 241
        logging.getLogger().info(result)
        assert status.OK()
        assert len(result) == 0

X
Xiaohai Xu 已提交
242
    def test_search_l2_index_params_partition_A(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
243 244 245 246 247
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search partition with the given vectors, check the result
        expected: search status ok, and the length of the result is 0
        '''
D
del-zhenwu 已提交
248
        top_k = 10
249 250 251
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
252 253 254
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")

X
Xiaohai Xu 已提交
255 256 257
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
258
        query_vec = [vectors[0]]
259
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
260
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
261 262 263 264
        logging.getLogger().info(result)
        assert status.OK()
        assert len(result) == 0

X
Xiaohai Xu 已提交
265
    def test_search_l2_index_params_partition_B(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
266 267 268 269 270
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
271
        top_k = 10
272 273 274
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
275 276
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
X
Xiaohai Xu 已提交
277 278 279
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
280
        query_vec = [vectors[0]]
281
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
282
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
Z
zhenwu 已提交
283 284
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
285 286 287
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon
D
del-zhenwu 已提交
288
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
289 290
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
291 292 293
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon
Z
zhenwu 已提交
294

X
Xiaohai Xu 已提交
295
    def test_search_l2_index_params_partition_C(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
296 297
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
298
        method: search with the given vectors and tags (one of the tags not existed in collection), check the result
Z
zhenwu 已提交
299 300
        expected: search status ok, and the length of the result is top_k
        '''
301 302 303
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
304 305
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
X
Xiaohai Xu 已提交
306 307 308
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
309 310
        query_vec = [vectors[0]]
        top_k = 10
311
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
312
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], params=search_param)
Z
zhenwu 已提交
313 314
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
315 316 317
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance <= epsilon
Z
zhenwu 已提交
318

D
del-zhenwu 已提交
319
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
320
    def test_search_l2_index_params_partition_D(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
321 322
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
323
        method: search with the given vectors and tag (tag name not existed in collection), check the result
Z
zhenwu 已提交
324 325
        expected: search status ok, and the length of the result is top_k
        '''
326 327 328
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
X
Xiaohai Xu 已提交
329 330 331
        status = connect.create_partition(collection, tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
332 333
        query_vec = [vectors[0]]
        top_k = 10
334
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
335
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["new_tag"], params=search_param)
Z
zhenwu 已提交
336
        logging.getLogger().info(result)
T
Tinkerrr 已提交
337
        assert not status.OK()
Z
zhenwu 已提交
338

D
del-zhenwu 已提交
339
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
340
    def test_search_l2_index_params_partition_E(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
341 342
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
343
        method: search collection with the given vectors and tags, check the result
Z
zhenwu 已提交
344 345
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
346
        top_k = 10
Z
zhenwu 已提交
347
        new_tag = "new_tag"
348
        index_type = get_simple_index["index_type"]
D
del-zhenwu 已提交
349 350 351
        index_param = get_simple_index["index_param"]
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
352
        logging.getLogger().info(get_simple_index)
X
Xiaohai Xu 已提交
353 354 355 356 357
        status = connect.create_partition(collection, tag)
        status = connect.create_partition(collection, new_tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
358
        query_vec = [vectors[0], new_vectors[0]]
359
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
360
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag], params=search_param)
Z
zhenwu 已提交
361 362
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
363 364 365 366 367
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert check_result(result[1], new_ids[0])
        assert result[0][0].distance <= epsilon
        assert result[1][0].distance <= epsilon
D
del-zhenwu 已提交
368
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[new_tag], params=search_param)
Z
zhenwu 已提交
369 370
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
371 372 373
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[1], new_ids[0])
        assert result[1][0].distance <= epsilon
Z
zhenwu 已提交
374

X
Xiaohai Xu 已提交
375
    def test_search_l2_index_params_partition_F(self, connect, collection, get_simple_index):
Z
zhenwu 已提交
376 377
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
X
Xiaohai Xu 已提交
378
        method: search collection with the given vectors and tags with "re" expr, check the result
Z
zhenwu 已提交
379 380
        expected: search status ok, and the length of the result is top_k
        '''
Z
zhenwu 已提交
381
        tag = "atag"
Z
zhenwu 已提交
382
        new_tag = "new_tag"
383 384 385
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
386 387
        if index_type == IndexType.IVF_PQ:
            pytest.skip("Skip PQ")
X
Xiaohai Xu 已提交
388 389 390 391 392
        status = connect.create_partition(collection, tag)
        status = connect.create_partition(collection, new_tag)
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
        new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag)
        status = connect.create_index(collection, index_type, index_param)
Z
zhenwu 已提交
393 394
        query_vec = [vectors[0], new_vectors[0]]
        top_k = 10
395
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
396
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["new(.*)"], params=search_param)
Z
zhenwu 已提交
397 398
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
399 400
        assert result[0][0].distance > epsilon
        assert result[1][0].distance <= epsilon
D
del-zhenwu 已提交
401
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["(.*)tag"], params=search_param)
Z
zhenwu 已提交
402 403
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
404 405
        assert result[0][0].distance <= epsilon
        assert result[1][0].distance <= epsilon
Z
zhenwu 已提交
406

D
del-zhenwu 已提交
407
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
408
    def test_search_ip_index_params(self, connect, ip_collection, get_simple_index):
J
JinHai-CN 已提交
409 410 411 412 413
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
414
        top_k = 10
415 416 417
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
418 419 420
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
            pytest.skip("rnsg not support in ip, skip pq")

X
Xiaohai Xu 已提交
421 422
        vectors, ids = self.init_data(connect, ip_collection)
        status = connect.create_index(ip_collection, index_type, index_param)
J
JinHai-CN 已提交
423
        query_vec = [vectors[0]]
424
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
425
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
J
JinHai-CN 已提交
426
        logging.getLogger().info(result)
D
del-zhenwu 已提交
427 428 429 430
        assert status.OK()
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
431

X
Xiaohai Xu 已提交
432
    def test_search_ip_large_nq_index_params(self, connect, ip_collection, get_simple_index):
433 434 435 436 437 438 439 440
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(get_simple_index)
D
del-zhenwu 已提交
441 442
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
            pytest.skip("rnsg not support in ip, skip pq")
X
Xiaohai Xu 已提交
443 444
        vectors, ids = self.init_data(connect, ip_collection)
        status = connect.create_index(ip_collection, index_type, index_param)
445 446 447 448 449
        query_vec = []
        for i in range (1200):
            query_vec.append(vectors[i])
        top_k = 10
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
450
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
451 452
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
453 454 455
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
J
JinHai-CN 已提交
456

D
del-zhenwu 已提交
457
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
458
    def test_search_ip_index_params_partition(self, connect, ip_collection, get_simple_index):
Z
zhenwu 已提交
459 460 461 462 463
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
464
        top_k = 10
465 466 467
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(index_param)
D
del-zhenwu 已提交
468 469 470
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
            pytest.skip("rnsg not support in ip, skip pq")

X
Xiaohai Xu 已提交
471 472 473
        status = connect.create_partition(ip_collection, tag)
        vectors, ids = self.init_data(connect, ip_collection)
        status = connect.create_index(ip_collection, index_type, index_param)
Z
zhenwu 已提交
474
        query_vec = [vectors[0]]
475
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
476
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
Z
zhenwu 已提交
477 478
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
479 480 481
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
D
del-zhenwu 已提交
482
        status, result = connect.search(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
483 484 485 486
        logging.getLogger().info(result)
        assert status.OK()
        assert len(result) == 0

D
del-zhenwu 已提交
487
    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
488
    def test_search_ip_index_params_partition_A(self, connect, ip_collection, get_simple_index):
Z
zhenwu 已提交
489 490 491 492 493
        '''
        target: test basic search fuction, all the search params is corrent, test all index params, and build
        method: search with the given vectors and tag, check the result
        expected: search status ok, and the length of the result is top_k
        '''
D
del-zhenwu 已提交
494
        top_k = 10
495 496 497
        index_param = get_simple_index["index_param"]
        index_type = get_simple_index["index_type"]
        logging.getLogger().info(index_param)
D
del-zhenwu 已提交
498 499 500
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
            pytest.skip("rnsg not support in ip, skip pq")

X
Xiaohai Xu 已提交
501 502 503
        status = connect.create_partition(ip_collection, tag)
        vectors, ids = self.init_data(connect, ip_collection, partition_tags=tag)
        status = connect.create_index(ip_collection, index_type, index_param)
Z
zhenwu 已提交
504
        query_vec = [vectors[0]]
505
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
506
        status, result = connect.search(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param)
Z
zhenwu 已提交
507 508
        logging.getLogger().info(result)
        assert status.OK()
D
del-zhenwu 已提交
509 510 511
        assert len(result[0]) == min(len(vectors), top_k)
        assert check_result(result[0], ids[0])
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
Z
zhenwu 已提交
512

D
del-zhenwu 已提交
513 514 515 516 517 518 519 520 521 522 523
    @pytest.mark.level(2)
    def test_search_vectors_without_connect(self, dis_connect, collection):
        '''
        target: test search vectors without connection
        method: use dis connected instance, call search method and check if search successfully
        expected: raise exception
        '''
        query_vectors = [vectors[0]]
        nprobe = 1
        with pytest.raises(Exception) as e:
            status, ids = dis_connect.search(collection, top_k, query_vectors)
J
JinHai-CN 已提交
524

X
Xiaohai Xu 已提交
525
    def test_search_collection_name_not_existed(self, connect, collection):
J
JinHai-CN 已提交
526
        '''
X
Xiaohai Xu 已提交
527 528
        target: search collection not existed
        method: search with the random collection_name, which is not in db
J
JinHai-CN 已提交
529 530
        expected: status not ok
        '''
X
Xiaohai Xu 已提交
531
        collection_name = gen_unique_str("not_existed_collection")
J
JinHai-CN 已提交
532 533
        nprobe = 1
        query_vecs = [vectors[0]]
D
del-zhenwu 已提交
534
        status, result = connect.search(collection_name, top_k, query_vecs)
J
JinHai-CN 已提交
535 536
        assert not status.OK()

X
Xiaohai Xu 已提交
537
    def test_search_collection_name_None(self, connect, collection):
J
JinHai-CN 已提交
538
        '''
X
Xiaohai Xu 已提交
539 540
        target: search collection that collection name is None
        method: search with the collection_name: None
J
JinHai-CN 已提交
541 542
        expected: status not ok
        '''
X
Xiaohai Xu 已提交
543
        collection_name = None
J
JinHai-CN 已提交
544 545 546
        nprobe = 1
        query_vecs = [vectors[0]]
        with pytest.raises(Exception) as e: 
D
del-zhenwu 已提交
547
            status, result = connect.search(collection_name, top_k, query_vecs)
J
JinHai-CN 已提交
548

X
Xiaohai Xu 已提交
549
    def test_search_top_k_query_records(self, connect, collection):
J
JinHai-CN 已提交
550 551 552 553 554 555
        '''
        target: test search fuction, with search params: query_records
        method: search with the given query_records, which are subarrays of the inserted vectors
        expected: status ok and the returned vectors should be query_records
        '''
        top_k = 10
X
Xiaohai Xu 已提交
556
        vectors, ids = self.init_data(connect, collection)
J
JinHai-CN 已提交
557
        query_vecs = [vectors[0],vectors[55],vectors[99]]
D
del-zhenwu 已提交
558
        status, result = connect.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
559 560 561 562 563 564
        assert status.OK()
        assert len(result) == len(query_vecs)
        for i in range(len(query_vecs)):
            assert len(result[i]) == top_k
            assert result[i][0].distance <= epsilon

X
Xiaohai Xu 已提交
565
    def test_search_distance_l2_flat_index(self, connect, collection):
J
JinHai-CN 已提交
566
        '''
X
Xiaohai Xu 已提交
567
        target: search collection, and check the result: distance
J
JinHai-CN 已提交
568 569 570 571
        method: compare the return distance value with value computed with Euclidean
        expected: the return distance equals to the computed value
        '''
        nb = 2
X
Xiaohai Xu 已提交
572
        vectors, ids = self.init_data(connect, collection, nb=nb)
J
JinHai-CN 已提交
573 574 575
        query_vecs = [[0.50 for i in range(dim)]]
        distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
        distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
D
del-zhenwu 已提交
576
        status, result = connect.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
577 578
        assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)

X
Xiaohai Xu 已提交
579
    def test_search_distance_ip_flat_index(self, connect, ip_collection):
J
JinHai-CN 已提交
580
        '''
X
Xiaohai Xu 已提交
581
        target: search ip_collection, and check the result: distance
J
JinHai-CN 已提交
582 583 584 585 586
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        nb = 2
        nprobe = 1
X
Xiaohai Xu 已提交
587
        vectors, ids = self.init_data(connect, ip_collection, nb=nb)
588 589
        index_type = IndexType.FLAT
        index_param = {
J
JinHai-CN 已提交
590 591
            "nlist": 16384
        }
X
Xiaohai Xu 已提交
592
        connect.create_index(ip_collection, index_type, index_param)
D
del-zhenwu 已提交
593
        logging.getLogger().info(connect.get_index_info(ip_collection))
J
JinHai-CN 已提交
594 595 596
        query_vecs = [[0.50 for i in range(dim)]]
        distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
        distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
597
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
598
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
J
JinHai-CN 已提交
599 600
        assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)

X
Xiaohai Xu 已提交
601
    def test_search_distance_jaccard_flat_index(self, connect, jac_collection):
G
groot 已提交
602
        '''
X
Xiaohai Xu 已提交
603
        target: search ip_collection, and check the result: distance
G
groot 已提交
604 605 606 607 608
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
X
Xiaohai Xu 已提交
609
        int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2)
610 611
        index_type = IndexType.FLAT
        index_param = {
G
groot 已提交
612 613
            "nlist": 16384
        }
X
Xiaohai Xu 已提交
614
        connect.create_index(jac_collection, index_type, index_param)
D
del-zhenwu 已提交
615 616
        logging.getLogger().info(connect.get_collection_info(jac_collection))
        logging.getLogger().info(connect.get_index_info(jac_collection))
X
Xiaohai Xu 已提交
617
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_collection, nb=1, insert=False)
G
groot 已提交
618 619
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
620
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
621
        status, result = connect.search(jac_collection, top_k, query_vecs, params=search_param)
G
groot 已提交
622 623 624 625
        logging.getLogger().info(status)
        logging.getLogger().info(result)
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon

X
Xiaohai Xu 已提交
626
    def test_search_distance_hamming_flat_index(self, connect, ham_collection):
G
groot 已提交
627
        '''
X
Xiaohai Xu 已提交
628
        target: search ip_collection, and check the result: distance
G
groot 已提交
629 630 631 632 633
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
X
Xiaohai Xu 已提交
634
        int_vectors, vectors, ids = self.init_binary_data(connect, ham_collection, nb=2)
635 636
        index_type = IndexType.FLAT
        index_param = {
G
groot 已提交
637 638
            "nlist": 16384
        }
X
Xiaohai Xu 已提交
639
        connect.create_index(ham_collection, index_type, index_param)
D
del-zhenwu 已提交
640 641
        logging.getLogger().info(connect.get_collection_info(ham_collection))
        logging.getLogger().info(connect.get_index_info(ham_collection))
X
Xiaohai Xu 已提交
642
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, ham_collection, nb=1, insert=False)
G
groot 已提交
643 644
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
645
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
646
        status, result = connect.search(ham_collection, top_k, query_vecs, params=search_param)
G
groot 已提交
647 648 649 650
        logging.getLogger().info(status)
        logging.getLogger().info(result)
        assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon

D
del-zhenwu 已提交
651 652 653 654 655 656 657 658 659 660 661 662 663 664
    def test_search_distance_substructure_flat_index(self, connect, substructure_collection):
        '''
        target: search ip_collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
        index_type = IndexType.FLAT
        index_param = {
            "nlist": 16384
        }
        connect.create_index(substructure_collection, index_type, index_param)
D
del-zhenwu 已提交
665 666
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
        logging.getLogger().info(connect.get_index_info(substructure_collection))
D
del-zhenwu 已提交
667 668 669 670
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False)
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
671
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
D
del-zhenwu 已提交
672 673
        logging.getLogger().info(status)
        logging.getLogger().info(result)
D
del-zhenwu 已提交
674
        assert len(result[0]) == 0
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690

    def test_search_distance_substructure_flat_index_B(self, connect, substructure_collection):
        '''
        target: search ip_collection, and check the result: distance
        method: compare the return distance value with value computed with SUB 
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        top_k = 3
        nprobe = 512
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
        index_type = IndexType.FLAT
        index_param = {
            "nlist": 16384
        }
        connect.create_index(substructure_collection, index_type, index_param)
D
del-zhenwu 已提交
691 692
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
        logging.getLogger().info(connect.get_index_info(substructure_collection))
693 694
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
695
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
696 697
        logging.getLogger().info(status)
        logging.getLogger().info(result) 
D
del-zhenwu 已提交
698 699
        assert len(result[0]) == 1
        assert len(result[1]) == 1
700 701 702 703
        assert result[0][0].distance <= epsilon
        assert result[0][0].id == ids[0]
        assert result[1][0].distance <= epsilon
        assert result[1][0].id == ids[1]
D
del-zhenwu 已提交
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718

    def test_search_distance_superstructure_flat_index(self, connect, superstructure_collection):
        '''
        target: search ip_collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
        index_type = IndexType.FLAT
        index_param = {
            "nlist": 16384
        }
        connect.create_index(superstructure_collection, index_type, index_param)
D
del-zhenwu 已提交
719 720
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
D
del-zhenwu 已提交
721 722 723 724
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False)
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
725
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
D
del-zhenwu 已提交
726 727
        logging.getLogger().info(status)
        logging.getLogger().info(result)
D
del-zhenwu 已提交
728
        assert len(result[0]) == 0
729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744

    def test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection):
        '''
        target: search ip_collection, and check the result: distance
        method: compare the return distance value with value computed with SUPER
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        top_k = 3
        nprobe = 512
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
        index_type = IndexType.FLAT
        index_param = {
            "nlist": 16384
        }
        connect.create_index(superstructure_collection, index_type, index_param)
D
del-zhenwu 已提交
745 746
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
747 748
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
749
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
750 751
        logging.getLogger().info(status)
        logging.getLogger().info(result)
D
del-zhenwu 已提交
752 753
        assert len(result[0]) == 2
        assert len(result[1]) == 2
754 755 756 757
        assert result[0][0].id in ids
        assert result[0][0].distance <= epsilon
        assert result[1][0].id in ids
        assert result[1][0].distance <= epsilon
D
del-zhenwu 已提交
758

X
Xiaohai Xu 已提交
759
    def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection):
G
groot 已提交
760
        '''
X
Xiaohai Xu 已提交
761
        target: search ip_collection, and check the result: distance
G
groot 已提交
762 763 764 765 766
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        # from scipy.spatial import distance
        nprobe = 512
X
Xiaohai Xu 已提交
767
        int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_collection, nb=2)
768 769
        index_type = IndexType.FLAT
        index_param = {
G
groot 已提交
770 771
            "nlist": 16384
        }
X
Xiaohai Xu 已提交
772
        connect.create_index(tanimoto_collection, index_type, index_param)
D
del-zhenwu 已提交
773 774
        logging.getLogger().info(connect.get_collection_info(tanimoto_collection))
        logging.getLogger().info(connect.get_index_info(tanimoto_collection))
X
Xiaohai Xu 已提交
775
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, tanimoto_collection, nb=1, insert=False)
G
groot 已提交
776 777
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
778
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
779
        status, result = connect.search(tanimoto_collection, top_k, query_vecs, params=search_param)
G
groot 已提交
780 781 782 783
        logging.getLogger().info(status)
        logging.getLogger().info(result)
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon

X
Xiaohai Xu 已提交
784
    def test_search_distance_ip_index_params(self, connect, ip_collection, get_index):
J
JinHai-CN 已提交
785
        '''
X
Xiaohai Xu 已提交
786
        target: search collection, and check the result: distance
J
JinHai-CN 已提交
787 788 789 790 791
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        top_k = 2
        nprobe = 1
792 793
        index_param = get_index["index_param"]
        index_type = get_index["index_type"]
D
del-zhenwu 已提交
794 795 796
        if index_type == IndexType.RNSG:
            pytest.skip("rnsg not support in ip")
        vectors, ids = self.init_data(connect, ip_collection, nb=2)
X
Xiaohai Xu 已提交
797
        connect.create_index(ip_collection, index_type, index_param)
D
del-zhenwu 已提交
798
        logging.getLogger().info(connect.get_index_info(ip_collection))
J
JinHai-CN 已提交
799
        query_vecs = [[0.50 for i in range(dim)]]
800
        search_param = get_search_param(index_type)
D
del-zhenwu 已提交
801
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
G
groot 已提交
802 803
        logging.getLogger().debug(status)
        logging.getLogger().debug(result)
J
JinHai-CN 已提交
804 805 806 807 808 809 810
        distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
        distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
        assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)

    # TODO: enable
    # @pytest.mark.repeat(5)
    @pytest.mark.timeout(30)
X
Xiaohai Xu 已提交
811 812
    def _test_search_concurrent(self, connect, collection):
        vectors, ids = self.init_data(connect, collection)
J
JinHai-CN 已提交
813 814 815 816 817 818
        thread_num = 10
        nb = 100
        top_k = 10
        threads = []
        query_vecs = vectors[nb//2:nb]
        def search():
D
del-zhenwu 已提交
819
            status, result = connect.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
820 821 822 823 824 825 826 827 828 829 830
            assert len(result) == len(query_vecs)
            for i in range(len(query_vecs)):
                assert result[i][0].id in ids
                assert result[i][0].distance == 0.0
        for i in range(thread_num):
            x = threading.Thread(target=search, args=())
            threads.append(x)
            x.start()
        for th in threads:
            th.join()

D
del-zhenwu 已提交
831
    @pytest.mark.level(2)
832 833 834 835 836 837 838 839 840 841 842
    @pytest.mark.timeout(30)
    def test_search_concurrent_multithreads(self, args):
        '''
        target: test concurrent search with multiprocessess
        method: search with 10 processes, each process uses dependent connection
        expected: status ok and the returned vectors should be query_records
        '''
        nb = 100
        top_k = 10
        threads_num = 4
        threads = []
X
Xiaohai Xu 已提交
843
        collection = gen_unique_str("test_search_concurrent_multiprocessing")
844
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
X
Xiaohai Xu 已提交
845
        param = {'collection_name': collection,
846 847 848
                 'dimension': dim,
                 'index_type': IndexType.FLAT,
                 'store_raw_vector': False}
X
Xiaohai Xu 已提交
849
        # create collection
D
del-zhenwu 已提交
850
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
X
Xiaohai Xu 已提交
851 852
        milvus.create_collection(param)
        vectors, ids = self.init_data(milvus, collection, nb=nb)
853 854
        query_vecs = vectors[nb//2:nb]
        def search(milvus):
D
del-zhenwu 已提交
855
            status, result = milvus.search(collection, top_k, query_vecs)
856 857 858 859 860 861
            assert len(result) == len(query_vecs)
            for i in range(len(query_vecs)):
                assert result[i][0].id in ids
                assert result[i][0].distance == 0.0

        for i in range(threads_num):
D
del-zhenwu 已提交
862
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
863 864 865 866 867 868 869
            t = threading.Thread(target=search, args=(milvus, ))
            threads.append(t)
            t.start()
            time.sleep(0.2)
        for t in threads:
            t.join()

J
JinHai-CN 已提交
870 871 872 873 874 875 876 877 878 879 880 881
    # TODO: enable
    @pytest.mark.timeout(30)
    def _test_search_concurrent_multiprocessing(self, args):
        '''
        target: test concurrent search with multiprocessess
        method: search with 10 processes, each process uses dependent connection
        expected: status ok and the returned vectors should be query_records
        '''
        nb = 100
        top_k = 10
        process_num = 4
        processes = []
X
Xiaohai Xu 已提交
882
        collection = gen_unique_str("test_search_concurrent_multiprocessing")
J
JinHai-CN 已提交
883
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
X
Xiaohai Xu 已提交
884
        param = {'collection_name': collection,
J
JinHai-CN 已提交
885 886 887
             'dimension': dim,
             'index_type': IndexType.FLAT,
             'store_raw_vector': False}
X
Xiaohai Xu 已提交
888
        # create collection
D
del-zhenwu 已提交
889
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
X
Xiaohai Xu 已提交
890 891
        milvus.create_collection(param)
        vectors, ids = self.init_data(milvus, collection, nb=nb)
J
JinHai-CN 已提交
892 893
        query_vecs = vectors[nb//2:nb]
        def search(milvus):
D
del-zhenwu 已提交
894
            status, result = milvus.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
895 896 897 898 899 900
            assert len(result) == len(query_vecs)
            for i in range(len(query_vecs)):
                assert result[i][0].id in ids
                assert result[i][0].distance == 0.0

        for i in range(process_num):
D
del-zhenwu 已提交
901
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
J
JinHai-CN 已提交
902 903 904 905 906 907 908
            p = Process(target=search, args=(milvus, ))
            processes.append(p)
            p.start()
            time.sleep(0.2)
        for p in processes:
            p.join()

X
Xiaohai Xu 已提交
909
    def test_search_multi_collection_L2(search, args):
J
JinHai-CN 已提交
910
        '''
X
Xiaohai Xu 已提交
911 912
        target: test search multi collections of L2
        method: add vectors into 10 collections, and search
J
JinHai-CN 已提交
913 914 915 916
        expected: search status ok, the length of result
        '''
        num = 10
        top_k = 10
X
Xiaohai Xu 已提交
917
        collections = []
J
JinHai-CN 已提交
918 919
        idx = []
        for i in range(num):
X
Xiaohai Xu 已提交
920
            collection = gen_unique_str("test_add_multicollection_%d" % i)
J
JinHai-CN 已提交
921
            uri = "tcp://%s:%s" % (args["ip"], args["port"])
X
Xiaohai Xu 已提交
922
            param = {'collection_name': collection,
J
JinHai-CN 已提交
923 924 925
                     'dimension': dim,
                     'index_file_size': 10,
                     'metric_type': MetricType.L2}
X
Xiaohai Xu 已提交
926
            # create collection
D
del-zhenwu 已提交
927
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
X
Xiaohai Xu 已提交
928
            milvus.create_collection(param)
D
del-zhenwu 已提交
929
            status, ids = milvus.insert(collection, vectors)
J
JinHai-CN 已提交
930 931
            assert status.OK()
            assert len(ids) == len(vectors)
X
Xiaohai Xu 已提交
932
            collections.append(collection)
J
JinHai-CN 已提交
933 934 935
            idx.append(ids[0])
            idx.append(ids[10])
            idx.append(ids[20])
936
            milvus.flush([collection])
J
JinHai-CN 已提交
937
        query_vecs = [vectors[0], vectors[10], vectors[20]]
X
Xiaohai Xu 已提交
938
        # start query from random collection
J
JinHai-CN 已提交
939
        for i in range(num):
X
Xiaohai Xu 已提交
940
            collection = collections[i]
D
del-zhenwu 已提交
941
            status, result = milvus.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
942 943 944 945 946 947 948
            assert status.OK()
            assert len(result) == len(query_vecs)
            for j in range(len(query_vecs)):
                assert len(result[j]) == top_k
            for j in range(len(query_vecs)):
                assert check_result(result[j], idx[3 * i + j])

X
Xiaohai Xu 已提交
949
    def test_search_multi_collection_IP(search, args):
J
JinHai-CN 已提交
950
        '''
X
Xiaohai Xu 已提交
951 952
        target: test search multi collections of IP
        method: add vectors into 10 collections, and search
J
JinHai-CN 已提交
953 954 955 956
        expected: search status ok, the length of result
        '''
        num = 10
        top_k = 10
X
Xiaohai Xu 已提交
957
        collections = []
J
JinHai-CN 已提交
958 959
        idx = []
        for i in range(num):
X
Xiaohai Xu 已提交
960
            collection = gen_unique_str("test_add_multicollection_%d" % i)
J
JinHai-CN 已提交
961
            uri = "tcp://%s:%s" % (args["ip"], args["port"])
X
Xiaohai Xu 已提交
962
            param = {'collection_name': collection,
J
JinHai-CN 已提交
963 964 965
                     'dimension': dim,
                     'index_file_size': 10,
                     'metric_type': MetricType.L2}
X
Xiaohai Xu 已提交
966
            # create collection
D
del-zhenwu 已提交
967
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
X
Xiaohai Xu 已提交
968
            milvus.create_collection(param)
D
del-zhenwu 已提交
969
            status, ids = milvus.insert(collection, vectors)
J
JinHai-CN 已提交
970 971
            assert status.OK()
            assert len(ids) == len(vectors)
X
Xiaohai Xu 已提交
972
            collections.append(collection)
J
JinHai-CN 已提交
973 974 975
            idx.append(ids[0])
            idx.append(ids[10])
            idx.append(ids[20])
976
            milvus.flush([collection])
J
JinHai-CN 已提交
977
        query_vecs = [vectors[0], vectors[10], vectors[20]]
X
Xiaohai Xu 已提交
978
        # start query from random collection
J
JinHai-CN 已提交
979
        for i in range(num):
X
Xiaohai Xu 已提交
980
            collection = collections[i]
D
del-zhenwu 已提交
981
            status, result = milvus.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
982 983 984 985 986 987 988 989 990
            assert status.OK()
            assert len(result) == len(query_vecs)
            for j in range(len(query_vecs)):
                assert len(result[j]) == top_k
            for j in range(len(query_vecs)):
                assert check_result(result[j], idx[3 * i + j])
"""
******************************************************************
#  The following cases are used to test `search_vectors` function 
X
Xiaohai Xu 已提交
991
#  with invalid collection_name top-k / nprobe / query_range
J
JinHai-CN 已提交
992 993 994 995
******************************************************************
"""

class TestSearchParamsInvalid(object):
Z
zhenwu 已提交
996
    nlist = 16384
997 998
    index_type = IndexType.IVF_SQ8
    index_param = {"nlist": nlist}
Z
zhenwu 已提交
999
    logging.getLogger().info(index_param)
J
JinHai-CN 已提交
1000

X
Xiaohai Xu 已提交
1001
    def init_data(self, connect, collection, nb=6000):
J
JinHai-CN 已提交
1002
        '''
X
Xiaohai Xu 已提交
1003
        Generate vectors and add it in collection, before search vectors
J
JinHai-CN 已提交
1004 1005
        '''
        global vectors
G
groot 已提交
1006
        if nb == 6000:
D
del-zhenwu 已提交
1007
            insert = vectors
J
JinHai-CN 已提交
1008
        else:  
D
del-zhenwu 已提交
1009 1010
            insert = gen_vectors(nb, dim)
        status, ids = connect.insert(collection, insert)
J
JinHai-CN 已提交
1011
        sleep(add_interval_time)
D
del-zhenwu 已提交
1012
        return insert, ids
J
JinHai-CN 已提交
1013 1014

    """
X
Xiaohai Xu 已提交
1015
    Test search collection with invalid collection names
J
JinHai-CN 已提交
1016 1017 1018
    """
    @pytest.fixture(
        scope="function",
X
Xiaohai Xu 已提交
1019
        params=gen_invalid_collection_names()
J
JinHai-CN 已提交
1020
    )
X
Xiaohai Xu 已提交
1021
    def get_collection_name(self, request):
J
JinHai-CN 已提交
1022 1023 1024
        yield request.param

    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
1025 1026 1027
    def test_search_with_invalid_collectionname(self, connect, get_collection_name):
        collection_name = get_collection_name
        logging.getLogger().info(collection_name)
J
JinHai-CN 已提交
1028 1029
        nprobe = 1 
        query_vecs = gen_vectors(1, dim)
D
del-zhenwu 已提交
1030
        status, result = connect.search(collection_name, top_k, query_vecs)
J
JinHai-CN 已提交
1031 1032
        assert not status.OK()

Z
zhenwu 已提交
1033
    @pytest.mark.level(1)
X
Xiaohai Xu 已提交
1034
    def test_search_with_invalid_tag_format(self, connect, collection):
Z
zhenwu 已提交
1035 1036 1037
        nprobe = 1 
        query_vecs = gen_vectors(1, dim)
        with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1038
            status, result = connect.search(collection, top_k, query_vecs, partition_tags="tag")
1039 1040 1041 1042 1043 1044
            logging.getLogger().debug(result)

    @pytest.mark.level(1)
    def test_search_with_tag_not_existed(self, connect, collection):
        nprobe = 1
        query_vecs = gen_vectors(1, dim)
D
del-zhenwu 已提交
1045
        status, result = connect.search(collection, top_k, query_vecs, partition_tags=["tag"])
1046 1047
        logging.getLogger().info(result)
        assert not status.OK()
Z
zhenwu 已提交
1048

J
JinHai-CN 已提交
1049
    """
X
Xiaohai Xu 已提交
1050
    Test search collection with invalid top-k
J
JinHai-CN 已提交
1051 1052 1053 1054 1055 1056 1057 1058
    """
    @pytest.fixture(
        scope="function",
        params=gen_invalid_top_ks()
    )
    def get_top_k(self, request):
        yield request.param

Z
zhenwu 已提交
1059
    @pytest.mark.level(1)
X
Xiaohai Xu 已提交
1060
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
J
JinHai-CN 已提交
1061 1062 1063 1064 1065 1066 1067 1068 1069
        '''
        target: test search fuction, with the wrong top_k
        method: search with top_k
        expected: raise an error, and the connection is normal
        '''
        top_k = get_top_k
        logging.getLogger().info(top_k)
        nprobe = 1
        query_vecs = gen_vectors(1, dim)
Z
zhenwu 已提交
1070
        if isinstance(top_k, int):
D
del-zhenwu 已提交
1071
            status, result = connect.search(collection, top_k, query_vecs)
Z
zhenwu 已提交
1072 1073 1074
            assert not status.OK()
        else:
            with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1075
                status, result = connect.search(collection, top_k, query_vecs)
J
JinHai-CN 已提交
1076 1077

    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
1078
    def test_search_with_invalid_top_k_ip(self, connect, ip_collection, get_top_k):
J
JinHai-CN 已提交
1079 1080 1081 1082 1083 1084 1085 1086 1087
        '''
        target: test search fuction, with the wrong top_k
        method: search with top_k
        expected: raise an error, and the connection is normal
        '''
        top_k = get_top_k
        logging.getLogger().info(top_k)
        nprobe = 1
        query_vecs = gen_vectors(1, dim)
Z
zhenwu 已提交
1088
        if isinstance(top_k, int):
D
del-zhenwu 已提交
1089
            status, result = connect.search(ip_collection, top_k, query_vecs)
Z
zhenwu 已提交
1090 1091 1092
            assert not status.OK()
        else:
            with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1093
                status, result = connect.search(ip_collection, top_k, query_vecs)
J
JinHai-CN 已提交
1094
    """
X
Xiaohai Xu 已提交
1095
    Test search collection with invalid nprobe
J
JinHai-CN 已提交
1096 1097 1098 1099 1100 1101 1102 1103
    """
    @pytest.fixture(
        scope="function",
        params=gen_invalid_nprobes()
    )
    def get_nprobes(self, request):
        yield request.param

Z
zhenwu 已提交
1104
    @pytest.mark.level(1)
X
Xiaohai Xu 已提交
1105
    def test_search_with_invalid_nprobe(self, connect, collection, get_nprobes):
J
JinHai-CN 已提交
1106
        '''
1107 1108
        target: test search fuction, with the wrong nprobe
        method: search with nprobe
J
JinHai-CN 已提交
1109 1110
        expected: raise an error, and the connection is normal
        '''
1111 1112
        index_type = IndexType.IVF_SQ8
        index_param = {"nlist": 16384}
X
Xiaohai Xu 已提交
1113
        connect.create_index(collection, index_type, index_param)
J
JinHai-CN 已提交
1114
        nprobe = get_nprobes
1115
        search_param = {"nprobe": nprobe}
J
JinHai-CN 已提交
1116 1117
        logging.getLogger().info(nprobe)
        query_vecs = gen_vectors(1, dim)
1118
        # if isinstance(nprobe, int):
D
del-zhenwu 已提交
1119
        status, result = connect.search(collection, top_k, query_vecs, params=search_param)
1120 1121 1122
        assert not status.OK()
        # else:
        #     with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1123
        #         status, result = connect.search(collection, top_k, query_vecs, params=search_param)
J
JinHai-CN 已提交
1124 1125

    @pytest.mark.level(2)
X
Xiaohai Xu 已提交
1126
    def test_search_with_invalid_nprobe_ip(self, connect, ip_collection, get_nprobes):
J
JinHai-CN 已提交
1127 1128 1129 1130 1131
        '''
        target: test search fuction, with the wrong top_k
        method: search with top_k
        expected: raise an error, and the connection is normal
        '''
1132 1133
        index_type = IndexType.IVF_SQ8
        index_param = {"nlist": 16384}
X
Xiaohai Xu 已提交
1134
        connect.create_index(ip_collection, index_type, index_param)
J
JinHai-CN 已提交
1135
        nprobe = get_nprobes
1136
        search_param = {"nprobe": nprobe}
J
JinHai-CN 已提交
1137 1138
        logging.getLogger().info(nprobe)
        query_vecs = gen_vectors(1, dim)
1139 1140

        # if isinstance(nprobe, int):
D
del-zhenwu 已提交
1141
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
1142 1143 1144
        assert not status.OK()
        # else:
        #     with pytest.raises(Exception) as e:
D
del-zhenwu 已提交
1145
        #         status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        if str(connect._cmd("mode")[1]) == "CPU":
            if request.param["index_type"] == IndexType.IVF_SQ8H:
                pytest.skip("sq8h not support in CPU mode")
        if str(connect._cmd("mode")[1]) == "GPU":
            if request.param["index_type"] == IndexType.IVF_PQ:
                pytest.skip("ivfpq not support in GPU mode")
        return request.param

1160
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
1161 1162 1163 1164 1165
        '''
        target: test search fuction, with empty search params
        method: search with params
        expected: search status not ok, and the connection is normal
        '''
1166 1167
        if args["handler"] == "HTTP":
            pytest.skip("skip in http mode")
1168 1169
        index_type = get_simple_index["index_type"]
        index_param = get_simple_index["index_param"]
X
Xiaohai Xu 已提交
1170
        connect.create_index(collection, index_type, index_param)
1171
        query_vecs = gen_vectors(1, dim)
D
del-zhenwu 已提交
1172
        status, result = connect.search(collection, top_k, query_vecs, params={})
1173 1174 1175

        if index_type == IndexType.FLAT:
            assert status.OK()
J
JinHai-CN 已提交
1176
        else:
1177 1178 1179 1180 1181 1182
            assert not status.OK()

    @pytest.fixture(
        scope="function",
        params=gen_invaild_search_params()
    )
D
del-zhenwu 已提交
1183
    def get_invalid_search_param(self, request, connect):
1184 1185 1186 1187 1188 1189 1190 1191
        if str(connect._cmd("mode")[1]) == "CPU":
            if request.param["index_type"] == IndexType.IVF_SQ8H:
                pytest.skip("sq8h not support in CPU mode")
        if str(connect._cmd("mode")[1]) == "GPU":
            if request.param["index_type"] == IndexType.IVF_PQ:
                pytest.skip("ivfpq not support in GPU mode")
        return request.param

D
del-zhenwu 已提交
1192
    def test_search_with_invalid_params(self, connect, collection, get_invalid_search_param):
1193 1194 1195 1196 1197
        '''
        target: test search fuction, with invalid search params
        method: search with params
        expected: search status not ok, and the connection is normal
        '''
D
del-zhenwu 已提交
1198 1199 1200 1201 1202
        index_type = get_invalid_search_param["index_type"]
        search_param = get_invalid_search_param["search_param"]
        for index in gen_simple_index():
            if index_type == index["index_type"]:
                connect.create_index(collection, index_type, index["index_param"])
1203
        query_vecs = gen_vectors(1, dim)
D
del-zhenwu 已提交
1204
        status, result = connect.search(collection, top_k, query_vecs, params=search_param)
1205
        assert not status.OK()
J
JinHai-CN 已提交
1206 1207 1208 1209 1210

def check_result(result, id):
    if len(result) >= 5:
        return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
    else:
1211
        return id in (i.id for i in result)