test_mix.py 7.5 KB
Newer Older
J
JinHai-CN 已提交
1 2 3 4 5 6 7 8
import pdb
import copy
import pytest
import threading
import datetime
import logging
from time import sleep
from multiprocessing import Process
Z
zhenwu 已提交
9
import sklearn.preprocessing
J
JinHai-CN 已提交
10 11 12 13 14 15 16 17
from milvus import Milvus, IndexType, MetricType
from utils import *

dim = 128
index_file_size = 10
table_id = "test_mix"
add_interval_time = 2
vectors = gen_vectors(100000, dim)
Z
zhenwu 已提交
18
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
J
JinHai-CN 已提交
19 20 21 22 23 24 25 26 27 28
vectors = vectors.tolist()
top_k = 1
nprobe = 1
epsilon = 0.0001
index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}


class TestMixBase:

    # TODO: enable
Z
zhenwu 已提交
29
    @pytest.mark.level(2)
Z
zhenwu 已提交
30
    def test_search_during_createIndex(self, args):
J
JinHai-CN 已提交
31
        loops = 100000
Z
zhenwu 已提交
32
        table = gen_unique_str()
J
JinHai-CN 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        query_vecs = [vectors[0], vectors[1]]
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
        id_0 = 0; id_1 = 0
        milvus_instance = Milvus()
        milvus_instance.connect(uri=uri)
        milvus_instance.create_table({'table_name': table,
             'dimension': dim,
             'index_file_size': index_file_size,
             'metric_type': MetricType.L2})
        for i in range(10):
            status, ids = milvus_instance.add_vectors(table, vectors)
            # logging.getLogger().info(ids)
            if i == 0:
                id_0 = ids[0]; id_1 = ids[1]
        def create_index(milvus_instance):
            logging.getLogger().info("In create index")
            status = milvus_instance.create_index(table, index_params)
            logging.getLogger().info(status)
            status, result = milvus_instance.describe_index(table)
            logging.getLogger().info(result)
        def add_vectors(milvus_instance):
            logging.getLogger().info("In add vectors")
            status, ids = milvus_instance.add_vectors(table, vectors)
            logging.getLogger().info(status)
        def search(milvus_instance):
Z
zhenwu 已提交
58
            logging.getLogger().info("In search vectors")
J
JinHai-CN 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
            for i in range(loops):
                status, result = milvus_instance.search_vectors(table, top_k, nprobe, query_vecs)
                logging.getLogger().info(status)
                assert result[0][0].id == id_0
                assert result[1][0].id == id_1
        milvus_instance = Milvus()
        milvus_instance.connect(uri=uri)
        p_search = Process(target=search, args=(milvus_instance, ))
        p_search.start()
        milvus_instance = Milvus()
        milvus_instance.connect(uri=uri)
        p_create = Process(target=add_vectors, args=(milvus_instance, ))
        p_create.start()
        p_create.join()

Z
zhenwu 已提交
74
    @pytest.mark.level(2)
J
JinHai-CN 已提交
75 76 77 78 79 80 81 82
    def test_mix_multi_tables(self, connect):
        '''
        target: test functions with multiple tables of different metric_types and index_types
        method: create 60 tables which 30 are L2 and the other are IP, add vectors into them
                and test describe index and search
        expected: status ok
        '''
        nq = 10000
Z
zhenwu 已提交
83
        nlist= 16384
J
JinHai-CN 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        vectors = gen_vectors(nq, dim)
        table_list = []
        idx = []

        #create table and add vectors
        for i in range(30):
            table_name = gen_unique_str('test_mix_multi_tables')
            table_list.append(table_name)
            param = {'table_name': table_name,
                     'dimension': dim,
                     'index_file_size': index_file_size,
                     'metric_type': MetricType.L2}
            connect.create_table(param)
            status, ids = connect.add_vectors(table_name=table_name, records=vectors)
            idx.append(ids[0])
            idx.append(ids[10])
            idx.append(ids[20])
            assert status.OK()
        for i in range(30):
            table_name = gen_unique_str('test_mix_multi_tables')
            table_list.append(table_name)
            param = {'table_name': table_name,
                     'dimension': dim,
                     'index_file_size': index_file_size,
                     'metric_type': MetricType.IP}
            connect.create_table(param)
            status, ids = connect.add_vectors(table_name=table_name, records=vectors)
            idx.append(ids[0])
            idx.append(ids[10])
            idx.append(ids[20])
            assert status.OK()
        time.sleep(2)

        #create index
        for i in range(10):
Z
zhenwu 已提交
119
            index_params = {'index_type': IndexType.FLAT, 'nlist': nlist}
J
JinHai-CN 已提交
120 121 122 123
            status = connect.create_index(table_list[i], index_params)
            assert status.OK()
            status = connect.create_index(table_list[30 + i], index_params)
            assert status.OK()
Z
zhenwu 已提交
124
            index_params = {'index_type': IndexType.IVFLAT, 'nlist': nlist}
J
JinHai-CN 已提交
125 126 127 128
            status = connect.create_index(table_list[10 + i], index_params)
            assert status.OK()
            status = connect.create_index(table_list[40 + i], index_params)
            assert status.OK()
Z
zhenwu 已提交
129
            index_params = {'index_type': IndexType.IVF_SQ8, 'nlist': nlist}
J
JinHai-CN 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
            status = connect.create_index(table_list[20 + i], index_params)
            assert status.OK()
            status = connect.create_index(table_list[50 + i], index_params)
            assert status.OK()

        #describe index
        for i in range(10):
            status, result = connect.describe_index(table_list[i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[i]
            assert result._index_type == IndexType.FLAT
            status, result = connect.describe_index(table_list[10 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[10 + i]
            assert result._index_type == IndexType.IVFLAT
            status, result = connect.describe_index(table_list[20 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[20 + i]
            assert result._index_type == IndexType.IVF_SQ8
            status, result = connect.describe_index(table_list[30 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[30 + i]
            assert result._index_type == IndexType.FLAT
            status, result = connect.describe_index(table_list[40 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[40 + i]
            assert result._index_type == IndexType.IVFLAT
            status, result = connect.describe_index(table_list[50 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[50 + i]
            assert result._index_type == IndexType.IVF_SQ8

        #search
        query_vecs = [vectors[0], vectors[10], vectors[20]]
        for i in range(60):
            table = table_list[i]
            status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
            assert status.OK()
            assert len(result) == len(query_vecs)
            for j in range(len(query_vecs)):
                assert len(result[j]) == top_k
            for j in range(len(query_vecs)):
                assert check_result(result[j], idx[3 * i + j])

def check_result(result, id):
    if len(result) >= 5:
        return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
    else:
Z
zhenwu 已提交
184
        return id in (i.id for i in result)