test_collection_count.py 21.9 KB
Newer Older
X
Xiaohai Xu 已提交
1
import pdb
2
import copy
X
Xiaohai Xu 已提交
3 4 5 6 7
import logging
import itertools
from time import sleep
import threading
from multiprocessing import Process
8 9 10
import sklearn.preprocessing

import pytest
X
Xiaohai Xu 已提交
11 12 13
from utils import *

nb = 6000
14 15 16 17
dim = 128
tag = "tag"
collection_id = "count_collection"
add_interval_time = 3
18
segment_row_count = 5000
19 20 21 22 23 24
default_fields = gen_default_fields() 
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
field_name = "fload_vector"
index_name = "index_name"

X
Xiaohai Xu 已提交
25 26 27 28 29 30 31 32 33

class TestCollectionCount:
    """
    params means different nb, the nb value may trigger merge, or not
    """
    @pytest.fixture(
        scope="function",
        params=[
            1,
34
            4000,
35
            6001
X
Xiaohai Xu 已提交
36 37
        ],
    )
38
    def insert_count(self, request):
X
Xiaohai Xu 已提交
39 40 41 42 43 44 45 46 47 48 49
        yield request.param

    """
    generate valid create_index params
    """
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        if str(connect._cmd("mode")[1]) == "CPU":
50
            if request.param["index_type"] in index_cpu_not_support():
X
Xiaohai Xu 已提交
51
                pytest.skip("sq8h not support in cpu mode")
52
        request.param.update({"metric_type": "L2"})
X
Xiaohai Xu 已提交
53 54
        return request.param

55
    def test_collection_count(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
56 57 58
        '''
        target: test collection rows_count is correct or not
        method: create collection and add vectors in it,
D
del-zhenwu 已提交
59
            assert the value returned by count_entities method is equal to length of vectors
X
Xiaohai Xu 已提交
60 61
        expected: the count is equal to the length of vectors
        '''
62 63
        entities = gen_entities(insert_count)
        res = connect.insert(collection, entities)
X
Xiaohai Xu 已提交
64
        connect.flush([collection])
65 66
        res = connect.count_entities(collection)
        assert res == insert_count
X
Xiaohai Xu 已提交
67

68
    def test_collection_count_partition(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
69 70 71
        '''
        target: test collection rows_count is correct or not
        method: create collection, create partition and add vectors in it,
D
del-zhenwu 已提交
72
            assert the value returned by count_entities method is equal to length of vectors
X
Xiaohai Xu 已提交
73 74
        expected: the count is equal to the length of vectors
        '''
75 76 77
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        res_ids = connect.insert(collection, entities, partition_tag=tag)
X
Xiaohai Xu 已提交
78
        connect.flush([collection])
79 80
        res = connect.count_entities(collection)
        assert res == insert_count
X
Xiaohai Xu 已提交
81

82
    def test_collection_count_multi_partitions_A(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
83 84
        '''
        target: test collection rows_count is correct or not
85 86 87
        method: create collection, create partitions and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
X
Xiaohai Xu 已提交
88 89
        '''
        new_tag = "new_tag"
90 91 92 93
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        res_ids = connect.insert(collection, entities)
X
Xiaohai Xu 已提交
94
        connect.flush([collection])
95 96
        res = connect.count_entities(collection)
        assert res == insert_count
X
Xiaohai Xu 已提交
97

98
    def test_collection_count_multi_partitions_B(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
99 100
        '''
        target: test collection rows_count is correct or not
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        new_tag = "new_tag"
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        res_ids = connect.insert(collection, entities, partition_tag=tag)
        connect.flush([collection])
        res = connect.count_entities(collection)
        assert res == insert_count

    def test_collection_count_multi_partitions_C(self, connect, collection, insert_count):
        '''
        target: test collection rows_count is correct or not
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
X
Xiaohai Xu 已提交
119 120 121
        expected: the count is equal to the length of vectors
        '''
        new_tag = "new_tag"
122 123 124 125 126
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        res_ids = connect.insert(collection, entities)
        res_ids_2 = connect.insert(collection, entities, partition_tag=tag)
X
Xiaohai Xu 已提交
127
        connect.flush([collection])
128 129
        res = connect.count_entities(collection)
        assert res == insert_count * 2
X
Xiaohai Xu 已提交
130

131
    def test_collection_count_multi_partitions_D(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
132 133
        '''
        target: test collection rows_count is correct or not
134 135 136
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the collection count is equal to the length of entities
X
Xiaohai Xu 已提交
137 138
        '''
        new_tag = "new_tag"
139 140 141 142 143
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        res_ids = connect.insert(collection, entities, partition_tag=tag)
        res_ids2 = connect.insert(collection, entities, partition_tag=new_tag)
X
Xiaohai Xu 已提交
144
        connect.flush([collection])
145 146
        res = connect.count_entities(collection)
        assert res == insert_count * 2
X
Xiaohai Xu 已提交
147

148
    def _test_collection_count_after_index_created(self, connect, collection, get_simple_index, insert_count):
X
Xiaohai Xu 已提交
149
        '''
D
del-zhenwu 已提交
150 151 152
        target: test count_entities, after index have been created
        method: add vectors in db, and create index, then calling count_entities with correct params 
        expected: count_entities raise exception
X
Xiaohai Xu 已提交
153
        '''
154 155
        entities = gen_entities(insert_count)
        res = connect.insert(collection, entities)
X
Xiaohai Xu 已提交
156
        connect.flush([collection])
157
        connect.create_index(collection, field_name, get_simple_index)
158 159 160 161 162 163 164 165 166 167 168 169 170 171
        res = connect.count_entities(collection)
        assert res == insert_count

    @pytest.mark.level(2)
    def test_count_without_connection(self, collection, dis_connect):
        '''
        target: test count_entities, without connection
        method: calling count_entities with correct params, with a disconnected instance
        expected: count_entities raise exception
        '''
        with pytest.raises(Exception) as e:
            dis_connect.count_entities(collection)

    def test_collection_count_no_vectors(self, connect, collection):
X
Xiaohai Xu 已提交
172 173 174
        '''
        target: test collection rows_count is correct or not, if collection is empty
        method: create collection and no vectors in it,
D
del-zhenwu 已提交
175
            assert the value returned by count_entities method is equal to 0
X
Xiaohai Xu 已提交
176
        expected: the count is equal to 0
177 178
        '''    
        res = connect.count_entities(collection)
X
Xiaohai Xu 已提交
179 180 181 182 183 184 185 186 187 188 189
        assert res == 0


class TestCollectionCountIP:
    """
    params means different nb, the nb value may trigger merge, or not
    """
    @pytest.fixture(
        scope="function",
        params=[
            1,
190
            4000,
191
            6001
X
Xiaohai Xu 已提交
192 193
        ],
    )
194
    def insert_count(self, request):
X
Xiaohai Xu 已提交
195 196 197 198 199 200 201 202 203 204 205
        yield request.param

    """
    generate valid create_index params
    """
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        if str(connect._cmd("mode")[1]) == "CPU":
206 207
            if request.param["index_type"] in index_cpu_not_support():
                pytest.skip("sq8h not support in cpu mode")
208
        request.param.update({"metric_type": "IP"})
X
Xiaohai Xu 已提交
209 210
        return request.param

211
    def _test_collection_count_after_index_created(self, connect, collection, get_simple_index, insert_count):
X
Xiaohai Xu 已提交
212
        '''
D
del-zhenwu 已提交
213
        target: test count_entities, after index have been created
214
        method: add vectors in db, and create index, then calling count_entities with correct params 
D
del-zhenwu 已提交
215
        expected: count_entities raise exception
X
Xiaohai Xu 已提交
216
        '''
217
        entities = gen_entities(insert_count)
218 219 220 221
        res = connect.insert(collection, entities)
        connect.flush([collection])
        connect.create_index(collection, field_name, get_simple_index)
        res = connect.count_entities(collection)
222 223
        assert res == insert_count

X
Xiaohai Xu 已提交
224

D
del-zhenwu 已提交
225
class TestCollectionCountBinary:
X
Xiaohai Xu 已提交
226 227 228 229 230 231 232
    """
    params means different nb, the nb value may trigger merge, or not
    """
    @pytest.fixture(
        scope="function",
        params=[
            1,
233
            4000,
234
            6001
X
Xiaohai Xu 已提交
235 236
        ],
    )
237
    def insert_count(self, request):
X
Xiaohai Xu 已提交
238 239
        yield request.param

240 241 242 243 244 245 246 247 248 249 250
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_jaccard_index(self, request, connect):
        if request.param["index_type"] in binary_support():
            request.param["metric_type"] = "JACCARD"
            return request.param
        else:
            pytest.skip("Skip index")

X
Xiaohai Xu 已提交
251 252 253 254 255
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_hamming_index(self, request, connect):
256
        if request.param["index_type"] in binary_support():
257
            request.param["metric_type"] = "HAMMING"
X
Xiaohai Xu 已提交
258 259
            return request.param
        else:
260
            pytest.skip("Skip index")
X
Xiaohai Xu 已提交
261

D
del-zhenwu 已提交
262 263 264 265 266
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_substructure_index(self, request, connect):
267
        if request.param["index_type"] == "FLAT":
268
            request.param["metric_type"] = "SUBSTRUCTURE"
D
del-zhenwu 已提交
269 270
            return request.param
        else:
271
            pytest.skip("Skip index")
D
del-zhenwu 已提交
272 273 274 275 276 277

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_superstructure_index(self, request, connect):
278
        if request.param["index_type"] == "FLAT":
279
            request.param["metric_type"] = "SUPERSTRUCTURE"
D
del-zhenwu 已提交
280 281
            return request.param
        else:
282
            pytest.skip("Skip index")
D
del-zhenwu 已提交
283

284
    def test_collection_count(self, connect, binary_collection, insert_count):
X
Xiaohai Xu 已提交
285 286
        '''
        target: test collection rows_count is correct or not
287 288 289
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
X
Xiaohai Xu 已提交
290
        '''
291
        raw_vectors, entities = gen_binary_entities(insert_count)
292
        res = connect.insert(binary_collection, entities)
293
        logging.getLogger().info(len(res))
294 295
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
296
        assert res == insert_count
X
Xiaohai Xu 已提交
297

298
    def test_collection_count_partition(self, connect, binary_collection, insert_count):
D
del-zhenwu 已提交
299 300
        '''
        target: test collection rows_count is correct or not
301 302 303
        method: create collection, create partition and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
D
del-zhenwu 已提交
304
        '''
305
        raw_vectors, entities = gen_binary_entities(insert_count)
306 307 308 309
        connect.create_partition(binary_collection, tag)
        res_ids = connect.insert(binary_collection, entities, partition_tag=tag)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collections)
310
        assert res == insert_count
D
del-zhenwu 已提交
311

312
    @pytest.mark.level(2)
313
    def test_collection_count_multi_partitions_A(self, connect, binary_collection, insert_count):
D
del-zhenwu 已提交
314 315
        '''
        target: test collection rows_count is correct or not
316 317 318
        method: create collection, create partitions and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
D
del-zhenwu 已提交
319
        '''
320 321
        new_tag = "new_tag"
        raw_vectors, entities = gen_binary_entities(insert_count)
322 323 324 325 326
        connect.create_partition(binary_collection, tag)
        connect.create_partition(binary_collection, new_tag)
        res_ids = connect.insert(binary_collection, entities)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
327
        assert res == insert_count
D
del-zhenwu 已提交
328

329
    @pytest.mark.level(2)
330
    def test_collection_count_multi_partitions_B(self, connect, binary_collection, insert_count):
X
Xiaohai Xu 已提交
331
        '''
332 333 334 335 336 337 338
        target: test collection rows_count is correct or not
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        new_tag = "new_tag"
        raw_vectors, entities = gen_binary_entities(insert_count)
339 340 341 342 343
        connect.create_partition(binary_collection, tag)
        connect.create_partition(binary_collection, new_tag)
        res_ids = connect.insert(binary_collection, entities, partition_tag=tag)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
344 345
        assert res == insert_count

346
    def test_collection_count_multi_partitions_C(self, connect, binary_collection, insert_count):
X
Xiaohai Xu 已提交
347
        '''
348 349 350 351 352 353 354
        target: test collection rows_count is correct or not
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        new_tag = "new_tag"
        raw_vectors, entities = gen_binary_entities(insert_count)
355 356 357 358 359 360
        connect.create_partition(binary_collection, tag)
        connect.create_partition(binary_collection, new_tag)
        res_ids = connect.insert(binary_collection, entities)
        res_ids_2 = connect.insert(binary_collection, entities, partition_tag=tag)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
361
        assert res == insert_count * 2
X
Xiaohai Xu 已提交
362

363
    @pytest.mark.level(2)
364
    def test_collection_count_multi_partitions_D(self, connect, binary_collection, insert_count):
D
del-zhenwu 已提交
365
        '''
366 367 368 369
        target: test collection rows_count is correct or not
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the collection count is equal to the length of entities
D
del-zhenwu 已提交
370
        '''
371 372
        new_tag = "new_tag"
        raw_vectors, entities = gen_binary_entities(insert_count)
373 374 375 376 377 378
        connect.create_partition(binary_collection, tag)
        connect.create_partition(binary_collection, new_tag)
        res_ids = connect.insert(binary_collection, entities, partition_tag=tag)
        res_ids2 = connect.insert(binary_collection, entities, partition_tag=new_tag)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
379
        assert res == insert_count * 2
D
del-zhenwu 已提交
380

381 382 383 384 385 386 387 388 389 390 391 392 393 394
    def _test_collection_count_after_index_created(self, connect, binary_collection, get_jaccard_index, insert_count):
        '''
        target: test count_entities, after index have been created
        method: add vectors in db, and create index, then calling count_entities with correct params 
        expected: count_entities raise exception
        '''
        raw_vectors, entities = gen_binary_entities(insert_count)
        res = connect.insert(binary_collection, entities)
        connect.flush([binary_collection])
        connect.create_index(binary_collection, field_name, get_jaccard_index)
        res = connect.count_entities(binary_collection)
        assert res == insert_count

    def _test_collection_count_after_index_created(self, connect, binary_collection, get_hamming_index, insert_count):
D
del-zhenwu 已提交
395
        '''
D
del-zhenwu 已提交
396
        target: test count_entities, after index have been created
397
        method: add vectors in db, and create index, then calling count_entities with correct params 
D
del-zhenwu 已提交
398
        expected: count_entities raise exception
D
del-zhenwu 已提交
399
        '''
400
        raw_vectors, entities = gen_binary_entities(insert_count)
401 402 403 404
        res = connect.insert(binary_collection, entities)
        connect.flush([binary_collection])
        connect.create_index(binary_collection, field_name, get_hamming_index)
        res = connect.count_entities(binary_collection)
405 406
        assert res == insert_count

407
    def test_collection_count_no_entities(self, connect, binary_collection):
X
Xiaohai Xu 已提交
408 409 410
        '''
        target: test collection rows_count is correct or not, if collection is empty
        method: create collection and no vectors in it,
D
del-zhenwu 已提交
411
            assert the value returned by count_entities method is equal to 0
X
Xiaohai Xu 已提交
412
        expected: the count is equal to 0
413
        '''    
414
        res = connect.count_entities(binary_collection)
X
Xiaohai Xu 已提交
415 416 417
        assert res == 0


418
class TestCollectionMultiCollections:
X
Xiaohai Xu 已提交
419 420 421 422 423 424 425
    """
    params means different nb, the nb value may trigger merge, or not
    """
    @pytest.fixture(
        scope="function",
        params=[
            1,
426
            4000,
427
            6001
X
Xiaohai Xu 已提交
428 429
        ],
    )
430
    def insert_count(self, request):
X
Xiaohai Xu 已提交
431
        yield request.param
432
        
433 434 435 436 437 438 439 440 441 442 443
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_jaccard_index(self, request, connect):
        if request.param["index_type"] in binary_support():
            request.param["metric_type"] = "JACCARD"
            return request.param
        else:
            pytest.skip("Skip index")

444 445 446 447 448 449
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_hamming_index(self, request, connect):
        if request.param["index_type"] in binary_support():
450
            request.param["metric_type"] = "HAMMING"
451 452 453
            return request.param
        else:
            pytest.skip("Skip index")
X
Xiaohai Xu 已提交
454

455 456 457 458 459 460
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_substructure_index(self, request, connect):
        if request.param["index_type"] == "FLAT":
461
            request.param["metric_type"] = "SUBSTRUCTURE"
462 463 464
            return request.param
        else:
            pytest.skip("Skip index")
X
Xiaohai Xu 已提交
465 466 467 468 469

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
470 471
    def get_superstructure_index(self, request, connect):
        if request.param["index_type"] == "FLAT":
472
            request.param["metric_type"] = "SUPERSTRUCTURE"
X
Xiaohai Xu 已提交
473 474
            return request.param
        else:
475
            pytest.skip("Skip index")
X
Xiaohai Xu 已提交
476

477
    def test_collection_count_multi_collections_l2(self, connect, insert_count):
X
Xiaohai Xu 已提交
478
        '''
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
        target: test collection rows_count is correct or not with multiple collections of L2
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        entities = gen_entities(insert_count)
        collection_list = []
        collection_num = 20
        for i in range(collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, default_fields)
            res = connect.insert(collection_name, entities)
        connect.flush(collection_list)
        for i in range(collection_num):
            res = connect.count_entities(collection_list[i])
            assert res == insert_count

497
    # TODO:
498
    def _test_collection_count_multi_collections_binary(self, connect, binary_collection, insert_count):
X
Xiaohai Xu 已提交
499
        '''
500 501 502 503 504 505
        target: test collection rows_count is correct or not with multiple collections of JACCARD
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        raw_vectors, entities = gen_binary_entities(insert_count)
506
        res = connect.insert(binary_collection, entities)
507 508 509 510 511 512 513 514 515 516 517 518 519
        # logging.getLogger().info(entities)
        collection_list = []
        collection_num = 20
        for i in range(collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, fields)
            res = connect.insert(collection_name, entities)
        connect.flush(collection_list)
        for i in range(collection_num):
            res = connect.count_entities(collection_list[i])
            assert res == insert_count

520 521
    # TODO:
    def _test_collection_count_multi_collections_mix(self, connect):
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
        '''
        target: test collection rows_count is correct or not with multiple collections of JACCARD
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        collection_list = []
        collection_num = 20
        for i in range(0, int(collection_num / 2)):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, default_fields)
            res = connect.insert(collection_name, entities)
        for i in range(int(collection_num / 2), collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, fields)
            res = connect.insert(collection_name, binary_entities)
        connect.flush(collection_list)
        for i in range(collection_num):
            res = connect.count_entities(collection_list[i])
            assert res == nb