test_mix.py 7.3 KB
Newer Older
J
JinHai-CN 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
import pdb
import copy
import pytest
import threading
import datetime
import logging
from time import sleep
from multiprocessing import Process
import numpy
from milvus import Milvus, IndexType, MetricType
from utils import *

dim = 128
index_file_size = 10
table_id = "test_mix"
add_interval_time = 2
vectors = gen_vectors(100000, dim)
vectors /= numpy.linalg.norm(vectors)
vectors = vectors.tolist()
top_k = 1
nprobe = 1
epsilon = 0.0001
index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}


class TestMixBase:

    # TODO: enable
    def _test_search_during_createIndex(self, args):
        loops = 100000
        table = "test_search_during_createIndex"
        query_vecs = [vectors[0], vectors[1]]
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
        id_0 = 0; id_1 = 0
        milvus_instance = Milvus()
        milvus_instance.connect(uri=uri)
        milvus_instance.create_table({'table_name': table,
             'dimension': dim,
             'index_file_size': index_file_size,
             'metric_type': MetricType.L2})
        for i in range(10):
            status, ids = milvus_instance.add_vectors(table, vectors)
            # logging.getLogger().info(ids)
            if i == 0:
                id_0 = ids[0]; id_1 = ids[1]
        def create_index(milvus_instance):
            logging.getLogger().info("In create index")
            status = milvus_instance.create_index(table, index_params)
            logging.getLogger().info(status)
            status, result = milvus_instance.describe_index(table)
            logging.getLogger().info(result)
        def add_vectors(milvus_instance):
            logging.getLogger().info("In add vectors")
            status, ids = milvus_instance.add_vectors(table, vectors)
            logging.getLogger().info(status)
        def search(milvus_instance):
            for i in range(loops):
                status, result = milvus_instance.search_vectors(table, top_k, nprobe, query_vecs)
                logging.getLogger().info(status)
                assert result[0][0].id == id_0
                assert result[1][0].id == id_1
        milvus_instance = Milvus()
        milvus_instance.connect(uri=uri)
        p_search = Process(target=search, args=(milvus_instance, ))
        p_search.start()
        milvus_instance = Milvus()
        milvus_instance.connect(uri=uri)
        p_create = Process(target=add_vectors, args=(milvus_instance, ))
        p_create.start()
        p_create.join()

    def test_mix_multi_tables(self, connect):
        '''
        target: test functions with multiple tables of different metric_types and index_types
        method: create 60 tables which 30 are L2 and the other are IP, add vectors into them
                and test describe index and search
        expected: status ok
        '''
        nq = 10000
        vectors = gen_vectors(nq, dim)
        table_list = []
        idx = []

        #create table and add vectors
        for i in range(30):
            table_name = gen_unique_str('test_mix_multi_tables')
            table_list.append(table_name)
            param = {'table_name': table_name,
                     'dimension': dim,
                     'index_file_size': index_file_size,
                     'metric_type': MetricType.L2}
            connect.create_table(param)
            status, ids = connect.add_vectors(table_name=table_name, records=vectors)
            idx.append(ids[0])
            idx.append(ids[10])
            idx.append(ids[20])
            assert status.OK()
        for i in range(30):
            table_name = gen_unique_str('test_mix_multi_tables')
            table_list.append(table_name)
            param = {'table_name': table_name,
                     'dimension': dim,
                     'index_file_size': index_file_size,
                     'metric_type': MetricType.IP}
            connect.create_table(param)
            status, ids = connect.add_vectors(table_name=table_name, records=vectors)
            idx.append(ids[0])
            idx.append(ids[10])
            idx.append(ids[20])
            assert status.OK()
        time.sleep(2)

        #create index
        for i in range(10):
            index_params = {'index_type': IndexType.FLAT, 'nlist': 16384}
            status = connect.create_index(table_list[i], index_params)
            assert status.OK()
            status = connect.create_index(table_list[30 + i], index_params)
            assert status.OK()
            index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
            status = connect.create_index(table_list[10 + i], index_params)
            assert status.OK()
            status = connect.create_index(table_list[40 + i], index_params)
            assert status.OK()
            index_params = {'index_type': IndexType.IVF_SQ8, 'nlist': 16384}
            status = connect.create_index(table_list[20 + i], index_params)
            assert status.OK()
            status = connect.create_index(table_list[50 + i], index_params)
            assert status.OK()

        #describe index
        for i in range(10):
            status, result = connect.describe_index(table_list[i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[i]
            assert result._index_type == IndexType.FLAT
            status, result = connect.describe_index(table_list[10 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[10 + i]
            assert result._index_type == IndexType.IVFLAT
            status, result = connect.describe_index(table_list[20 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[20 + i]
            assert result._index_type == IndexType.IVF_SQ8
            status, result = connect.describe_index(table_list[30 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[30 + i]
            assert result._index_type == IndexType.FLAT
            status, result = connect.describe_index(table_list[40 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[40 + i]
            assert result._index_type == IndexType.IVFLAT
            status, result = connect.describe_index(table_list[50 + i])
            logging.getLogger().info(result)
            assert result._nlist == 16384
            assert result._table_name == table_list[50 + i]
            assert result._index_type == IndexType.IVF_SQ8

        #search
        query_vecs = [vectors[0], vectors[10], vectors[20]]
        for i in range(60):
            table = table_list[i]
            status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
            assert status.OK()
            assert len(result) == len(query_vecs)
            for j in range(len(query_vecs)):
                assert len(result[j]) == top_k
            for j in range(len(query_vecs)):
                assert check_result(result[j], idx[3 * i + j])

def check_result(result, id):
    if len(result) >= 5:
        return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
    else:
        return id in (i.id for i in result)