提交 4cd9dc0e 编写于 作者: Y yunyaoXYY 提交者: Walter

fix faiss bug

上级 d0c01a97
...@@ -26,45 +26,45 @@ ...@@ -26,45 +26,45 @@
#include <map> #include <map>
struct SearchResult { struct SearchResult {
std::vector <faiss::Index::idx_t> I; std::vector<faiss::idx_t> I;
std::vector<float> D; std::vector<float> D;
int return_k; int return_k;
}; };
class VectorSearch { class VectorSearch {
public: public:
explicit VectorSearch(const YAML::Node &config_file) { explicit VectorSearch(const YAML::Node &config_file) {
// IndexProcess // IndexProcess
this->index_dir = this->index_dir =
config_file["IndexProcess"]["index_dir"].as<std::string>(); config_file["IndexProcess"]["index_dir"].as<std::string>();
this->return_k = config_file["IndexProcess"]["return_k"].as<int>(); this->return_k = config_file["IndexProcess"]["return_k"].as<int>();
this->score_thres = config_file["IndexProcess"]["score_thres"].as<float>(); this->score_thres = config_file["IndexProcess"]["score_thres"].as<float>();
this->max_query_number = this->max_query_number =
config_file["Global"]["max_det_results"].as<int>() + 1; config_file["Global"]["max_det_results"].as<int>() + 1;
LoadIdMap(); LoadIdMap();
LoadIndexFile(); LoadIndexFile();
this->I.resize(this->return_k * this->max_query_number); this->I.resize(this->return_k * this->max_query_number);
this->D.resize(this->return_k * this->max_query_number); this->D.resize(this->return_k * this->max_query_number);
}; };
void LoadIdMap(); void LoadIdMap();
void LoadIndexFile(); void LoadIndexFile();
const SearchResult &Search(float *feature, int query_number); const SearchResult &Search(float *feature, int query_number);
const std::string &GetLabel(faiss::Index::idx_t ind); const std::string &GetLabel(faiss::idx_t ind);
const float &GetThreshold() { return this->score_thres; } const float &GetThreshold() { return this->score_thres; }
private: private:
std::string index_dir; std::string index_dir;
int return_k = 5; int return_k = 5;
float score_thres = 0.5; float score_thres = 0.5;
std::map<long int, std::string> id_map; std::map<long int, std::string> id_map;
faiss::Index *index; faiss::Index *index;
int max_query_number = 6; int max_query_number = 6;
std::vector<float> D; std::vector<float> D;
std::vector <faiss::Index::idx_t> I; std::vector<faiss::idx_t> I;
SearchResult sr; SearchResult sr;
}; };
...@@ -20,43 +20,43 @@ ...@@ -20,43 +20,43 @@
#include <regex> #include <regex>
void VectorSearch::LoadIndexFile() { void VectorSearch::LoadIndexFile() {
std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index"; std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index";
const char *fname = file_path.c_str(); const char *fname = file_path.c_str();
this->index = faiss::read_index(fname, 0); this->index = faiss::read_index(fname, 0);
} }
void VectorSearch::LoadIdMap() { void VectorSearch::LoadIdMap() {
std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt"; std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt";
std::ifstream in(file_path); std::ifstream in(file_path);
std::string line; std::string line;
std::vector <std::string> m_vec; std::vector<std::string> m_vec;
if (in) { if (in) {
while (getline(in, line)) { while (getline(in, line)) {
std::regex ws_re("\\s+"); std::regex ws_re("\\s+");
std::vector <std::string> v( std::vector<std::string> v(
std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1), std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1),
std::sregex_token_iterator()); std::sregex_token_iterator());
if (v.size() != 2) { if (v.size() != 2) {
std::cout << "The number of element for each line in : " << file_path std::cout << "The number of element for each line in : " << file_path
<< "must be 2, exit the program..." << std::endl; << "must be 2, exit the program..." << std::endl;
exit(1); exit(1);
} else } else
this->id_map.insert(std::pair<long int, std::string>( this->id_map.insert(std::pair<long int, std::string>(
std::stol(v[0], nullptr, 10), v[1])); std::stol(v[0], nullptr, 10), v[1]));
}
} }
}
} }
const SearchResult &VectorSearch::Search(float *feature, int query_number) { const SearchResult &VectorSearch::Search(float *feature, int query_number) {
this->D.resize(this->return_k * query_number); this->D.resize(this->return_k * query_number);
this->I.resize(this->return_k * query_number); this->I.resize(this->return_k * query_number);
this->index->search(query_number, feature, return_k, D.data(), I.data()); this->index->search(query_number, feature, return_k, D.data(), I.data());
this->sr.return_k = this->return_k; this->sr.return_k = this->return_k;
this->sr.D = this->D; this->sr.D = this->D;
this->sr.I = this->I; this->sr.I = this->I;
return this->sr; return this->sr;
} }
const std::string &VectorSearch::GetLabel(faiss::Index::idx_t ind) { const std::string &VectorSearch::GetLabel(faiss::idx_t ind) {
return this->id_map.at(ind); return this->id_map.at(ind);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册