test_collection_stats.py 14.5 KB
Newer Older
1 2 3 4 5 6 7 8 9
import time
import pdb
import threading
import logging
from multiprocessing import Pool, Process
import pytest
from utils import *

dim = 128
10
segment_row_count = 5000
11 12 13 14 15
nprobe = 1
top_k = 1
epsilon = 0.0001
tag = "1970-01-01"
nb = 6000
16
nlist = 1024
17 18 19 20 21 22 23
collection_id = "collection_stats"
field_name = "float_vector"
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
24 25


26
class TestStatsBase:
27 28
    """
    ******************************************************************
29
      The following cases are used to test `collection_stats` function
30 31 32 33 34
    ******************************************************************
    """
    
    @pytest.fixture(
        scope="function",
35
        params=gen_invalid_strs()
36
    )
X
Xiaohai Xu 已提交
37
    def get_collection_name(self, request):
38 39
        yield request.param

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        if str(connect._cmd("mode")) == "CPU":
            if request.param["index_type"] in index_cpu_not_support():
                pytest.skip("CPU not support index_type: ivf_sq8h")
        return request.param

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_jaccard_index(self, request, connect):
        logging.getLogger().info(request.param)
        if request.param["index_type"] in binary_support():
57
            request.param["metric_type"] = "JACCARD"
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
            return request.param
        else:
            pytest.skip("Skip index Temporary")

    def test_get_collection_stats_name_not_existed(self, connect, collection):
        '''
        target: get collection stats where collection name does not exist
        method: call collection_stats with a random collection_name, which is not in db
        expected: status not ok
        '''
        collection_name = gen_unique_str(collection_id)
        with pytest.raises(Exception) as e:
            stats = connect.get_collection_stats(collection_name)

    def test_get_collection_stats_name_invalid(self, connect, get_collection_name):
73
        '''
74 75
        target: get collection stats where collection name is invalid
        method: call collection_stats with invalid collection_name
76 77
        expected: status not ok
        '''
X
Xiaohai Xu 已提交
78
        collection_name = get_collection_name
79 80 81 82 83 84 85 86 87 88 89 90 91 92
        with pytest.raises(Exception) as e:
            stats = connect.get_collection_stats(collection_name)

    def test_get_collection_stats_empty(self, connect, collection):
        '''
        target: get collection stats where no entity in collection
        method: call collection_stats in empty collection
        expected: segment = []
        '''
        stats = connect.get_collection_stats(collection)
        assert stats["row_count"] == 0
        assert len(stats["partitions"]) == 1
        assert stats["partitions"][0]["tag"] == "_default"
        assert stats["partitions"][0]["row_count"] == 0
93

94
    def test_get_collection_stats_batch(self, connect, collection):
95
        '''
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        target: get row count with collection_stats
        method: add entities, check count in collection info
        expected: count as expected
        '''
        ids = connect.insert(collection, entities)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        assert stats["row_count"] == nb
        assert len(stats["partitions"]) == 1
        assert stats["partitions"][0]["tag"] == "_default"
        assert stats["partitions"][0]["row_count"] == nb

    def test_get_collection_stats_single(self, connect, collection):
        '''
        target: get row count with collection_stats
        method: add entity one by one, check count in collection info
        expected: count as expected
        '''
        nb = 10
        for i in range(nb):
            ids = connect.insert(collection, entity)
            connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        assert stats["row_count"] == nb
        assert len(stats["partitions"]) == 1
        assert stats["partitions"][0]["tag"] == "_default"
        assert stats["partitions"][0]["row_count"] == nb

    def test_get_collection_stats_after_delete(self, connect, collection):
        '''
        target: get row count with collection_stats
        method: add and delete entities, check count in collection info
128 129
        expected: status ok, count as expected
        '''
130
        ids = connect.insert(collection, entities)
X
Xiaohai Xu 已提交
131
        status = connect.flush([collection])
132
        delete_ids = [ids[0], ids[-1]]
133 134 135 136 137
        connect.delete_entity_by_id(collection, delete_ids)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        assert stats["row_count"] == nb - 2
        assert stats["partitions"][0]["segments"][0]["data_size"] > 0
138 139
        # TODO
        # assert stats["partitions"][0]["segments"][0]["index_type"] == "FLAT"
140

141
    def test_get_collection_stats_after_compact_parts(self, connect, collection):
142
        '''
143 144 145
        target: get row count with collection_stats
        method: add and delete entities, and compact collection, check count in collection info
        expected: status ok, count as expected
146
        '''
147
        ids = connect.insert(collection, entities)
X
Xiaohai Xu 已提交
148
        status = connect.flush([collection])
149 150 151 152 153 154 155 156 157 158 159 160
        delete_ids = ids[:3000]
        connect.delete_entity_by_id(collection, delete_ids)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        logging.getLogger().info(stats)
        assert stats["row_count"] == nb - 3000
        compact_before = stats["partitions"][0]["segments"][0]["data_size"]
        connect.compact(collection)
        stats = connect.get_collection_stats(collection)
        logging.getLogger().info(stats)
        compact_after = stats["partitions"][0]["segments"][0]["data_size"]
        assert compact_before > compact_after
161

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
    def test_get_collection_stats_after_compact_delete_one(self, connect, collection):
        '''
        target: get row count with collection_stats
        method: add and delete one entity, and compact collection, check count in collection info
        expected: status ok, count as expected
        '''
        ids = connect.insert(collection, entities)
        status = connect.flush([collection])
        delete_ids = ids[:1]
        connect.delete_entity_by_id(collection, delete_ids)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        logging.getLogger().info(stats)
        compact_before = stats["partitions"][0]["segments"][0]["data_size"]
        connect.compact(collection)
        stats = connect.get_collection_stats(collection)
        logging.getLogger().info(stats)
        compact_after = stats["partitions"][0]["segments"][0]["data_size"]
        # pdb.set_trace()
        assert compact_before == compact_after
182

183
    def test_get_collection_stats_partition(self, connect, collection):
184
        '''
X
Xiaohai Xu 已提交
185
        target: get partition info in a collection
186
        method: call collection_stats after partition created and check partition_stats
187 188
        expected: status ok, vectors added to partition
        '''
189 190 191 192 193 194 195
        connect.create_partition(collection, tag)
        ids = connect.insert(collection, entities, partition_tag=tag)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        assert len(stats["partitions"]) == 2
        assert stats["partitions"][1]["tag"] == tag
        assert stats["partitions"][1]["row_count"] == nb
196

197
    def test_get_collection_stats_partitions(self, connect, collection):
198
        '''
X
Xiaohai Xu 已提交
199
        target: get partition info in a collection
200
        method: create two partitions, add vectors in one of the partitions, call collection_stats and check 
201 202 203
        expected: status ok, vectors added to one partition but not the other
        '''
        new_tag = "new_tag"
204 205 206 207 208 209
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        ids = connect.insert(collection, entities, partition_tag=tag)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        for partition in stats["partitions"]:
210 211
            if partition["tag"] == tag:
                assert partition["row_count"] == nb
212
            else:
213
                assert partition["row_count"] == 0
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        ids = connect.insert(collection, entities, partition_tag=new_tag)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        for partition in stats["partitions"]:
            if partition["tag"] in [tag, new_tag]:
                assert partition["row_count"] == nb
    
    def test_get_collection_stats_after_index_created(self, connect, collection, get_simple_index):
        '''
        target: test collection info after index created
        method: create collection, add vectors, create index and call collection_stats 
        expected: status ok, index created and shown in segments
        '''
        ids = connect.insert(collection, entities)
        connect.flush([collection])
229
        connect.create_index(collection, field_name, get_simple_index)
230 231 232
        stats = connect.get_collection_stats(collection)
        logging.getLogger().info(stats)
        assert stats["partitions"][0]["segments"][0]["row_count"] == nb
233 234
        # TODO
        # assert stats["partitions"][0]["segments"][0]["index_name"] == get_simple_index["index_type"]
235

236
    def test_get_collection_stats_after_index_created_ip(self, connect, collection, get_simple_index):
237
        '''
238 239 240
        target: test collection info after index created
        method: create collection, add vectors, create index and call collection_stats 
        expected: status ok, index created and shown in segments
241
        '''
242 243 244
        get_simple_index["metric_type"] = "IP"
        ids = connect.insert(collection, entities)
        connect.flush([collection])
245
        get_simple_index.update({"metric_type": "IP"})
246 247
        connect.create_index(collection, field_name, get_simple_index)
        stats = connect.get_collection_stats(collection)
248 249
        logging.getLogger().info(stats)
        assert stats["partitions"][0]["segments"][0]["row_count"] == nb
250 251
        # TODO
        # assert stats["partitions"][0]["segments"][0]["index_name"] == get_simple_index["index_type"]
252

253
    def test_get_collection_stats_after_index_created_jac(self, connect, binary_collection, get_jaccard_index):
254
        '''
X
Xiaohai Xu 已提交
255
        target: test collection info after index created
256
        method: create collection, add binary entities, create index and call collection_stats 
257
        expected: status ok, index created and shown in segments
258
        '''
259 260 261 262
        ids = connect.insert(binary_collection, binary_entities)
        connect.flush([binary_collection])
        connect.create_index(binary_collection, "binary_vector", get_jaccard_index)
        stats = connect.get_collection_stats(binary_collection)
263 264
        logging.getLogger().info(stats)
        assert stats["partitions"][0]["segments"][0]["row_count"] == nb
265 266
        # TODO
        # assert stats["partitions"][0]["segments"][0]["index_name"] == get_jaccard_index["index_type"]
267

268
    def test_get_collection_stats_after_create_different_index(self, connect, collection):
269
        '''
X
Xiaohai Xu 已提交
270
        target: test collection info after index created repeatedly
271
        method: create collection, add vectors, create index and call collection_stats multiple times 
272
        expected: status ok, index info shown in segments
273
        '''
274 275 276
        ids = connect.insert(collection, entities)
        connect.flush([collection])
        for index_type in ["IVF_FLAT", "IVF_SQ8"]:
277
            connect.create_index(collection, field_name, {"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
278 279
            stats = connect.get_collection_stats(collection)
            logging.getLogger().info(stats)
280 281
            # TODO
            # assert stats["partitions"][0]["segments"][0]["index_name"] == index_type
282
            assert stats["partitions"][0]["segments"][0]["row_count"] == nb
283

284
    def test_collection_count_multi_collections(self, connect):
285
        '''
286 287 288 289
        target: test collection rows_count is correct or not with multiple collections of L2
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: row count in segments
290
        '''
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        collection_list = []
        collection_num = 10
        for i in range(collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, default_fields)
            res = connect.insert(collection_name, entities)
        connect.flush(collection_list)
        for i in range(collection_num):
            stats = connect.get_collection_stats(collection_list[i])
            assert stats["partitions"][0]["segments"][0]["row_count"] == nb
            connect.drop_collection(collection_list[i])

    def test_collection_count_multi_collections_indexed(self, connect):
        '''
        target: test collection rows_count is correct or not with multiple collections of L2
        method: create collection and add entities in it,
            assert the value returned by count_entities method is equal to length of entities
        expected: row count in segments
        '''
        collection_list = []
        collection_num = 10
        for i in range(collection_num):
            collection_name = gen_unique_str(collection_id)
            collection_list.append(collection_name)
            connect.create_collection(collection_name, default_fields)
            res = connect.insert(collection_name, entities)
            connect.flush(collection_list)
            if i % 2:
320
                connect.create_index(collection_name, field_name, {"index_type": "IVF_SQ8", "params":{"nlist": 1024}, "metric_type": "L2"})
321
            else:
322
                connect.create_index(collection_name, field_name, {"index_type": "IVF_FLAT","params":{ "nlist": 1024}, "metric_type": "L2"})
323 324 325
        for i in range(collection_num):
            stats = connect.get_collection_stats(collection_list[i])
            assert stats["partitions"][0]["segments"][0]["row_count"] == nb
326 327 328 329 330
            # TODO
            # if i % 2:
            #     assert stats["partitions"][0]["segments"][0]["index_name"] == "IVF_SQ8"
            # else:
            #     assert stats["partitions"][0]["segments"][0]["index_name"] == "IVF_FLAT"
331
            connect.drop_collection(collection_list[i])