From 2b035e54a6f0a26204f75ce3369bc2a4d0ff0650 Mon Sep 17 00:00:00 2001 From: groot Date: Sun, 21 Apr 2019 15:55:48 +0800 Subject: [PATCH] index search unittest Former-commit-id: 9632c0668ce71d07bba6f13e9ee46df4e1af6b38 --- cpp/unittest/faiss_wrapper/wrapper_test.cpp | 72 +++++++++++++++------ 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/cpp/unittest/faiss_wrapper/wrapper_test.cpp b/cpp/unittest/faiss_wrapper/wrapper_test.cpp index b410d4a5..75a56fc1 100644 --- a/cpp/unittest/faiss_wrapper/wrapper_test.cpp +++ b/cpp/unittest/faiss_wrapper/wrapper_test.cpp @@ -23,36 +23,72 @@ TEST(operand_test, Wrapper_Test) { TEST(build_test, Wrapper_Test) { // dimension of the vectors to index - int d = 64; - - // size of the database we plan to index - size_t nb = 100000; + int d = 3; // make a set of nt training vectors in the unit cube - size_t nt = 150000; + size_t nt = 10000; // a reasonable number of cetroids to index nb vectors - int ncentroids = 25; - - srand48(35); // seed + int ncentroids = 16; - std::vector xb(nb * d); - for (size_t i = 0; i < nb * d; i++) { - xb[i] = drand48(); - } + std::random_device rd; + std::mt19937 gen(rd()); - std::vector ids(nb); - for (size_t i = 0; i < nb; i++) { - ids[i] = drand48(); - } + std::vector xb; + std::vector ids; + //prepare train data + std::uniform_real_distribution<> dis_xt(-1.0, 1.0); std::vector xt(nt * d); for (size_t i = 0; i < nt * d; i++) { - xt[i] = drand48(); + xt[i] = dis_xt(gen); } + //train the index auto opd = std::make_shared(); + opd->index_type = "IVF16,Flat"; + opd->d = d; + opd->ncent = ncentroids; IndexBuilderPtr index_builder_1 = GetIndexBuilder(opd); - auto index_1 = index_builder_1->build_all(nb, xb, ids, nt, xt); + auto index_1 = index_builder_1->build_all(0, xb, ids, nt, xt); + ASSERT_TRUE(index_1 != nullptr); + + // size of the database we plan to index + size_t nb = 100000; + + //prepare raw data + xb.resize(nb); + ids.resize(nb); + for (size_t i = 0; i < nb; i++) { + xb[i] = dis_xt(gen); + ids[i] = i; + } + index_1->add_with_ids(nb, xb.data(), ids.data()); + + //search in first quadrant + int nq = 1, k = 10; + std::vector xq = {0.5, 0.5, 0.5}; + float* result_dists = new float[k]; + long* result_ids = new long[k]; + index_1->search(nq, xq.data(), k, result_dists, result_ids); + + for(int i = 0; i < k; i++) { + if(result_ids[i] < 0) { + ASSERT_TRUE(false); + break; + } + + long id = result_ids[i]; + std::cout << "No." << id << " [" << xb[id*3] << ", " << xb[id*3 + 1] << ", " + << xb[id*3 + 2] <<"] distance = " << result_dists[i] << std::endl; + + //makesure result vector is in first quadrant + ASSERT_TRUE(xb[id*3] > 0.0); + ASSERT_TRUE(xb[id*3 + 1] > 0.0); + ASSERT_TRUE(xb[id*3 + 2] > 0.0); + } + + delete[] result_dists; + delete[] result_ids; } -- GitLab