db_tests.cpp 6.3 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
X
xj.lin 已提交
7 8 9
#include <faiss/IndexFlat.h>
#include <faiss/MetaIndexes.h>
#include <faiss/AutoTune.h>
X
Xu Peng 已提交
10 11
#include <thread>
#include <easylogging++.h>
12
#include <chrono>
G
groot 已提交
13 14

#include "db/DB.h"
X
xj.lin 已提交
15
#include "faiss/Index.h"
G
groot 已提交
16 17 18

using namespace zilliz::vecwise;

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#define TIMING

#ifdef TIMING
#define INIT_TIMER auto start = std::chrono::high_resolution_clock::now();
#define START_TIMER  start = std::chrono::high_resolution_clock::now();
#define STOP_TIMER(name)  LOG(DEBUG) << "RUNTIME of " << name << ": " << \
    std::chrono::duration_cast<std::chrono::milliseconds>( \
            std::chrono::high_resolution_clock::now()-start \
    ).count() << " ms ";
#else
#define INIT_TIMER
#define START_TIMER
#define STOP_TIMER(name)
#endif

34 35 36 37 38 39 40 41 42 43 44 45
class DBTest : public ::testing::Test {
protected:
    virtual void SetUp() {
        el::Configurations defaultConf;
        defaultConf.setToDefault();
        defaultConf.set(el::Level::Debug,
                el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)");
        el::Loggers::reconfigureLogger("default", defaultConf);
    }

};

G
groot 已提交
46 47 48 49 50 51 52 53 54
namespace {
    void ASSERT_STATS(engine::Status& stat) {
        ASSERT_TRUE(stat.ok());
        if(!stat.ok()) {
            std::cout << stat.ToString() << std::endl;
        }
    }
}

55 56
TEST_F(DBTest, DB_TEST) {

G
groot 已提交
57 58 59 60
    static const std::string group_name = "test_group";
    static const int group_dim = 256;

    engine::Options opt;
X
Xu Peng 已提交
61
    opt.memory_sync_interval = 1;
X
Xu Peng 已提交
62
    opt.index_trigger_size = 1024*group_dim;
G
groot 已提交
63
    opt.meta.backend_uri = "http://127.0.0.1";
G
groot 已提交
64
    opt.meta.path = "/tmp/vecwise_test/db_test";
G
groot 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

    engine::DB* db = nullptr;
    engine::DB::Open(opt, &db);
    ASSERT_TRUE(db != nullptr);

    engine::meta::GroupSchema group_info;
    group_info.dimension = group_dim;
    group_info.group_id = group_name;
    engine::Status stat = db->add_group(group_info);

    engine::meta::GroupSchema group_info_get;
    group_info_get.group_id = group_name;
    stat = db->get_group(group_info_get);
    ASSERT_STATS(stat);
    ASSERT_EQ(group_info_get.dimension, group_dim);

    engine::IDNumbers vector_ids;
X
Xu Peng 已提交
82 83 84
    engine::IDNumbers target_ids;

    int d = 256;
X
Xu Peng 已提交
85
    int nb = 100;
X
Xu Peng 已提交
86 87 88 89
    float *xb = new float[d * nb];
    for(int i = 0; i < nb; i++) {
        for(int j = 0; j < d; j++) xb[d * i + j] = drand48();
        xb[d * i] += i / 2000.;
G
groot 已提交
90 91
    }

X
Xu Peng 已提交
92 93 94 95 96 97 98
    int qb = 1;
    float *qxb = new float[d * qb];
    for(int i = 0; i < qb; i++) {
        for(int j = 0; j < d; j++) qxb[d * i + j] = drand48();
        qxb[d * i] += i / 2000.;
    }

X
Xu Peng 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    std::thread search([&]() {
        engine::QueryResults results;
        int k = 10;
        std::this_thread::sleep_for(std::chrono::seconds(2));
        /* std::this_thread::sleep_for(std::chrono::milliseconds(30)); */

        INIT_TIMER;
        std::stringstream ss;
        long count = 0;

        for (auto j=0; j<5; ++j) {
            ss.str("");
            db->count(group_name, count);

            ss << "Search " << j << " With Size " << count;

            START_TIMER;
            stat = db->search(group_name, k, qb, qxb, results);
            STOP_TIMER(ss.str());

            ASSERT_STATS(stat);
            ASSERT_EQ(results[0][0], target_ids[0]);
            std::this_thread::sleep_for(std::chrono::seconds(1));
        }
    });

    int loop = 40000;
X
Xu Peng 已提交
126 127 128 129 130 131 132

    for (auto i=0; i<loop; ++i) {
        if (i==40) {
            db->add_vectors(group_name, qb, qxb, target_ids);
        } else {
            db->add_vectors(group_name, nb, xb, vector_ids);
        }
X
Xu Peng 已提交
133
        std::this_thread::sleep_for(std::chrono::microseconds(100));
X
Xu Peng 已提交
134
    }
X
xj.lin 已提交
135

X
Xu Peng 已提交
136
    search.join();
X
Xu Peng 已提交
137 138 139 140 141 142

    delete [] xb;
    delete [] qxb;
    delete db;
    engine::DB::Open(opt, &db);
    db->drop_all();
X
xj.lin 已提交
143
    delete db;
144
};
X
xj.lin 已提交
145

146
TEST_F(DBTest, SEARCH_TEST) {
X
xj.lin 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    static const std::string group_name = "test_group";
    static const int group_dim = 256;

    engine::Options opt;
    opt.meta.backend_uri = "http://127.0.0.1";
    opt.meta.path = "/tmp/search_test";
    opt.index_trigger_size = 100000 * group_dim;
    opt.memory_sync_interval = 1;
    opt.merge_trigger_number = 1;

    engine::DB* db = nullptr;
    engine::DB::Open(opt, &db);
    ASSERT_TRUE(db != nullptr);

    engine::meta::GroupSchema group_info;
    group_info.dimension = group_dim;
    group_info.group_id = group_name;
    engine::Status stat = db->add_group(group_info);
    //ASSERT_STATS(stat);

    engine::meta::GroupSchema group_info_get;
    group_info_get.group_id = group_name;
    stat = db->get_group(group_info_get);
    ASSERT_STATS(stat);
    ASSERT_EQ(group_info_get.dimension, group_dim);


    // prepare raw data
X
xj.lin 已提交
175
    size_t nb = 250000;
X
xj.lin 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    size_t nq = 10;
    size_t k = 5;
    std::vector<float> xb(nb*group_dim);
    std::vector<float> xq(nq*group_dim);
    std::vector<long> ids(nb);

    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
    for (size_t i = 0; i < nb*group_dim; i++) {
        xb[i] = dis_xt(gen);
        if (i < nb){
            ids[i] = i;
        }
    }
    for (size_t i = 0; i < nq*group_dim; i++) {
        xq[i] = dis_xt(gen);
    }

    // result data
    //std::vector<long> nns_gt(k*nq);
    std::vector<long> nns(k*nq);  // nns = nearst neg search
    //std::vector<float> dis_gt(k*nq);
    std::vector<float> dis(k*nq);

    // prepare ground-truth
    //faiss::Index* index_gt(faiss::index_factory(group_dim, "IDMap,Flat"));
    //index_gt->add_with_ids(nb, xb.data(), ids.data());
    //index_gt->search(nq, xq.data(), 1, dis_gt.data(), nns_gt.data());

    // insert data
    const int batch_size = 100;
    for (int j = 0; j < nb / batch_size; ++j) {
        stat = db->add_vectors(group_name, batch_size, xb.data()+batch_size*j*group_dim, ids);
X
xj.lin 已提交
210
        if (j == 200){ sleep(1);}
X
xj.lin 已提交
211 212 213
        ASSERT_STATS(stat);
    }

X
xj.lin 已提交
214
    sleep(3); // wait until build index finish
X
xj.lin 已提交
215

G
groot 已提交
216
    engine::QueryResults results;
X
xj.lin 已提交
217
    stat = db->search(group_name, k, nq, xq.data(), results);
G
groot 已提交
218
    ASSERT_STATS(stat);
X
xj.lin 已提交
219 220

    // TODO(linxj): add groundTruth assert
G
groot 已提交
221 222

    delete db;
X
xj.lin 已提交
223

X
Xu Peng 已提交
224 225 226
    engine::DB::Open(opt, &db);
    db->drop_all();
    delete db;
227
};