test_list_id_in_segment.py 12.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
import time
import random
import pdb
import threading
import logging
from multiprocessing import Pool, Process
import pytest
from utils import *

dim = 128
11
segment_row_count = 100000
12
nb = 6000
C
cqy123456 已提交
13
tag = "1970_01_01"
14
field_name = default_float_vec_field_name
15
binary_field_name = default_binary_vec_field_name
16 17 18 19 20 21 22 23
collection_id = "list_id_in_segment"
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields() 


24
def get_segment_id(connect, collection, nb=1, vec_type='float', index_params=None):
25 26 27 28 29 30 31
    if vec_type != "float":
        vectors, entities = gen_binary_entities(nb)
    else:
        entities = gen_entities(nb)
    ids = connect.insert(collection, entities)
    connect.flush([collection])
    if index_params:
32 33 34 35
        if vec_type == 'float':
            connect.create_index(collection, field_name, index_params)
        else:
            connect.create_index(collection, binary_field_name, index_params)
36
    stats = connect.get_collection_stats(collection)
37
    return ids, stats["partitions"][0]["segments"][0]["id"]
38 39


40
class TestListIdInSegmentBase:
41 42 43
        
    """
    ******************************************************************
D
del-zhenwu 已提交
44
      The following cases are used to test `list_id_in_segment` function
45 46
    ******************************************************************
    """
D
del-zhenwu 已提交
47
    def test_list_id_in_segment_collection_name_None(self, connect, collection):
48
        '''
X
Xiaohai Xu 已提交
49
        target: get vector ids where collection name is None
D
del-zhenwu 已提交
50
        method: call list_id_in_segment with the collection_name: None
51 52
        expected: exception raised
        '''
X
Xiaohai Xu 已提交
53
        collection_name = None
54
        ids, segment_id = get_segment_id(connect, collection)
55
        with pytest.raises(Exception) as e:
56
            connect.list_id_in_segment(collection_name, segment_id)
57

D
del-zhenwu 已提交
58
    def test_list_id_in_segment_collection_name_not_existed(self, connect, collection):
59
        '''
X
Xiaohai Xu 已提交
60
        target: get vector ids where collection name does not exist
D
del-zhenwu 已提交
61
        method: call list_id_in_segment with a random collection_name, which is not in db
62 63
        expected: status not ok
        '''
64
        collection_name = gen_unique_str(collection_id)
65
        ids, segment_id = get_segment_id(connect, collection)
66
        with pytest.raises(Exception) as e:
67
            vector_ids = connect.list_id_in_segment(collection_name, segment_id)
68 69 70
    
    @pytest.fixture(
        scope="function",
71
        params=gen_invalid_strs()
72
    )
X
Xiaohai Xu 已提交
73
    def get_collection_name(self, request):
74 75
        yield request.param

D
del-zhenwu 已提交
76
    def test_list_id_in_segment_collection_name_invalid(self, connect, collection, get_collection_name):
77
        '''
X
Xiaohai Xu 已提交
78
        target: get vector ids where collection name is invalid
D
del-zhenwu 已提交
79
        method: call list_id_in_segment with invalid collection_name
80 81
        expected: status not ok
        '''
X
Xiaohai Xu 已提交
82
        collection_name = get_collection_name
83
        ids, segment_id = get_segment_id(connect, collection)
84
        with pytest.raises(Exception) as e:
85
            connect.list_id_in_segment(collection_name, segment_id)
86

D
del-zhenwu 已提交
87
    def test_list_id_in_segment_name_None(self, connect, collection):
88 89
        '''
        target: get vector ids where segment name is None
D
del-zhenwu 已提交
90
        method: call list_id_in_segment with the name: None
91 92
        expected: exception raised
        '''
93
        ids, segment_id = get_segment_id(connect, collection)
94 95
        segment = None
        with pytest.raises(Exception) as e:
96
            vector_ids = connect.list_id_in_segment(collection, segment)
97

D
del-zhenwu 已提交
98
    def test_list_id_in_segment_name_not_existed(self, connect, collection):
99 100
        '''
        target: get vector ids where segment name does not exist
D
del-zhenwu 已提交
101
        method: call list_id_in_segment with a random segment name
102 103
        expected: status not ok
        '''
104 105
        ids, seg_id = get_segment_id(connect, collection)
        # segment = gen_unique_str(collection_id)
106
        with pytest.raises(Exception) as e:
107
            vector_ids = connect.list_id_in_segment(collection, seg_id + 10000)
108

109
    @pytest.mark.level(2)
D
del-zhenwu 已提交
110
    def test_list_id_in_segment_without_index_A(self, connect, collection):
111 112
        '''
        target: get vector ids when there is no index
D
del-zhenwu 已提交
113
        method: call list_id_in_segment and check if the segment contains vectors
114 115
        expected: status ok
        '''
116
        nb = 1
117 118
        ids, seg_id = get_segment_id(connect, collection, nb=nb)
        vector_ids = connect.list_id_in_segment(collection, seg_id)
119
        # vector_ids should match ids
120 121
        assert len(vector_ids) == nb
        assert vector_ids[0] == ids[0]
122

123
    @pytest.mark.level(2)
D
del-zhenwu 已提交
124
    def test_list_id_in_segment_without_index_B(self, connect, collection):
125 126
        '''
        target: get vector ids when there is no index but with partition
D
del-zhenwu 已提交
127
        method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
128 129
        expected: status ok
        '''
130 131 132 133 134 135 136
        nb = 10
        entities = gen_entities(nb)
        connect.create_partition(collection, tag)
        ids = connect.insert(collection, entities, partition_tag=tag)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        assert stats["partitions"][1]["tag"] == tag
137
        vector_ids = connect.list_id_in_segment(collection, stats["partitions"][1]["segments"][0]["id"])
138
        # vector_ids should match ids
139 140
        assert len(vector_ids) == nb
        for i in range(nb):
141 142 143 144
            assert vector_ids[i] == ids[i]

    @pytest.fixture(
        scope="function",
145
        params=gen_simple_index()
146
    )
147
    def get_simple_index(self, request, connect):
148 149 150
        if str(connect._cmd("mode")) == "CPU":
            if request.param["index_type"] in index_cpu_not_support():
                pytest.skip("CPU not support index_type: ivf_sq8h")
151 152
        return request.param

153
    @pytest.mark.level(2)
D
del-zhenwu 已提交
154
    def test_list_id_in_segment_with_index_A(self, connect, collection, get_simple_index):
155 156
        '''
        target: get vector ids when there is index
D
del-zhenwu 已提交
157
        method: call list_id_in_segment and check if the segment contains vectors
158 159
        expected: status ok
        '''
160 161 162 163 164
        ids, seg_id = get_segment_id(connect, collection, nb=nb, index_params=get_simple_index)
        try:
            connect.list_id_in_segment(collection, seg_id)
        except Exception as e:
            assert False, str(e)
165
        # TODO: 
166

167
    @pytest.mark.level(2)
D
del-zhenwu 已提交
168
    def test_list_id_in_segment_with_index_B(self, connect, collection, get_simple_index):
169 170
        '''
        target: get vector ids when there is index and with partition
D
del-zhenwu 已提交
171
        method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
172 173
        expected: status ok
        '''
174 175 176 177 178
        connect.create_partition(collection, tag)
        ids = connect.insert(collection, entities, partition_tag=tag)
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
        assert stats["partitions"][1]["tag"] == tag
179 180 181 182
        try:
            connect.list_id_in_segment(collection, stats["partitions"][1]["segments"][0]["id"])
        except Exception as e:
            assert False, str(e)
183
        # vector_ids should match ids
184
        # TODO
185

186
    @pytest.mark.level(2)
D
del-zhenwu 已提交
187
    def test_list_id_in_segment_after_delete_vectors(self, connect, collection):
188 189
        '''
        target: get vector ids after vectors are deleted
D
del-zhenwu 已提交
190
        method: add vectors and delete a few, call list_id_in_segment
191 192
        expected: status ok, vector_ids decreased after vectors deleted
        '''
193
        nb = 2
194
        ids, seg_id = get_segment_id(connect, collection, nb=nb)
195
        delete_ids = [ids[0]]
D
del-zhenwu 已提交
196
        status = connect.delete_entity_by_id(collection, delete_ids)
197 198
        connect.flush([collection])
        stats = connect.get_collection_stats(collection)
199
        vector_ids = connect.list_id_in_segment(collection, stats["partitions"][0]["segments"][0]["id"])
200 201 202
        assert len(vector_ids) == 1
        assert vector_ids[0] == ids[1]

203
    @pytest.mark.level(2)
204
    def test_list_id_in_segment_with_index_ip(self, connect, collection, get_simple_index):
205 206
        '''
        target: get vector ids when there is index
D
del-zhenwu 已提交
207
        method: call list_id_in_segment and check if the segment contains vectors
208
        expected: ids returned in ids inserted
209
        '''
210
        get_simple_index["metric_type"] = "IP"
211 212
        ids, seg_id = get_segment_id(connect, collection, nb=nb, index_params=get_simple_index)
        vector_ids = connect.list_id_in_segment(collection, seg_id)
213
        # TODO: 
214
        assert vector_ids == ids
215 216


217
class TestListIdInSegmentBinary:
218 219
    """
    ******************************************************************
D
del-zhenwu 已提交
220
      The following cases are used to test `list_id_in_segment` function
221 222
    ******************************************************************
    """
223
    @pytest.mark.level(2)
224
    def test_list_id_in_segment_without_index_A(self, connect, binary_collection):
225 226
        '''
        target: get vector ids when there is no index
D
del-zhenwu 已提交
227
        method: call list_id_in_segment and check if the segment contains vectors
228 229
        expected: status ok
        '''
230 231
        nb = 10
        vectors, entities = gen_binary_entities(nb)
232 233 234 235
        ids = connect.insert(binary_collection, entities)
        connect.flush([binary_collection])
        stats = connect.get_collection_stats(binary_collection)
        vector_ids = connect.list_id_in_segment(binary_collection, stats["partitions"][0]["segments"][0]["id"])
236
        # vector_ids should match ids
237 238
        assert len(vector_ids) == nb
        for i in range(nb):
239 240
            assert vector_ids[i] == ids[i]

241
    @pytest.mark.level(2)
242
    def test_list_id_in_segment_without_index_B(self, connect, binary_collection):
243 244
        '''
        target: get vector ids when there is no index but with partition
D
del-zhenwu 已提交
245
        method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
246 247
        expected: status ok
        '''
248
        connect.create_partition(binary_collection, tag)
249 250
        nb = 10
        vectors, entities = gen_binary_entities(nb)
251 252 253 254
        ids = connect.insert(binary_collection, entities, partition_tag=tag)
        connect.flush([binary_collection])
        stats = connect.get_collection_stats(binary_collection)
        vector_ids = connect.list_id_in_segment(binary_collection, stats["partitions"][1]["segments"][0]["id"])
255
        # vector_ids should match ids
256 257
        assert len(vector_ids) == nb
        for i in range(nb):
258 259 260 261
            assert vector_ids[i] == ids[i]

    @pytest.fixture(
        scope="function",
262
        params=gen_binary_index()
263
    )
264
    def get_jaccard_index(self, request, connect):
265
        logging.getLogger().info(request.param)
266
        if request.param["index_type"] in binary_support():
267
            request.param["metric_type"] = "JACCARD"
268 269
            return request.param
        else:
270
            pytest.skip("not support")
271

272
    def test_list_id_in_segment_with_index_A(self, connect, binary_collection, get_jaccard_index):
273 274
        '''
        target: get vector ids when there is index
D
del-zhenwu 已提交
275
        method: call list_id_in_segment and check if the segment contains vectors
276 277
        expected: status ok
        '''
278
        ids, seg_id = get_segment_id(connect, binary_collection, nb=nb, index_params=get_jaccard_index, vec_type='binary')
279
        vector_ids = connect.list_id_in_segment(binary_collection, seg_id)
280
        # TODO: 
281

282
    def test_list_id_in_segment_with_index_B(self, connect, binary_collection, get_jaccard_index):
283 284
        '''
        target: get vector ids when there is index and with partition
D
del-zhenwu 已提交
285
        method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
286 287
        expected: status ok
        '''
288
        connect.create_partition(binary_collection, tag)
289
        ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
290 291
        connect.flush([binary_collection])
        stats = connect.get_collection_stats(binary_collection)
292
        assert stats["partitions"][1]["tag"] == tag
293
        vector_ids = connect.list_id_in_segment(binary_collection, stats["partitions"][1]["segments"][0]["id"])
294
        # vector_ids should match ids
295
        # TODO
296

297
    def test_list_id_in_segment_after_delete_vectors(self, connect, binary_collection, get_jaccard_index):
298 299
        '''
        target: get vector ids after vectors are deleted
D
del-zhenwu 已提交
300
        method: add vectors and delete a few, call list_id_in_segment
301 302
        expected: status ok, vector_ids decreased after vectors deleted
        '''
303
        nb = 2
304
        ids, seg_id = get_segment_id(connect, binary_collection, nb=nb, vec_type='binary', index_params=get_jaccard_index)
305
        delete_ids = [ids[0]]
306 307 308 309
        status = connect.delete_entity_by_id(binary_collection, delete_ids)
        connect.flush([binary_collection])
        stats = connect.get_collection_stats(binary_collection)
        vector_ids = connect.list_id_in_segment(binary_collection, stats["partitions"][0]["segments"][0]["id"])
310
        assert len(vector_ids) == 1
311
        assert vector_ids[0] == ids[1]