提交 5ef89df9 编写于 作者: D dongshuilong

update for vector search

上级 a89d6ae7
...@@ -19,51 +19,46 @@ ...@@ -19,51 +19,46 @@
#define OS_PATH_SEP "/" #define OS_PATH_SEP "/"
#endif #endif
#include "yaml-cpp/yaml.h"
#include <cstring>
#include <faiss/Index.h> #include <faiss/Index.h>
#include <faiss/index_io.h> #include <faiss/index_io.h>
#include <cstring>
#include <map> #include <map>
#include "yaml-cpp/yaml.h"
struct SearchResult{ struct SearchResult {
faiss::Index::idx_t* I; std::vector<faiss::Index::idx_t> I;
float* D; std::vector<float> D;
int query_number;
int return_k; int return_k;
}; };
class VectorSearch{ class VectorSearch {
public: public:
explicit VectorSearch(const YAML::Node &config_file){ explicit VectorSearch(const YAML::Node &config_file) {
// IndexProcess // IndexProcess
this->index_dir = config_file["IndexProcess"]["index_dir"].as<std::string>(); this->index_dir =
config_file["IndexProcess"]["index_dir"].as<std::string>();
this->return_k = config_file["IndexProcess"]["return_k"].as<int>(); this->return_k = config_file["IndexProcess"]["return_k"].as<int>();
this->score_thres = config_file["IndexProcess"]["score_thres"].as<float>(); this->score_thres = config_file["IndexProcess"]["score_thres"].as<float>();
this->max_query_number = config_file["Global"]["max_det_results"].as<int>() + 1; this->max_query_number =
config_file["Global"]["max_det_results"].as<int>() + 1;
LoadIdMap(); LoadIdMap();
LoadIndexFile(); LoadIndexFile();
this->I = new faiss::Index::idx_t[this->return_k * this->max_query_number]; this->I.resize(this->return_k * this->max_query_number);
this->D = new float[this->return_k * this->max_query_number]; this->D.resize(this->return_k * this->max_query_number);
}
~VectorSearch(){
delete[] I;
delete[] D;
}; };
void LoadIdMap(); void LoadIdMap();
void LoadIndexFile(); void LoadIndexFile();
void Search(float* feature, int query_number); const SearchResult &Search(float *feature, int query_number);
const SearchResult& GetSearchResult(); const std::string &GetLabel(faiss::Index::idx_t ind);
const std::string& GetLabel(faiss::Index::idx_t ind);
private: private:
std::string index_dir; std::string index_dir;
int return_k = 5; int return_k = 5;
float score_thres = 0.5; float score_thres = 0.5;
std::map <long int, std::string> id_map; std::map<long int, std::string> id_map;
faiss::Index * index; faiss::Index *index;
int max_query_number = 6; int max_query_number = 6;
int real_query_number = 0; std::vector<float> D;
float *D = NULL; std::vector<faiss::Index::idx_t> I;
faiss::Index::idx_t* I = NULL;
SearchResult sr; SearchResult sr;
}; };
...@@ -29,8 +29,8 @@ ...@@ -29,8 +29,8 @@
#include <auto_log/autolog.h> #include <auto_log/autolog.h>
#include <include/cls.h> #include <include/cls.h>
#include <include/object_detector.h> #include <include/object_detector.h>
#include <include/yaml_config.h>
#include <include/vector_search.h> #include <include/vector_search.h>
#include <include/yaml_config.h>
using namespace std; using namespace std;
using namespace cv; using namespace cv;
...@@ -137,6 +137,11 @@ int main(int argc, char **argv) { ...@@ -137,6 +137,11 @@ int main(int argc, char **argv) {
YamlConfig config(argv[1]); YamlConfig config(argv[1]);
config.PrintConfigInfo(); config.PrintConfigInfo();
// initialize detector, rec_Model, vector_search
PaddleClas::Classifier classifier(config.config_file);
PaddleDetection::ObjectDetector detector(config.config_file);
VectorSearch searcher(config.config_file);
// config // config
const int batch_size = config.config_file["Global"]["batch_size"].as<int>(); const int batch_size = config.config_file["Global"]["batch_size"].as<int>();
bool visual_det = false; bool visual_det = false;
...@@ -152,6 +157,7 @@ int main(int argc, char **argv) { ...@@ -152,6 +157,7 @@ int main(int argc, char **argv) {
max_det_results = config.config_file["Global"]["max_det_results"].as<int>(); max_det_results = config.config_file["Global"]["max_det_results"].as<int>();
} }
// load image_file_path
std::string path = std::string path =
config.config_file["Global"]["infer_imgs"].as<std::string>(); config.config_file["Global"]["infer_imgs"].as<std::string>();
std::vector<std::string> img_files_list; std::vector<std::string> img_files_list;
...@@ -164,12 +170,8 @@ int main(int argc, char **argv) { ...@@ -164,12 +170,8 @@ int main(int argc, char **argv) {
} else { } else {
img_files_list.push_back(path); img_files_list.push_back(path);
} }
std::cout << "img_file_list length: " << img_files_list.size() << std::endl; std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
PaddleClas::Classifier classifier(config.config_file);
PaddleDetection::ObjectDetector detector(config.config_file);
double elapsed_time = 0.0; double elapsed_time = 0.0;
std::vector<double> cls_times = {0, 0, 0}; std::vector<double> cls_times = {0, 0, 0};
std::vector<double> det_times = {0, 0, 0}; std::vector<double> det_times = {0, 0, 0};
...@@ -177,6 +179,8 @@ int main(int argc, char **argv) { ...@@ -177,6 +179,8 @@ int main(int argc, char **argv) {
std::vector<std::string> img_paths; std::vector<std::string> img_paths;
std::vector<PaddleDetection::ObjectResult> det_result; std::vector<PaddleDetection::ObjectResult> det_result;
std::vector<int> det_bbox_num; std::vector<int> det_bbox_num;
std::vector<float> features;
std::vector<float> feature;
int warmup_iter = img_files_list.size() > 5 ? 5 : 0; int warmup_iter = img_files_list.size() > 5 ? 5 : 0;
for (int idx = 0; idx < img_files_list.size(); ++idx) { for (int idx = 0; idx < img_files_list.size(); ++idx) {
...@@ -206,20 +210,29 @@ int main(int argc, char **argv) { ...@@ -206,20 +210,29 @@ int main(int argc, char **argv) {
det_result.push_back(result_whole_img); det_result.push_back(result_whole_img);
det_bbox_num[0] = det_result.size() + 1; det_bbox_num[0] = det_result.size() + 1;
// step3: recognition process, use score_thres to ensure accuracy // step3: extract feature for all boxes in an inmage
SearchResult search_result;
for (int j = 0; j < det_result.size(); ++j) { for (int j = 0; j < det_result.size(); ++j) {
int w = det_result[j].rect[2] - det_result[j].rect[0]; int w = det_result[j].rect[2] - det_result[j].rect[0];
int h = det_result[j].rect[3] - det_result[j].rect[1]; int h = det_result[j].rect[3] - det_result[j].rect[1];
cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h); cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
cv::Mat crop_img = srcimg(rect); cv::Mat crop_img = srcimg(rect);
std::vector<float> feature;
classifier.Run(crop_img, feature, cls_times); classifier.Run(crop_img, feature, cls_times);
features.insert(features.end(), feature.begin(), feature.end());
} }
// double run_time = classifier.Run(srcimg, cls_times);
// step4: get search result
search_result = searcher.Search(features.data(), det_result.size());
// nms for search result
// for postprocess
batch_imgs.clear(); batch_imgs.clear();
img_paths.clear(); img_paths.clear();
det_bbox_num.clear(); det_bbox_num.clear();
det_result.clear(); det_result.clear();
feature.clear();
features.clear();
} }
std::string presion = "fp32"; std::string presion = "fp32";
......
...@@ -11,54 +11,52 @@ ...@@ -11,54 +11,52 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "include/vector_search.h"
#include <cstdio>
#include <faiss/index_factory.h> #include <faiss/index_factory.h>
#include <faiss/index_io.h> #include <faiss/index_io.h>
#include <fstream> #include <fstream>
#include <regex>
#include <iostream> #include <iostream>
#include <cstdio> #include <regex>
#include "include/vector_search.h"
void VectorSearch::LoadIndexFile(){ void VectorSearch::LoadIndexFile() {
std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index"; std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index";
const char* fname = file_path.c_str(); const char *fname = file_path.c_str();
this->index = faiss::read_index(fname, 0); this->index = faiss::read_index(fname, 0);
} }
void VectorSearch::LoadIdMap(){ void VectorSearch::LoadIdMap() {
std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt"; std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt";
std::ifstream in(file_path); std::ifstream in(file_path);
std::string line; std::string line;
std::vector<std::string> m_vec; std::vector<std::string> m_vec;
if (in){ if (in) {
while (getline(in, line)){ while (getline(in, line)) {
std::regex ws_re("\\s+"); std::regex ws_re("\\s+");
std::vector<std::string> v( std::vector<std::string> v(
std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1), std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1),
std::sregex_token_iterator()); std::sregex_token_iterator());
if (v.size() !=2){ if (v.size() != 2) {
std::cout << "The number of element for each line in : " << file_path std::cout << "The number of element for each line in : " << file_path
<< "must be 2, exit the program..." << std::endl; << "must be 2, exit the program..." << std::endl;
exit(1); exit(1);
}else } else
this->id_map.insert(std::pair<long int, std::string>(std::stol(v[0], nullptr, 10), v[1])); this->id_map.insert(std::pair<long int, std::string>(
std::stol(v[0], nullptr, 10), v[1]));
} }
} }
} }
void VectorSearch::Search(float *feature, int query_number){ const SearchResult &VectorSearch::Search(float *feature, int query_number) {
this->index->search(query_number, feature, return_k, D, I); this->D.resize(this->return_k * query_number);
this->real_query_number = query_number; this->I.resize(this->return_k * query_number);
} this->index->search(query_number, feature, return_k, D.data(), I.data());
const SearchResult& VectorSearch::GetSearchResult(){
this->sr.query_number = this->real_query_number;
this->sr.return_k = this->return_k; this->sr.return_k = this->return_k;
this->sr.D = this->D; this->sr.D = this->D;
this->sr.I = this->I; this->sr.I = this->I;
return this->sr; return this->sr;
} }
const std::string& VectorSearch::GetLabel(faiss::Index::idx_t ind){ const std::string &VectorSearch::GetLabel(faiss::Index::idx_t ind) {
return this->id_map.at(ind); return this->id_map.at(ind);
} }
...@@ -3,6 +3,7 @@ LIB_DIR=/work/project/project/cpp_infer/paddle_inference/ ...@@ -3,6 +3,7 @@ LIB_DIR=/work/project/project/cpp_infer/paddle_inference/
CUDA_LIB_DIR=/usr/local/cuda/lib64 CUDA_LIB_DIR=/usr/local/cuda/lib64
CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/ CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/
FAISS_DIR=/work/project/project/cpp_infer/faiss/faiss_install FAISS_DIR=/work/project/project/cpp_infer/faiss/faiss_install
FAISS_WITH_MKL=OFF
BUILD_DIR=build BUILD_DIR=build
rm -rf ${BUILD_DIR} rm -rf ${BUILD_DIR}
...@@ -18,6 +19,6 @@ cmake .. \ ...@@ -18,6 +19,6 @@ cmake .. \
-DCUDNN_LIB=${CUDNN_LIB_DIR} \ -DCUDNN_LIB=${CUDNN_LIB_DIR} \
-DCUDA_LIB=${CUDA_LIB_DIR} \ -DCUDA_LIB=${CUDA_LIB_DIR} \
-DFAISS_DIR=${FAISS_DIR} \ -DFAISS_DIR=${FAISS_DIR} \
-DFAISS_WITH_MKL=OFF -DFAISS_WITH_MKL=${FAISS_WITH_MKL}
make -j make -j
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册