test_collection_count.py 21.9 KB
Newer Older
X
Xiaohai Xu 已提交
1
import pdb
2
import copy
X
Xiaohai Xu 已提交
3 4 5 6 7
import logging
import itertools
from time import sleep
import threading
from multiprocessing import Process
8 9 10
import sklearn.preprocessing

import pytest
X
Xiaohai Xu 已提交
11 12 13
from utils import *

nb = 6000
14 15 16 17
dim = 128
tag = "tag"
collection_id = "count_collection"
add_interval_time = 3
18
segment_row_count = 5000
19 20 21 22 23 24
default_fields = gen_default_fields() 
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
field_name = "fload_vector"
index_name = "index_name"

X
Xiaohai Xu 已提交
25 26 27 28 29 30 31 32 33

class TestCollectionCount:
    """
    params means different nb, the nb value may trigger merge, or not
    """
    @pytest.fixture(
        scope="function",
        params=[
            1,
34
            4000,
35
            6001
X
Xiaohai Xu 已提交
36 37
        ],
    )
38
    def insert_count(self, request):
X
Xiaohai Xu 已提交
39 40 41 42 43 44 45 46 47 48 49
        yield request.param

    """
    generate valid create_index params
    """
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        if str(connect._cmd("mode")[1]) == "CPU":
50
            if request.param["index_type"] in index_cpu_not_support():
X
Xiaohai Xu 已提交
51
                pytest.skip("sq8h not support in cpu mode")
52
        request.param.update({"metric_type": "L2"})
X
Xiaohai Xu 已提交
53 54
        return request.param

55
    def test_collection_count(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
56 57 58
        '''
        target: test collection rows_count is correct or not
        method: create collection and add vectors in it,
D
del-zhenwu 已提交
59
            assert the value returned by count_entities method is equal to length of vectors
X
Xiaohai Xu 已提交
60 61
        expected: the count is equal to the length of vectors
        '''
62 63
        entities = gen_entities(insert_count)
        res = connect.insert(collection, entities)
X
Xiaohai Xu 已提交
64
        connect.flush([collection])
65 66
        res = connect.count_entities(collection)
        assert res == insert_count
X
Xiaohai Xu 已提交
67

68
    def test_collection_count_partition(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
69 70 71
        '''
        target: test collection rows_count is correct or not
        method: create collection, create partition and add vectors in it,
D
del-zhenwu 已提交
72
            assert the value returned by count_entities method is equal to length of vectors
X
Xiaohai Xu 已提交
73 74
        expected: the count is equal to the length of vectors
        '''
75 76 77
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        res_ids = connect.insert(collection, entities, partition_tag=tag)
X
Xiaohai Xu 已提交
78
        connect.flush([collection])
79 80
        res = connect.count_entities(collection)
        assert res == insert_count
X
Xiaohai Xu 已提交
81

82
    def test_collection_count_multi_partitions_A(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
83 84
        '''
        target: test collection rows_count is correct or not
85 86 87
        method: create collection, create partitions and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
X
Xiaohai Xu 已提交
88 89
        '''
        new_tag = "new_tag"
90 91 92 93
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        res_ids = connect.insert(collection, entities)
X
Xiaohai Xu 已提交
94
        connect.flush([collection])
95 96
        res = connect.count_entities(collection)
        assert res == insert_count
X
Xiaohai Xu 已提交
97

98
    def test_collection_count_multi_partitions_B(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
99 100
        '''
        target: test collection rows_count is correct or not
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        new_tag = "new_tag"
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        res_ids = connect.insert(collection, entities, partition_tag=tag)
        connect.flush([collection])
        res = connect.count_entities(collection)
        assert res == insert_count

    def test_collection_count_multi_partitions_C(self, connect, collection, insert_count):
        '''
        target: test collection rows_count is correct or not
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
X
Xiaohai Xu 已提交
119 120 121
        expected: the count is equal to the length of vectors
        '''
        new_tag = "new_tag"
122 123 124 125 126
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        res_ids = connect.insert(collection, entities)
        res_ids_2 = connect.insert(collection, entities, partition_tag=tag)
X
Xiaohai Xu 已提交
127
        connect.flush([collection])
128 129
        res = connect.count_entities(collection)
        assert res == insert_count * 2
X
Xiaohai Xu 已提交
130

131
    def test_collection_count_multi_partitions_D(self, connect, collection, insert_count):
X
Xiaohai Xu 已提交
132 133
        '''
        target: test collection rows_count is correct or not
134 135 136
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the collection count is equal to the length of entities
X
Xiaohai Xu 已提交
137 138
        '''
        new_tag = "new_tag"
139 140 141 142 143
        entities = gen_entities(insert_count)
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        res_ids = connect.insert(collection, entities, partition_tag=tag)
        res_ids2 = connect.insert(collection, entities, partition_tag=new_tag)
X
Xiaohai Xu 已提交
144
        connect.flush([collection])
145 146
        res = connect.count_entities(collection)
        assert res == insert_count * 2
X
Xiaohai Xu 已提交
147

148
    def _test_collection_count_after_index_created(self, connect, collection, get_simple_index, insert_count):
X
Xiaohai Xu 已提交
149
        '''
D
del-zhenwu 已提交
150 151 152
        target: test count_entities, after index have been created
        method: add vectors in db, and create index, then calling count_entities with correct params 
        expected: count_entities raise exception
X
Xiaohai Xu 已提交
153
        '''
154 155
        entities = gen_entities(insert_count)
        res = connect.insert(collection, entities)
X
Xiaohai Xu 已提交
156
        connect.flush([collection])
157
        connect.create_index(collection, field_name, get_simple_index)
158 159 160 161 162 163 164 165 166 167 168 169 170
        res = connect.count_entities(collection)
        assert res == insert_count

    def test_count_without_connection(self, collection, dis_connect):
        '''
        target: test count_entities, without connection
        method: calling count_entities with correct params, with a disconnected instance
        expected: count_entities raise exception
        '''
        with pytest.raises(Exception) as e:
            dis_connect.count_entities(collection)

    def test_collection_count_no_vectors(self, connect, collection):
X
Xiaohai Xu 已提交
171 172 173
        '''
        target: test collection rows_count is correct or not, if collection is empty
        method: create collection and no vectors in it,
D
del-zhenwu 已提交
174
            assert the value returned by count_entities method is equal to 0
X
Xiaohai Xu 已提交
175
        expected: the count is equal to 0
176 177
        '''    
        res = connect.count_entities(collection)
X
Xiaohai Xu 已提交
178 179 180 181 182 183 184 185 186 187 188
        assert res == 0


class TestCollectionCountIP:
    """
    params means different nb, the nb value may trigger merge, or not
    """
    @pytest.fixture(
        scope="function",
        params=[
            1,
189
            4000,
190
            6001
X
Xiaohai Xu 已提交
191 192
        ],
    )
193
    def insert_count(self, request):
X
Xiaohai Xu 已提交
194 195 196 197 198 199 200 201 202 203 204
        yield request.param

    """
    generate valid create_index params
    """
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        if str(connect._cmd("mode")[1]) == "CPU":
205 206
            if request.param["index_type"] in index_cpu_not_support():
                pytest.skip("sq8h not support in cpu mode")
207
        request.param.update({"metric_type": "IP"})
X
Xiaohai Xu 已提交
208 209
        return request.param

210
    def _test_collection_count_after_index_created(self, connect, collection, get_simple_index, insert_count):
X
Xiaohai Xu 已提交
211
        '''
D
del-zhenwu 已提交
212
        target: test count_entities, after index have been created
213
        method: add vectors in db, and create index, then calling count_entities with correct params 
D
del-zhenwu 已提交
214
        expected: count_entities raise exception
X
Xiaohai Xu 已提交
215
        '''
216
        entities = gen_entities(insert_count)
217 218 219 220
        res = connect.insert(collection, entities)
        connect.flush([collection])
        connect.create_index(collection, field_name, get_simple_index)
        res = connect.count_entities(collection)
221 222
        assert res == insert_count

X
Xiaohai Xu 已提交
223

D
del-zhenwu 已提交
224
class TestCollectionCountBinary:
X
Xiaohai Xu 已提交
225 226 227 228 229 230 231
    """
    params means different nb, the nb value may trigger merge, or not
    """
    @pytest.fixture(
        scope="function",
        params=[
            1,
232
            4000,
233
            6001
X
Xiaohai Xu 已提交
234 235
        ],
    )
236
    def insert_count(self, request):
X
Xiaohai Xu 已提交
237 238
        yield request.param

239 240 241 242 243 244 245 246 247 248 249
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_jaccard_index(self, request, connect):
        if request.param["index_type"] in binary_support():
            request.param["metric_type"] = "JACCARD"
            return request.param
        else:
            pytest.skip("Skip index")

X
Xiaohai Xu 已提交
250 251 252 253 254
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_hamming_index(self, request, connect):
255
        if request.param["index_type"] in binary_support():
256
            request.param["metric_type"] = "HAMMING"
X
Xiaohai Xu 已提交
257 258
            return request.param
        else:
259
            pytest.skip("Skip index")
X
Xiaohai Xu 已提交
260

D
del-zhenwu 已提交
261 262 263 264 265
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_substructure_index(self, request, connect):
266
        if request.param["index_type"] == "FLAT":
267
            request.param["metric_type"] = "SUBSTRUCTURE"
D
del-zhenwu 已提交
268 269
            return request.param
        else:
270
            pytest.skip("Skip index")
D
del-zhenwu 已提交
271 272 273 274 275 276

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_superstructure_index(self, request, connect):
277
        if request.param["index_type"] == "FLAT":
278
            request.param["metric_type"] = "SUPERSTRUCTURE"
D
del-zhenwu 已提交
279 280
            return request.param
        else:
281
            pytest.skip("Skip index")
D
del-zhenwu 已提交
282

283
    def test_collection_count(self, connect, binary_collection, insert_count):
X
Xiaohai Xu 已提交
284 285
        '''
        target: test collection rows_count is correct or not
286 287 288
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
X
Xiaohai Xu 已提交
289
        '''
290
        raw_vectors, entities = gen_binary_entities(insert_count)
291
        res = connect.insert(binary_collection, entities)
292
        logging.getLogger().info(len(res))
293 294
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
295
        assert res == insert_count
X
Xiaohai Xu 已提交
296

297
    def test_collection_count_partition(self, connect, binary_collection, insert_count):
D
del-zhenwu 已提交
298 299
        '''
        target: test collection rows_count is correct or not
300 301 302
        method: create collection, create partition and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
D
del-zhenwu 已提交
303
        '''
304
        raw_vectors, entities = gen_binary_entities(insert_count)
305 306 307
        connect.create_partition(binary_collection, tag)
        res_ids = connect.insert(binary_collection, entities, partition_tag=tag)
        connect.flush([binary_collection])
308
        res = connect.count_entities(binary_collection)
309
        assert res == insert_count
D
del-zhenwu 已提交
310

311
    @pytest.mark.level(2)
312
    def test_collection_count_multi_partitions_A(self, connect, binary_collection, insert_count):
D
del-zhenwu 已提交
313 314
        '''
        target: test collection rows_count is correct or not
315 316 317
        method: create collection, create partitions and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
D
del-zhenwu 已提交
318
        '''
319 320
        new_tag = "new_tag"
        raw_vectors, entities = gen_binary_entities(insert_count)
321 322 323 324 325
        connect.create_partition(binary_collection, tag)
        connect.create_partition(binary_collection, new_tag)
        res_ids = connect.insert(binary_collection, entities)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
326
        assert res == insert_count
D
del-zhenwu 已提交
327

328
    @pytest.mark.level(2)
329
    def test_collection_count_multi_partitions_B(self, connect, binary_collection, insert_count):
X
Xiaohai Xu 已提交
330
        '''
331 332 333 334 335 336 337
        target: test collection rows_count is correct or not
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        new_tag = "new_tag"
        raw_vectors, entities = gen_binary_entities(insert_count)
338 339 340 341 342
        connect.create_partition(binary_collection, tag)
        connect.create_partition(binary_collection, new_tag)
        res_ids = connect.insert(binary_collection, entities, partition_tag=tag)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
343 344
        assert res == insert_count

345
    def test_collection_count_multi_partitions_C(self, connect, binary_collection, insert_count):
X
Xiaohai Xu 已提交
346
        '''
347 348 349 350 351 352 353
        target: test collection rows_count is correct or not
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        new_tag = "new_tag"
        raw_vectors, entities = gen_binary_entities(insert_count)
354 355 356 357 358 359
        connect.create_partition(binary_collection, tag)
        connect.create_partition(binary_collection, new_tag)
        res_ids = connect.insert(binary_collection, entities)
        res_ids_2 = connect.insert(binary_collection, entities, partition_tag=tag)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
360
        assert res == insert_count * 2
X
Xiaohai Xu 已提交
361

362
    @pytest.mark.level(2)
363
    def test_collection_count_multi_partitions_D(self, connect, binary_collection, insert_count):
D
del-zhenwu 已提交
364
        '''
365 366 367 368
        target: test collection rows_count is correct or not
        method: create collection, create partitions and add entities in one of the partitions,
            assert the value returned by count_entities method is equal to length of entities
        expected: the collection count is equal to the length of entities
D
del-zhenwu 已提交
369
        '''
370 371
        new_tag = "new_tag"
        raw_vectors, entities = gen_binary_entities(insert_count)
372 373 374 375 376 377
        connect.create_partition(binary_collection, tag)
        connect.create_partition(binary_collection, new_tag)
        res_ids = connect.insert(binary_collection, entities, partition_tag=tag)
        res_ids2 = connect.insert(binary_collection, entities, partition_tag=new_tag)
        connect.flush([binary_collection])
        res = connect.count_entities(binary_collection)
378
        assert res == insert_count * 2
D
del-zhenwu 已提交
379

380 381 382 383 384 385 386 387 388 389 390 391 392 393
    def _test_collection_count_after_index_created(self, connect, binary_collection, get_jaccard_index, insert_count):
        '''
        target: test count_entities, after index have been created
        method: add vectors in db, and create index, then calling count_entities with correct params 
        expected: count_entities raise exception
        '''
        raw_vectors, entities = gen_binary_entities(insert_count)
        res = connect.insert(binary_collection, entities)
        connect.flush([binary_collection])
        connect.create_index(binary_collection, field_name, get_jaccard_index)
        res = connect.count_entities(binary_collection)
        assert res == insert_count

    def _test_collection_count_after_index_created(self, connect, binary_collection, get_hamming_index, insert_count):
D
del-zhenwu 已提交
394
        '''
D
del-zhenwu 已提交
395
        target: test count_entities, after index have been created
396
        method: add vectors in db, and create index, then calling count_entities with correct params 
D
del-zhenwu 已提交
397
        expected: count_entities raise exception
D
del-zhenwu 已提交
398
        '''
399
        raw_vectors, entities = gen_binary_entities(insert_count)
400 401 402 403
        res = connect.insert(binary_collection, entities)
        connect.flush([binary_collection])
        connect.create_index(binary_collection, field_name, get_hamming_index)
        res = connect.count_entities(binary_collection)
404 405
        assert res == insert_count

406
    def test_collection_count_no_entities(self, connect, binary_collection):
X
Xiaohai Xu 已提交
407 408 409
        '''
        target: test collection rows_count is correct or not, if collection is empty
        method: create collection and no vectors in it,
D
del-zhenwu 已提交
410
            assert the value returned by count_entities method is equal to 0
X
Xiaohai Xu 已提交
411
        expected: the count is equal to 0
412
        '''    
413
        res = connect.count_entities(binary_collection)
X
Xiaohai Xu 已提交
414 415 416
        assert res == 0


417
class TestCollectionMultiCollections:
X
Xiaohai Xu 已提交
418 419 420 421 422 423 424
    """
    params means different nb, the nb value may trigger merge, or not
    """
    @pytest.fixture(
        scope="function",
        params=[
            1,
425
            4000,
426
            6001
X
Xiaohai Xu 已提交
427 428
        ],
    )
429
    def insert_count(self, request):
X
Xiaohai Xu 已提交
430
        yield request.param
431
        
432 433 434 435 436 437 438 439 440 441 442
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_jaccard_index(self, request, connect):
        if request.param["index_type"] in binary_support():
            request.param["metric_type"] = "JACCARD"
            return request.param
        else:
            pytest.skip("Skip index")

443 444 445 446 447 448
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_hamming_index(self, request, connect):
        if request.param["index_type"] in binary_support():
449
            request.param["metric_type"] = "HAMMING"
450 451 452
            return request.param
        else:
            pytest.skip("Skip index")
X
Xiaohai Xu 已提交
453

454 455 456 457 458 459
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_substructure_index(self, request, connect):
        if request.param["index_type"] == "FLAT":
460
            request.param["metric_type"] = "SUBSTRUCTURE"
461 462 463
            return request.param
        else:
            pytest.skip("Skip index")
X
Xiaohai Xu 已提交
464 465 466 467 468

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
469 470
    def get_superstructure_index(self, request, connect):
        if request.param["index_type"] == "FLAT":
471
            request.param["metric_type"] = "SUPERSTRUCTURE"
X
Xiaohai Xu 已提交
472 473
            return request.param
        else:
474
            pytest.skip("Skip index")
X
Xiaohai Xu 已提交
475

476
    def test_collection_count_multi_collections_l2(self, connect, insert_count):
X
Xiaohai Xu 已提交
477
        '''
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
        target: test collection rows_count is correct or not with multiple collections of L2
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        entities = gen_entities(insert_count)
        collection_list = []
        collection_num = 20
        for i in range(collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, default_fields)
            res = connect.insert(collection_name, entities)
        connect.flush(collection_list)
        for i in range(collection_num):
            res = connect.count_entities(collection_list[i])
            assert res == insert_count

496
    # TODO:
497
    def _test_collection_count_multi_collections_binary(self, connect, binary_collection, insert_count):
X
Xiaohai Xu 已提交
498
        '''
499 500 501 502 503 504
        target: test collection rows_count is correct or not with multiple collections of JACCARD
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        raw_vectors, entities = gen_binary_entities(insert_count)
505
        res = connect.insert(binary_collection, entities)
506 507 508 509 510 511
        # logging.getLogger().info(entities)
        collection_list = []
        collection_num = 20
        for i in range(collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
512
            connect.create_collection(collection_name, default_fields)
513 514 515 516 517 518
            res = connect.insert(collection_name, entities)
        connect.flush(collection_list)
        for i in range(collection_num):
            res = connect.count_entities(collection_list[i])
            assert res == insert_count

519 520
    # TODO:
    def _test_collection_count_multi_collections_mix(self, connect):
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
        '''
        target: test collection rows_count is correct or not with multiple collections of JACCARD
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: the count is equal to the length of entities
        '''
        collection_list = []
        collection_num = 20
        for i in range(0, int(collection_num / 2)):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, default_fields)
            res = connect.insert(collection_name, entities)
        for i in range(int(collection_num / 2), collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
537
            connect.create_collection(collection_name, default_fields)
538 539 540 541 542
            res = connect.insert(collection_name, binary_entities)
        connect.flush(collection_list)
        for i in range(collection_num):
            res = connect.count_entities(collection_list[i])
            assert res == nb