db_tests.cpp 6.2 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
X
Xu Peng 已提交
7 8
#include <thread>
#include <easylogging++.h>
9
#include <chrono>
G
groot 已提交
10 11 12 13 14

#include "db/DB.h"

using namespace zilliz::vecwise;

15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
#define TIMING

#ifdef TIMING
#define INIT_TIMER auto start = std::chrono::high_resolution_clock::now();
#define START_TIMER  start = std::chrono::high_resolution_clock::now();
#define STOP_TIMER(name)  LOG(DEBUG) << "RUNTIME of " << name << ": " << \
    std::chrono::duration_cast<std::chrono::milliseconds>( \
            std::chrono::high_resolution_clock::now()-start \
    ).count() << " ms ";
#else
#define INIT_TIMER
#define START_TIMER
#define STOP_TIMER(name)
#endif

30 31 32 33 34 35 36 37 38 39 40 41
class DBTest : public ::testing::Test {
protected:
    virtual void SetUp() {
        el::Configurations defaultConf;
        defaultConf.setToDefault();
        defaultConf.set(el::Level::Debug,
                el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)");
        el::Loggers::reconfigureLogger("default", defaultConf);
    }

};

G
groot 已提交
42 43 44 45 46 47 48 49 50
namespace {
    void ASSERT_STATS(engine::Status& stat) {
        ASSERT_TRUE(stat.ok());
        if(!stat.ok()) {
            std::cout << stat.ToString() << std::endl;
        }
    }
}

51 52
TEST_F(DBTest, DB_TEST) {

G
groot 已提交
53 54 55 56
    static const std::string group_name = "test_group";
    static const int group_dim = 256;

    engine::Options opt;
X
Xu Peng 已提交
57
    opt.memory_sync_interval = 1;
X
Xu Peng 已提交
58
    opt.index_trigger_size = 1024*group_dim;
G
groot 已提交
59
    opt.meta.backend_uri = "http://127.0.0.1";
G
groot 已提交
60
    opt.meta.path = "/tmp/vecwise_test/db_test";
G
groot 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

    engine::DB* db = nullptr;
    engine::DB::Open(opt, &db);
    ASSERT_TRUE(db != nullptr);

    engine::meta::GroupSchema group_info;
    group_info.dimension = group_dim;
    group_info.group_id = group_name;
    engine::Status stat = db->add_group(group_info);

    engine::meta::GroupSchema group_info_get;
    group_info_get.group_id = group_name;
    stat = db->get_group(group_info_get);
    ASSERT_STATS(stat);
    ASSERT_EQ(group_info_get.dimension, group_dim);

    engine::IDNumbers vector_ids;
X
Xu Peng 已提交
78 79 80
    engine::IDNumbers target_ids;

    int d = 256;
X
Xu Peng 已提交
81
    int nb = 50;
X
Xu Peng 已提交
82 83 84 85
    float *xb = new float[d * nb];
    for(int i = 0; i < nb; i++) {
        for(int j = 0; j < d; j++) xb[d * i + j] = drand48();
        xb[d * i] += i / 2000.;
G
groot 已提交
86 87
    }

X
Xu Peng 已提交
88 89 90 91 92 93 94
    int qb = 1;
    float *qxb = new float[d * qb];
    for(int i = 0; i < qb; i++) {
        for(int j = 0; j < d; j++) qxb[d * i + j] = drand48();
        qxb[d * i] += i / 2000.;
    }

X
Xu Peng 已提交
95 96 97 98 99 100 101 102 103
    std::thread search([&]() {
        engine::QueryResults results;
        int k = 10;
        std::this_thread::sleep_for(std::chrono::seconds(2));

        INIT_TIMER;
        std::stringstream ss;
        long count = 0;

X
Xu Peng 已提交
104
        for (auto j=0; j<8; ++j) {
X
Xu Peng 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
            ss.str("");
            db->count(group_name, count);

            ss << "Search " << j << " With Size " << count;

            START_TIMER;
            stat = db->search(group_name, k, qb, qxb, results);
            STOP_TIMER(ss.str());

            ASSERT_STATS(stat);
            ASSERT_EQ(results[0][0], target_ids[0]);
            std::this_thread::sleep_for(std::chrono::seconds(1));
        }
    });

X
Xu Peng 已提交
120
    int loop = 100000;
X
Xu Peng 已提交
121 122 123 124 125 126 127

    for (auto i=0; i<loop; ++i) {
        if (i==40) {
            db->add_vectors(group_name, qb, qxb, target_ids);
        } else {
            db->add_vectors(group_name, nb, xb, vector_ids);
        }
X
Xu Peng 已提交
128
        std::this_thread::sleep_for(std::chrono::microseconds(5));
X
Xu Peng 已提交
129
    }
X
xj.lin 已提交
130

X
Xu Peng 已提交
131
    search.join();
X
Xu Peng 已提交
132 133 134 135 136 137

    delete [] xb;
    delete [] qxb;
    delete db;
    engine::DB::Open(opt, &db);
    db->drop_all();
X
xj.lin 已提交
138
    delete db;
139
};
X
xj.lin 已提交
140

141
TEST_F(DBTest, SEARCH_TEST) {
X
xj.lin 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    static const std::string group_name = "test_group";
    static const int group_dim = 256;

    engine::Options opt;
    opt.meta.backend_uri = "http://127.0.0.1";
    opt.meta.path = "/tmp/search_test";
    opt.index_trigger_size = 100000 * group_dim;
    opt.memory_sync_interval = 1;
    opt.merge_trigger_number = 1;

    engine::DB* db = nullptr;
    engine::DB::Open(opt, &db);
    ASSERT_TRUE(db != nullptr);

    engine::meta::GroupSchema group_info;
    group_info.dimension = group_dim;
    group_info.group_id = group_name;
    engine::Status stat = db->add_group(group_info);
    //ASSERT_STATS(stat);

    engine::meta::GroupSchema group_info_get;
    group_info_get.group_id = group_name;
    stat = db->get_group(group_info_get);
    ASSERT_STATS(stat);
    ASSERT_EQ(group_info_get.dimension, group_dim);


    // prepare raw data
X
xj.lin 已提交
170
    size_t nb = 250000;
X
xj.lin 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
    size_t nq = 10;
    size_t k = 5;
    std::vector<float> xb(nb*group_dim);
    std::vector<float> xq(nq*group_dim);
    std::vector<long> ids(nb);

    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
    for (size_t i = 0; i < nb*group_dim; i++) {
        xb[i] = dis_xt(gen);
        if (i < nb){
            ids[i] = i;
        }
    }
    for (size_t i = 0; i < nq*group_dim; i++) {
        xq[i] = dis_xt(gen);
    }

    // result data
    //std::vector<long> nns_gt(k*nq);
    std::vector<long> nns(k*nq);  // nns = nearst neg search
    //std::vector<float> dis_gt(k*nq);
    std::vector<float> dis(k*nq);

    // prepare ground-truth
    //faiss::Index* index_gt(faiss::index_factory(group_dim, "IDMap,Flat"));
    //index_gt->add_with_ids(nb, xb.data(), ids.data());
    //index_gt->search(nq, xq.data(), 1, dis_gt.data(), nns_gt.data());

    // insert data
    const int batch_size = 100;
    for (int j = 0; j < nb / batch_size; ++j) {
        stat = db->add_vectors(group_name, batch_size, xb.data()+batch_size*j*group_dim, ids);
X
xj.lin 已提交
205
        if (j == 200){ sleep(1);}
X
xj.lin 已提交
206 207 208
        ASSERT_STATS(stat);
    }

X
xj.lin 已提交
209
    sleep(3); // wait until build index finish
X
xj.lin 已提交
210

G
groot 已提交
211
    engine::QueryResults results;
X
xj.lin 已提交
212
    stat = db->search(group_name, k, nq, xq.data(), results);
G
groot 已提交
213
    ASSERT_STATS(stat);
X
xj.lin 已提交
214 215

    // TODO(linxj): add groundTruth assert
G
groot 已提交
216 217

    delete db;
X
xj.lin 已提交
218

X
Xu Peng 已提交
219 220 221
    engine::DB::Open(opt, &db);
    db->drop_all();
    delete db;
222
};