db_tests.cpp 6.0 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
X
xj.lin 已提交
7 8 9
#include <faiss/IndexFlat.h>
#include <faiss/MetaIndexes.h>
#include <faiss/AutoTune.h>
X
Xu Peng 已提交
10 11
#include <thread>
#include <easylogging++.h>
12
#include <chrono>
G
groot 已提交
13 14

#include "db/DB.h"
X
xj.lin 已提交
15
#include "faiss/Index.h"
G
groot 已提交
16 17 18

using namespace zilliz::vecwise;

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#define TIMING

#ifdef TIMING
#define INIT_TIMER auto start = std::chrono::high_resolution_clock::now();
#define START_TIMER  start = std::chrono::high_resolution_clock::now();
#define STOP_TIMER(name)  LOG(DEBUG) << "RUNTIME of " << name << ": " << \
    std::chrono::duration_cast<std::chrono::milliseconds>( \
            std::chrono::high_resolution_clock::now()-start \
    ).count() << " ms ";
#else
#define INIT_TIMER
#define START_TIMER
#define STOP_TIMER(name)
#endif

34 35 36 37 38 39 40 41 42 43 44 45
class DBTest : public ::testing::Test {
protected:
    virtual void SetUp() {
        el::Configurations defaultConf;
        defaultConf.setToDefault();
        defaultConf.set(el::Level::Debug,
                el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)");
        el::Loggers::reconfigureLogger("default", defaultConf);
    }

};

G
groot 已提交
46 47 48 49 50 51 52 53 54
namespace {
    void ASSERT_STATS(engine::Status& stat) {
        ASSERT_TRUE(stat.ok());
        if(!stat.ok()) {
            std::cout << stat.ToString() << std::endl;
        }
    }
}

55 56
TEST_F(DBTest, DB_TEST) {

G
groot 已提交
57 58 59 60
    static const std::string group_name = "test_group";
    static const int group_dim = 256;

    engine::Options opt;
X
Xu Peng 已提交
61
    opt.memory_sync_interval = 1;
G
groot 已提交
62
    opt.meta.backend_uri = "http://127.0.0.1";
G
groot 已提交
63
    opt.meta.path = "/tmp/vecwise_test/db_test";
G
groot 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

    engine::DB* db = nullptr;
    engine::DB::Open(opt, &db);
    ASSERT_TRUE(db != nullptr);

    engine::meta::GroupSchema group_info;
    group_info.dimension = group_dim;
    group_info.group_id = group_name;
    engine::Status stat = db->add_group(group_info);

    engine::meta::GroupSchema group_info_get;
    group_info_get.group_id = group_name;
    stat = db->get_group(group_info_get);
    ASSERT_STATS(stat);
    ASSERT_EQ(group_info_get.dimension, group_dim);

    engine::IDNumbers vector_ids;
X
Xu Peng 已提交
81 82 83
    engine::IDNumbers target_ids;

    int d = 256;
84
    int nb = 5;
X
Xu Peng 已提交
85 86 87 88
    float *xb = new float[d * nb];
    for(int i = 0; i < nb; i++) {
        for(int j = 0; j < d; j++) xb[d * i + j] = drand48();
        xb[d * i] += i / 2000.;
G
groot 已提交
89 90
    }

X
Xu Peng 已提交
91 92 93 94 95 96 97
    int qb = 1;
    float *qxb = new float[d * qb];
    for(int i = 0; i < qb; i++) {
        for(int j = 0; j < d; j++) qxb[d * i + j] = drand48();
        qxb[d * i] += i / 2000.;
    }

98
    int loop = 2000000;
X
Xu Peng 已提交
99 100 101 102 103 104 105 106

    for (auto i=0; i<loop; ++i) {
        if (i==40) {
            db->add_vectors(group_name, qb, qxb, target_ids);
        } else {
            db->add_vectors(group_name, nb, xb, vector_ids);
        }
    }
X
xj.lin 已提交
107

108 109 110 111
    engine::QueryResults results;
    int k = 10;
    std::this_thread::sleep_for(std::chrono::seconds(2));
    INIT_TIMER;
X
Xu Peng 已提交
112

113
    std::stringstream ss;
X
Xu Peng 已提交
114 115
    long count = 0;

116 117 118 119 120 121 122
    for (auto j=0; j<15; ++j) {
        ss.str("");
        db->count(group_name, count);

        ss << "Search " << j << " With Size " << count;

        START_TIMER;
X
Xu Peng 已提交
123
        stat = db->search(group_name, k, qb, qxb, results);
124 125
        STOP_TIMER(ss.str());

X
Xu Peng 已提交
126 127
        ASSERT_STATS(stat);
        ASSERT_EQ(results[0][0], target_ids[0]);
128
        std::this_thread::sleep_for(std::chrono::seconds(1));
X
Xu Peng 已提交
129
    }
X
Xu Peng 已提交
130 131 132 133 134 135

    delete [] xb;
    delete [] qxb;
    delete db;
    engine::DB::Open(opt, &db);
    db->drop_all();
X
xj.lin 已提交
136
    delete db;
137
};
X
xj.lin 已提交
138

139
TEST_F(DBTest, SEARCH_TEST) {
X
xj.lin 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
    static const std::string group_name = "test_group";
    static const int group_dim = 256;

    engine::Options opt;
    opt.meta.backend_uri = "http://127.0.0.1";
    opt.meta.path = "/tmp/search_test";
    opt.index_trigger_size = 100000 * group_dim;
    opt.memory_sync_interval = 1;
    opt.merge_trigger_number = 1;

    engine::DB* db = nullptr;
    engine::DB::Open(opt, &db);
    ASSERT_TRUE(db != nullptr);

    engine::meta::GroupSchema group_info;
    group_info.dimension = group_dim;
    group_info.group_id = group_name;
    engine::Status stat = db->add_group(group_info);
    //ASSERT_STATS(stat);

    engine::meta::GroupSchema group_info_get;
    group_info_get.group_id = group_name;
    stat = db->get_group(group_info_get);
    ASSERT_STATS(stat);
    ASSERT_EQ(group_info_get.dimension, group_dim);


    // prepare raw data
X
xj.lin 已提交
168
    size_t nb = 250000;
X
xj.lin 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
    size_t nq = 10;
    size_t k = 5;
    std::vector<float> xb(nb*group_dim);
    std::vector<float> xq(nq*group_dim);
    std::vector<long> ids(nb);

    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
    for (size_t i = 0; i < nb*group_dim; i++) {
        xb[i] = dis_xt(gen);
        if (i < nb){
            ids[i] = i;
        }
    }
    for (size_t i = 0; i < nq*group_dim; i++) {
        xq[i] = dis_xt(gen);
    }

    // result data
    //std::vector<long> nns_gt(k*nq);
    std::vector<long> nns(k*nq);  // nns = nearst neg search
    //std::vector<float> dis_gt(k*nq);
    std::vector<float> dis(k*nq);

    // prepare ground-truth
    //faiss::Index* index_gt(faiss::index_factory(group_dim, "IDMap,Flat"));
    //index_gt->add_with_ids(nb, xb.data(), ids.data());
    //index_gt->search(nq, xq.data(), 1, dis_gt.data(), nns_gt.data());

    // insert data
    const int batch_size = 100;
    for (int j = 0; j < nb / batch_size; ++j) {
        stat = db->add_vectors(group_name, batch_size, xb.data()+batch_size*j*group_dim, ids);
X
xj.lin 已提交
203
        if (j == 200){ sleep(1);}
X
xj.lin 已提交
204 205 206
        ASSERT_STATS(stat);
    }

X
xj.lin 已提交
207
    sleep(3); // wait until build index finish
X
xj.lin 已提交
208

G
groot 已提交
209
    engine::QueryResults results;
X
xj.lin 已提交
210
    stat = db->search(group_name, k, nq, xq.data(), results);
G
groot 已提交
211
    ASSERT_STATS(stat);
X
xj.lin 已提交
212 213

    // TODO(linxj): add groundTruth assert
G
groot 已提交
214 215

    delete db;
X
xj.lin 已提交
216

X
Xu Peng 已提交
217 218 219
    engine::DB::Open(opt, &db);
    db->drop_all();
    delete db;
220
};