db_tests.cpp 5.4 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
X
xj.lin 已提交
7 8 9
#include <faiss/IndexFlat.h>
#include <faiss/MetaIndexes.h>
#include <faiss/AutoTune.h>
X
Xu Peng 已提交
10 11
#include <thread>
#include <easylogging++.h>
G
groot 已提交
12 13

#include "db/DB.h"
X
xj.lin 已提交
14
#include "faiss/Index.h"
G
groot 已提交
15 16 17

using namespace zilliz::vecwise;

18 19 20 21 22 23 24 25 26 27 28 29
class DBTest : public ::testing::Test {
protected:
    virtual void SetUp() {
        el::Configurations defaultConf;
        defaultConf.setToDefault();
        defaultConf.set(el::Level::Debug,
                el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)");
        el::Loggers::reconfigureLogger("default", defaultConf);
    }

};

G
groot 已提交
30 31 32 33 34 35 36 37 38
namespace {
    void ASSERT_STATS(engine::Status& stat) {
        ASSERT_TRUE(stat.ok());
        if(!stat.ok()) {
            std::cout << stat.ToString() << std::endl;
        }
    }
}

39 40
TEST_F(DBTest, DB_TEST) {

G
groot 已提交
41 42 43 44
    static const std::string group_name = "test_group";
    static const int group_dim = 256;

    engine::Options opt;
X
Xu Peng 已提交
45
    opt.memory_sync_interval = 1;
G
groot 已提交
46
    opt.meta.backend_uri = "http://127.0.0.1";
G
groot 已提交
47
    opt.meta.path = "/tmp/vecwise_test/db_test";
G
groot 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

    engine::DB* db = nullptr;
    engine::DB::Open(opt, &db);
    ASSERT_TRUE(db != nullptr);

    engine::meta::GroupSchema group_info;
    group_info.dimension = group_dim;
    group_info.group_id = group_name;
    engine::Status stat = db->add_group(group_info);

    engine::meta::GroupSchema group_info_get;
    group_info_get.group_id = group_name;
    stat = db->get_group(group_info_get);
    ASSERT_STATS(stat);
    ASSERT_EQ(group_info_get.dimension, group_dim);

    engine::IDNumbers vector_ids;
X
Xu Peng 已提交
65 66 67
    engine::IDNumbers target_ids;

    int d = 256;
68
    int nb = 10;
X
Xu Peng 已提交
69 70 71 72
    float *xb = new float[d * nb];
    for(int i = 0; i < nb; i++) {
        for(int j = 0; j < d; j++) xb[d * i + j] = drand48();
        xb[d * i] += i / 2000.;
G
groot 已提交
73 74
    }

X
Xu Peng 已提交
75 76 77 78 79 80 81
    int qb = 1;
    float *qxb = new float[d * qb];
    for(int i = 0; i < qb; i++) {
        for(int j = 0; j < d; j++) qxb[d * i + j] = drand48();
        qxb[d * i] += i / 2000.;
    }

82
    int loop = 500000;
X
Xu Peng 已提交
83 84 85 86 87 88 89 90

    for (auto i=0; i<loop; ++i) {
        if (i==40) {
            db->add_vectors(group_name, qb, qxb, target_ids);
        } else {
            db->add_vectors(group_name, nb, xb, vector_ids);
        }
    }
X
xj.lin 已提交
91

X
Xu Peng 已提交
92
    std::this_thread::sleep_for(std::chrono::seconds(3));
X
Xu Peng 已提交
93

X
Xu Peng 已提交
94 95 96 97
    long count = 0;
    db->count(group_name, count);
    LOG(DEBUG) << "Count=" << count;

X
Xu Peng 已提交
98 99
    engine::QueryResults results;
    int k = 10;
X
Xu Peng 已提交
100 101 102 103 104 105 106
    for (auto i=0; i<5; ++i) {
        LOG(DEBUG) << "PRE" << i;
        stat = db->search(group_name, k, qb, qxb, results);
        LOG(DEBUG) << "POST" << i;
        ASSERT_STATS(stat);
        ASSERT_EQ(results[0][0], target_ids[0]);
    }
X
Xu Peng 已提交
107 108 109 110 111 112

    delete [] xb;
    delete [] qxb;
    delete db;
    engine::DB::Open(opt, &db);
    db->drop_all();
X
xj.lin 已提交
113
    delete db;
114
};
X
xj.lin 已提交
115

116
TEST_F(DBTest, SEARCH_TEST) {
X
xj.lin 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    static const std::string group_name = "test_group";
    static const int group_dim = 256;

    engine::Options opt;
    opt.meta.backend_uri = "http://127.0.0.1";
    opt.meta.path = "/tmp/search_test";
    opt.index_trigger_size = 100000 * group_dim;
    opt.memory_sync_interval = 1;
    opt.merge_trigger_number = 1;

    engine::DB* db = nullptr;
    engine::DB::Open(opt, &db);
    ASSERT_TRUE(db != nullptr);

    engine::meta::GroupSchema group_info;
    group_info.dimension = group_dim;
    group_info.group_id = group_name;
    engine::Status stat = db->add_group(group_info);
    //ASSERT_STATS(stat);

    engine::meta::GroupSchema group_info_get;
    group_info_get.group_id = group_name;
    stat = db->get_group(group_info_get);
    ASSERT_STATS(stat);
    ASSERT_EQ(group_info_get.dimension, group_dim);


    // prepare raw data
X
xj.lin 已提交
145
    size_t nb = 250000;
X
xj.lin 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    size_t nq = 10;
    size_t k = 5;
    std::vector<float> xb(nb*group_dim);
    std::vector<float> xq(nq*group_dim);
    std::vector<long> ids(nb);

    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
    for (size_t i = 0; i < nb*group_dim; i++) {
        xb[i] = dis_xt(gen);
        if (i < nb){
            ids[i] = i;
        }
    }
    for (size_t i = 0; i < nq*group_dim; i++) {
        xq[i] = dis_xt(gen);
    }

    // result data
    //std::vector<long> nns_gt(k*nq);
    std::vector<long> nns(k*nq);  // nns = nearst neg search
    //std::vector<float> dis_gt(k*nq);
    std::vector<float> dis(k*nq);

    // prepare ground-truth
    //faiss::Index* index_gt(faiss::index_factory(group_dim, "IDMap,Flat"));
    //index_gt->add_with_ids(nb, xb.data(), ids.data());
    //index_gt->search(nq, xq.data(), 1, dis_gt.data(), nns_gt.data());

    // insert data
    const int batch_size = 100;
    for (int j = 0; j < nb / batch_size; ++j) {
        stat = db->add_vectors(group_name, batch_size, xb.data()+batch_size*j*group_dim, ids);
X
xj.lin 已提交
180
        if (j == 200){ sleep(1);}
X
xj.lin 已提交
181 182 183
        ASSERT_STATS(stat);
    }

X
xj.lin 已提交
184
    sleep(3); // wait until build index finish
X
xj.lin 已提交
185

G
groot 已提交
186
    engine::QueryResults results;
X
xj.lin 已提交
187
    stat = db->search(group_name, k, nq, xq.data(), results);
G
groot 已提交
188
    ASSERT_STATS(stat);
X
xj.lin 已提交
189 190

    // TODO(linxj): add groundTruth assert
G
groot 已提交
191 192

    delete db;
X
xj.lin 已提交
193

X
Xu Peng 已提交
194 195 196
    engine::DB::Open(opt, &db);
    db->drop_all();
    delete db;
197
};