diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index 64f7d68eb5c09ee2f88cea8e7ebdcfd657915ab4..18e77366323868f78124cdc4e13eb19c5263970c 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -56,6 +56,7 @@ Status DBImpl::add_vectors(const std::string& group_id_, } } +// TODO(XUPENG): add search range based on time Status DBImpl::search(const std::string &group_id, size_t k, size_t nq, const float *vectors, QueryResults &results) { meta::DatePartionedGroupFilesSchema files; @@ -63,75 +64,92 @@ Status DBImpl::search(const std::string &group_id, size_t k, size_t nq, auto status = _pMeta->files_to_search(group_id, partition, files); if (!status.ok()) { return status; } - // TODO: optimized meta::GroupFilesSchema index_files; meta::GroupFilesSchema raw_files; for (auto &day_files : files) { for (auto &file : day_files.second) { - file.file_type == meta::GroupFileSchema::RAW ? - raw_files.push_back(file) : - index_files.push_back(file); + file.file_type == meta::GroupFileSchema::INDEX ? + index_files.push_back(file) : raw_files.push_back(file); } } - int dim = raw_files[0].dimension; + int dim = 0; + if (!index_files.empty()) { + dim = index_files[0].dimension; + } else if (!raw_files.empty()) { + dim = raw_files[0].dimension; + } else { + return Status::OK(); + } - // merge raw files + // merge raw files and build flat index. faiss::Index *index(faiss::index_factory(dim, "IDMap,Flat")); - for (auto &file : raw_files) { auto file_index = dynamic_cast(faiss::read_index(file.location.c_str())); - index->add_with_ids(file_index->ntotal, dynamic_cast(file_index->index)->xb.data(), + index->add_with_ids(file_index->ntotal, + dynamic_cast(file_index->index)->xb.data(), file_index->id_map.data()); } - float *xb = dynamic_cast(index)->xb.data(); - int64_t *ids = dynamic_cast(index)->id_map.data(); - long totoal = index->ntotal; - std::vector distence; - std::vector result_ids; { - // allocate memory + // [{ids, distence}, ...] + using SearchResult = std::pair, std::vector>; + std::vector batchresult(nq); // allocate nq cells. + + auto cluster = [&](long *nns, float *dis) -> void { + for (int i = 0; i < nq; ++i) { + auto f_begin = batchresult[i].first.cbegin(); + auto s_begin = batchresult[i].second.cbegin(); + batchresult[i].first.insert(f_begin, nns + i * k, nns + i * k + k); + batchresult[i].second.insert(s_begin, dis + i * k, dis + i * k + k); + } + }; + + // Allocate Memory float *output_distence; long *output_ids; - output_distence = (float *) malloc(k * sizeof(float)); - output_ids = (long *) malloc(k * sizeof(long)); - - // build and search in raw file - // TODO: HardCode - auto opd = std::make_shared(); - opd->index_type = "IDMap,Flat"; - IndexBuilderPtr builder = GetIndexBuilder(opd); - auto index = builder->build_all(totoal, xb, ids); + output_distence = (float *) malloc(k * nq * sizeof(float)); + output_ids = (long *) malloc(k * nq * sizeof(long)); + memset(output_distence, 0, k * nq * sizeof(float)); + memset(output_ids, 0, k * nq * sizeof(long)); + // search in raw file index->search(nq, vectors, k, output_distence, output_ids); - distence.insert(distence.begin(), output_distence, output_distence + k); - result_ids.insert(result_ids.begin(), output_ids, output_ids + k); - memset(output_distence, 0, k * sizeof(float)); - memset(output_ids, 0, k * sizeof(long)); + cluster(output_ids, output_distence); // cluster to each query + memset(output_distence, 0, k * nq * sizeof(float)); + memset(output_ids, 0, k * nq * sizeof(long)); - // search in index file + // Search in index file for (auto &file : index_files) { auto index = read_index(file.location.c_str()); index->search(nq, vectors, k, output_distence, output_ids); - distence.insert(distence.begin(), output_distence, output_distence + k); - result_ids.insert(result_ids.begin(), output_ids, output_ids + k); - memset(output_distence, 0, k * sizeof(float)); - memset(output_ids, 0, k * sizeof(long)); + cluster(output_ids, output_distence); // cluster to each query + memset(output_distence, 0, k * nq * sizeof(float)); + memset(output_ids, 0, k * nq * sizeof(long)); } - // TopK - TopK(distence.data(), distence.size(), k, output_distence, output_ids); - distence.clear(); - result_ids.clear(); - distence.insert(distence.begin(), output_distence, output_distence + k); - result_ids.insert(result_ids.begin(), output_ids, output_ids + k); + auto cluster_topk = [&]() -> void { + QueryResult res; + for (auto &result_pair : batchresult) { + auto &dis = result_pair.second; + auto &nns = result_pair.first; + TopK(dis.data(), dis.size(), k, output_distence, output_ids); + for (int i = 0; i < k; ++i) { + res.emplace_back(nns[output_ids[i]]); // mapping + } + results.push_back(res); // append to result list + res.clear(); + } + }; + cluster_topk(); - // free free(output_distence); free(output_ids); } + if (results.empty()) { + return Status::NotFound("Group " + group_id + ", search result not found!"); + } return Status::OK(); } diff --git a/cpp/unittest/db/db_tests.cpp b/cpp/unittest/db/db_tests.cpp index 8b8ee602d148302a3ff6f3211c663f1123d82da2..b8106452c6f88775b75382d4978786b8a9df8d16 100644 --- a/cpp/unittest/db/db_tests.cpp +++ b/cpp/unittest/db/db_tests.cpp @@ -4,8 +4,12 @@ // Proprietary and confidential. //////////////////////////////////////////////////////////////////////////////// #include +#include +#include +#include #include "db/DB.h" +#include "faiss/Index.h" using namespace zilliz::vecwise; @@ -51,12 +55,90 @@ TEST(DBTest, DB_TEST) { stat = db->add_vectors(group_name, 1, vec_f.data(), vector_ids); ASSERT_STATS(stat); + //engine::QueryResults results; + //std::vector vec_s = vec_f; + //stat = db->search(group_name, 1, 1, vec_f.data(), results); + //ASSERT_STATS(stat); + //ASSERT_EQ(results.size(), 1); + //ASSERT_EQ(results[0][0], vector_ids[0]); + + delete db; +} + +TEST(SearchTest, DB_TEST) { + static const std::string group_name = "test_group"; + static const int group_dim = 256; + + engine::Options opt; + opt.meta.backend_uri = "http://127.0.0.1"; + opt.meta.path = "/tmp/search_test"; + opt.index_trigger_size = 100000 * group_dim; + opt.memory_sync_interval = 1; + opt.merge_trigger_number = 1; + + engine::DB* db = nullptr; + engine::DB::Open(opt, &db); + ASSERT_TRUE(db != nullptr); + + engine::meta::GroupSchema group_info; + group_info.dimension = group_dim; + group_info.group_id = group_name; + engine::Status stat = db->add_group(group_info); + //ASSERT_STATS(stat); + + engine::meta::GroupSchema group_info_get; + group_info_get.group_id = group_name; + stat = db->get_group(group_info_get); + ASSERT_STATS(stat); + ASSERT_EQ(group_info_get.dimension, group_dim); + + + // prepare raw data + size_t nb = 25000; + size_t nq = 10; + size_t k = 5; + std::vector xb(nb*group_dim); + std::vector xq(nq*group_dim); + std::vector ids(nb); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis_xt(-1.0, 1.0); + for (size_t i = 0; i < nb*group_dim; i++) { + xb[i] = dis_xt(gen); + if (i < nb){ + ids[i] = i; + } + } + for (size_t i = 0; i < nq*group_dim; i++) { + xq[i] = dis_xt(gen); + } + + // result data + //std::vector nns_gt(k*nq); + std::vector nns(k*nq); // nns = nearst neg search + //std::vector dis_gt(k*nq); + std::vector dis(k*nq); + + // prepare ground-truth + //faiss::Index* index_gt(faiss::index_factory(group_dim, "IDMap,Flat")); + //index_gt->add_with_ids(nb, xb.data(), ids.data()); + //index_gt->search(nq, xq.data(), 1, dis_gt.data(), nns_gt.data()); + + // insert data + const int batch_size = 100; + for (int j = 0; j < nb / batch_size; ++j) { + stat = db->add_vectors(group_name, batch_size, xb.data()+batch_size*j*group_dim, ids); + ASSERT_STATS(stat); + } + + //sleep(10); // wait until build index finish + engine::QueryResults results; - std::vector vec_s = vec_f; - stat = db->search(group_name, 1, 1, vec_f.data(), results); + stat = db->search(group_name, k, nq, xq.data(), results); ASSERT_STATS(stat); - ASSERT_EQ(results.size(), 1); - ASSERT_EQ(results[0][0], vector_ids[0]); + + // TODO(linxj): add groundTruth assert delete db; } \ No newline at end of file diff --git a/cpp/unittest/faiss_wrapper/wrapper_test.cpp b/cpp/unittest/faiss_wrapper/wrapper_test.cpp index 75a56fc19c01c275d292eb7d89ef5d8ca1f33302..c2d29e30b60f048b78700121782e26aeac2cb892 100644 --- a/cpp/unittest/faiss_wrapper/wrapper_test.cpp +++ b/cpp/unittest/faiss_wrapper/wrapper_test.cpp @@ -92,3 +92,35 @@ TEST(build_test, Wrapper_Test) { delete[] result_ids; } +TEST(search_test, Wrapper_Test) { + const int dim = 256; + + size_t nb = 25000; + size_t nq = 100; + size_t k = 100; + std::vector xb(nb*dim); + std::vector xq(nq*dim); + std::vector ids(nb*dim); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis_xt(-1.0, 1.0); + for (size_t i = 0; i < nb*dim; i++) { + xb[i] = dis_xt(gen); + ids[i] = i; + } + for (size_t i = 0; i < nq*dim; i++) { + xq[i] = dis_xt(gen); + } + + // result data + std::vector nns_gt(nq*k); // nns = nearst neg search + std::vector nns(nq*k); + std::vector dis_gt(nq*k); + std::vector dis(nq*k); + faiss::Index* index_gt(faiss::index_factory(dim, "IDMap,Flat")); + index_gt->add_with_ids(nb, xb.data(), ids.data()); + index_gt->search(nq, xq.data(), 10, dis_gt.data(), nns_gt.data()); + std::cout << "data: " << nns_gt[0]; + +}