diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index 5fb02e0b837f3e9b2528750e9b72d0f738bb0a7a..c272abe6494a8db4a6694fdbe4b6780ae8cbbd5f 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include "DBImpl.h" #include "DBMetaImpl.h" #include "Env.h" @@ -53,9 +55,82 @@ Status DBImpl::add_vectors(const std::string& group_id_, } } -Status DBImpl::search(const std::string& group_id, size_t k, size_t nq, - const float* vectors, QueryResults& results) { - // PXU TODO +Status DBImpl::search(const std::string &group_id, size_t k, size_t nq, + const float *vectors, QueryResults &results) { + meta::DatePartionedGroupFilesSchema files; + std::vector partition; + auto status = _pMeta->files_to_search(group_id, partition, files); + if (!status.ok()) { return status; } + + // TODO: optimized + meta::GroupFilesSchema index_files; + meta::GroupFilesSchema raw_files; + for (auto &day_files : files) { + for (auto &file : day_files.second) { + file.file_type == meta::GroupFileSchema::RAW ? + raw_files.push_back(file) : + index_files.push_back(file); + } + } + int dim = raw_files[0].dimension; + + + // merge raw files + faiss::Index *index(faiss::index_factory(dim, "IDMap,Flat")); + + for (auto &file : raw_files) { + auto file_index = dynamic_cast(faiss::read_index(file.location.c_str())); + index->add_with_ids(file_index->ntotal, dynamic_cast(file_index->index)->xb.data(), + file_index->id_map.data()); + } + float *xb = dynamic_cast(index)->xb.data(); + int64_t *ids = dynamic_cast(index)->id_map.data(); + long totoal = index->ntotal; + + std::vector distence; + std::vector result_ids; + { + // allocate memory + float *output_distence; + long *output_ids; + output_distence = (float *) malloc(k * sizeof(float)); + output_ids = (long *) malloc(k * sizeof(long)); + + // build and search in raw file + // TODO: HardCode + auto opd = std::make_shared(); + opd->index_type = "IDMap,Flat"; + IndexBuilderPtr builder = GetIndexBuilder(opd); + auto index = builder->build_all(totoal, xb, ids); + + index->search(nq, vectors, k, output_distence, output_ids); + distence.insert(distence.begin(), output_distence, output_distence + k); + result_ids.insert(result_ids.begin(), output_ids, output_ids + k); + memset(output_distence, 0, k * sizeof(float)); + memset(output_ids, 0, k * sizeof(long)); + + // search in index file + for (auto &file : index_files) { + auto index = read_index(file.location.c_str()); + index->search(nq, vectors, k, output_distence, output_ids); + distence.insert(distence.begin(), output_distence, output_distence + k); + result_ids.insert(result_ids.begin(), output_ids, output_ids + k); + memset(output_distence, 0, k * sizeof(float)); + memset(output_ids, 0, k * sizeof(long)); + } + + // TopK + TopK(distence.data(), distence.size(), k, output_distence, output_ids); + distence.clear(); + result_ids.clear(); + distence.insert(distence.begin(), output_distence, output_distence + k); + result_ids.insert(result_ids.begin(), output_ids, output_ids + k); + + // free + free(output_distence); + free(output_ids); + } + return Status::OK(); } diff --git a/cpp/src/db/DBMetaImpl.cpp b/cpp/src/db/DBMetaImpl.cpp index 3b0f520673edfde1e8b750fe8fa6549294f93880..fe539365053ca2ffbbcddfa2b7507ed685267554 100644 --- a/cpp/src/db/DBMetaImpl.cpp +++ b/cpp/src/db/DBMetaImpl.cpp @@ -226,6 +226,48 @@ Status DBMetaImpl::files_to_index(GroupFilesSchema& files) { return Status::OK(); } +Status DBMetaImpl::files_to_search(const std::string &group_id, + std::vector partition, + DatePartionedGroupFilesSchema &files) { + // TODO: support data partition + files.clear(); + auto selected = ConnectorPtr->select(columns(&GroupFileSchema::id, + &GroupFileSchema::group_id, + &GroupFileSchema::file_id, + &GroupFileSchema::file_type, + &GroupFileSchema::rows, + &GroupFileSchema::date), + where(c(&GroupFileSchema::group_id) == group_id and + (c(&GroupFileSchema::file_type) == (int) GroupFileSchema::RAW or + c(&GroupFileSchema::file_type) == (int) GroupFileSchema::INDEX))); + + GroupSchema group_info; + group_info.group_id = group_id; + auto status = get_group_no_lock(group_info); + if (!status.ok()) { + return status; + } + + for (auto& file : selected) { + GroupFileSchema group_file; + group_file.id = std::get<0>(file); + group_file.group_id = std::get<1>(file); + group_file.file_id = std::get<2>(file); + group_file.file_type = std::get<3>(file); + group_file.rows = std::get<4>(file); + group_file.date = std::get<5>(file); + group_file.dimension = group_info.dimension; + GetGroupFilePath(group_file); + auto dateItr = files.find(group_file.date); + if (dateItr == files.end()) { + files[group_file.date] = GroupFilesSchema(); + } + files[group_file.date].push_back(group_file); + } + + return Status::OK(); +} + Status DBMetaImpl::files_to_merge(const std::string& group_id, DatePartionedGroupFilesSchema& files) { files.clear(); diff --git a/cpp/src/db/DBMetaImpl.h b/cpp/src/db/DBMetaImpl.h index 86d1cd56bac12fa8839fe52984d7d1dca2479f7f..8def7aea32d345c60f2d1d19f33d90157c64ab17 100644 --- a/cpp/src/db/DBMetaImpl.h +++ b/cpp/src/db/DBMetaImpl.h @@ -38,6 +38,10 @@ public: virtual Status files_to_merge(const std::string& group_id, DatePartionedGroupFilesSchema& files) override; + virtual Status files_to_search(const std::string& group_id, + std::vector partition, + DatePartionedGroupFilesSchema& files) override; + virtual Status files_to_index(GroupFilesSchema&) override; virtual Status cleanup() override; diff --git a/cpp/src/db/Meta.h b/cpp/src/db/Meta.h index 84ba414708248fae3b43146f33180c08f40f2a7d..0a795cdd83f8cf0a883971467fc940c77c9db085 100644 --- a/cpp/src/db/Meta.h +++ b/cpp/src/db/Meta.h @@ -70,6 +70,10 @@ public: virtual Status update_files(const GroupFilesSchema& files) = 0; + virtual Status files_to_search(const std::string& group_id, + std::vector partition, + DatePartionedGroupFilesSchema& files) = 0; + virtual Status files_to_merge(const std::string& group_id, DatePartionedGroupFilesSchema& files) = 0; diff --git a/cpp/src/wrapper/Index.cpp b/cpp/src/wrapper/Index.cpp index 351baf00d3c16ceb92b9bfc43e4bcbb038ff74f2..0dcc8fc7ea7deceeb93f1b038aad0b45f7e48746 100644 --- a/cpp/src/wrapper/Index.cpp +++ b/cpp/src/wrapper/Index.cpp @@ -11,6 +11,7 @@ #endif #include "Index.h" +#include "faiss/index_io.h" namespace zilliz { namespace vecwise { @@ -66,6 +67,12 @@ void write_index(const Index_ptr &index, const std::string &file_name) { write_index(index->index_.get(), file_name.c_str()); } +Index_ptr read_index(const std::string &file_name) { + std::shared_ptr raw_index = nullptr; + raw_index.reset(faiss::read_index(file_name.c_str())); + return std::make_shared(raw_index); +} + } } }