From 4cd9dc0e057331a65257d53938f64804fca020ac Mon Sep 17 00:00:00 2001 From: yunyaoXYY Date: Mon, 6 Mar 2023 06:57:00 +0000 Subject: [PATCH] fix faiss bug --- deploy/cpp_shitu/include/vector_search.h | 60 ++++++++++++------------ deploy/cpp_shitu/src/vector_search.cpp | 60 ++++++++++++------------ 2 files changed, 60 insertions(+), 60 deletions(-) diff --git a/deploy/cpp_shitu/include/vector_search.h b/deploy/cpp_shitu/include/vector_search.h index 4b354f78..6753bc7d 100644 --- a/deploy/cpp_shitu/include/vector_search.h +++ b/deploy/cpp_shitu/include/vector_search.h @@ -26,45 +26,45 @@ #include struct SearchResult { - std::vector I; - std::vector D; - int return_k; + std::vector I; + std::vector D; + int return_k; }; class VectorSearch { public: - explicit VectorSearch(const YAML::Node &config_file) { - // IndexProcess - this->index_dir = - config_file["IndexProcess"]["index_dir"].as(); - this->return_k = config_file["IndexProcess"]["return_k"].as(); - this->score_thres = config_file["IndexProcess"]["score_thres"].as(); - this->max_query_number = - config_file["Global"]["max_det_results"].as() + 1; - LoadIdMap(); - LoadIndexFile(); - this->I.resize(this->return_k * this->max_query_number); - this->D.resize(this->return_k * this->max_query_number); - }; + explicit VectorSearch(const YAML::Node &config_file) { + // IndexProcess + this->index_dir = + config_file["IndexProcess"]["index_dir"].as(); + this->return_k = config_file["IndexProcess"]["return_k"].as(); + this->score_thres = config_file["IndexProcess"]["score_thres"].as(); + this->max_query_number = + config_file["Global"]["max_det_results"].as() + 1; + LoadIdMap(); + LoadIndexFile(); + this->I.resize(this->return_k * this->max_query_number); + this->D.resize(this->return_k * this->max_query_number); + }; - void LoadIdMap(); + void LoadIdMap(); - void LoadIndexFile(); + void LoadIndexFile(); - const SearchResult &Search(float *feature, int query_number); + const SearchResult &Search(float *feature, int query_number); - const std::string &GetLabel(faiss::Index::idx_t ind); + const std::string &GetLabel(faiss::idx_t ind); - const float &GetThreshold() { return this->score_thres; } + const float &GetThreshold() { return this->score_thres; } private: - std::string index_dir; - int return_k = 5; - float score_thres = 0.5; - std::map id_map; - faiss::Index *index; - int max_query_number = 6; - std::vector D; - std::vector I; - SearchResult sr; + std::string index_dir; + int return_k = 5; + float score_thres = 0.5; + std::map id_map; + faiss::Index *index; + int max_query_number = 6; + std::vector D; + std::vector I; + SearchResult sr; }; diff --git a/deploy/cpp_shitu/src/vector_search.cpp b/deploy/cpp_shitu/src/vector_search.cpp index 85c487a7..5732bbb1 100644 --- a/deploy/cpp_shitu/src/vector_search.cpp +++ b/deploy/cpp_shitu/src/vector_search.cpp @@ -20,43 +20,43 @@ #include void VectorSearch::LoadIndexFile() { - std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index"; - const char *fname = file_path.c_str(); - this->index = faiss::read_index(fname, 0); + std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index"; + const char *fname = file_path.c_str(); + this->index = faiss::read_index(fname, 0); } void VectorSearch::LoadIdMap() { - std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt"; - std::ifstream in(file_path); - std::string line; - std::vector m_vec; - if (in) { - while (getline(in, line)) { - std::regex ws_re("\\s+"); - std::vector v( - std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1), - std::sregex_token_iterator()); - if (v.size() != 2) { - std::cout << "The number of element for each line in : " << file_path - << "must be 2, exit the program..." << std::endl; - exit(1); - } else - this->id_map.insert(std::pair( - std::stol(v[0], nullptr, 10), v[1])); - } + std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt"; + std::ifstream in(file_path); + std::string line; + std::vector m_vec; + if (in) { + while (getline(in, line)) { + std::regex ws_re("\\s+"); + std::vector v( + std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1), + std::sregex_token_iterator()); + if (v.size() != 2) { + std::cout << "The number of element for each line in : " << file_path + << "must be 2, exit the program..." << std::endl; + exit(1); + } else + this->id_map.insert(std::pair( + std::stol(v[0], nullptr, 10), v[1])); } + } } const SearchResult &VectorSearch::Search(float *feature, int query_number) { - this->D.resize(this->return_k * query_number); - this->I.resize(this->return_k * query_number); - this->index->search(query_number, feature, return_k, D.data(), I.data()); - this->sr.return_k = this->return_k; - this->sr.D = this->D; - this->sr.I = this->I; - return this->sr; + this->D.resize(this->return_k * query_number); + this->I.resize(this->return_k * query_number); + this->index->search(query_number, feature, return_k, D.data(), I.data()); + this->sr.return_k = this->return_k; + this->sr.D = this->D; + this->sr.I = this->I; + return this->sr; } -const std::string &VectorSearch::GetLabel(faiss::Index::idx_t ind) { - return this->id_map.at(ind); +const std::string &VectorSearch::GetLabel(faiss::idx_t ind) { + return this->id_map.at(ind); } -- GitLab