diff --git a/deploy/cpp_shitu/include/nms.h b/deploy/cpp_shitu/include/nms.h new file mode 100644 index 0000000000000000000000000000000000000000..f8d6604b010375c3e2fa85c7dc128bfaf1df39c2 --- /dev/null +++ b/deploy/cpp_shitu/include/nms.h @@ -0,0 +1,85 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This code is adpated from opencv(https://github.com/opencv/opencv) + +#include +#include + +template +static inline bool SortScorePairDescend(const std::pair &pair1, + const std::pair &pair2) { + return pair1.first > pair2.first; +} + +float RectOverlap(const PaddleDetection::ObjectResult &a, + const PaddleDetection::ObjectResult &b) { + float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1); + float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1); + + int iou_w = max(min(a.rect[2], b.rect[2]) - max(a.rect[0], b.rect[0]) + 1, 0); + int iou_h = max(min(a.rect[3], b.rect[3]) - max(a.rect[1], b.rect[1]) + 1, 0); + float Aab = iou_w * iou_h; + return Aab / (Aa + Ab - Aab); +} + +// Get max scores with corresponding indices. +// scores: a set of scores. +// threshold: only consider scores higher than the threshold. +// top_k: if -1, keep all; otherwise, keep at most top_k. +// score_index_vec: store the sorted (score, index) pair. +inline void +GetMaxScoreIndex(const std::vector &det_result, + const float threshold, + std::vector> &score_index_vec) { + // Generate index score pairs. + for (size_t i = 0; i < det_result.size(); ++i) { + if (det_result[i].confidence > threshold) { + score_index_vec.push_back(std::make_pair(det_result[i].confidence, i)); + } + } + + // Sort the score pair according to the scores in descending order + std::stable_sort(score_index_vec.begin(), score_index_vec.end(), + SortScorePairDescend); + + // // Keep top_k scores if needed. + // if (top_k > 0 && top_k < (int)score_index_vec.size()) + // { + // score_index_vec.resize(top_k); + // } +} + +void NMSBoxes(const std::vector det_result, + const float score_threshold, const float nms_threshold, + std::vector &indices) { + int a = 1; + // Get top_k scores (with corresponding indices). + std::vector> score_index_vec; + GetMaxScoreIndex(det_result, score_threshold, score_index_vec); + + // Do nms + indices.clear(); + for (size_t i = 0; i < score_index_vec.size(); ++i) { + const int idx = score_index_vec[i].second; + bool keep = true; + for (int k = 0; k < (int)indices.size() && keep; ++k) { + const int kept_idx = indices[k]; + float overlap = RectOverlap(det_result[idx], det_result[kept_idx]); + keep = overlap <= nms_threshold; + } + if (keep) + indices.push_back(idx); + } +} diff --git a/deploy/cpp_shitu/include/object_detector.h b/deploy/cpp_shitu/include/object_detector.h index 015885b1bf533437824638f6861a625975f3ee09..613fce7c51a3f2849fe3d24bac78c5a20de6a0d5 100644 --- a/deploy/cpp_shitu/include/object_detector.h +++ b/deploy/cpp_shitu/include/object_detector.h @@ -65,7 +65,6 @@ public: this->use_fp16_ = config_file["Global"]["use_fp16"].as(); this->model_dir_ = config_file["Global"]["det_inference_model_dir"].as(); - this->nms_thres_ = config_file["Global"]["rec_nms_thresold"].as(); this->threshold_ = config_file["Global"]["threshold"].as(); this->max_det_results_ = config_file["Global"]["max_det_results"].as(); this->image_shape_ = @@ -105,7 +104,6 @@ private: bool batch_size_ = 1; bool use_fp16_ = false; std::string model_dir_; - float nms_thres_ = 0.02; float threshold_ = 0.5; float max_det_results_ = 5; std::vector image_shape_ = {3, 640, 640}; diff --git a/deploy/cpp_shitu/src/main.cpp b/deploy/cpp_shitu/src/main.cpp index 796d03e61e411f5c82aac9702b4e77ca185b8165..ed78c6fe46f8ff78c37ab0a4215d1f6692f8e541 100644 --- a/deploy/cpp_shitu/src/main.cpp +++ b/deploy/cpp_shitu/src/main.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -132,6 +133,21 @@ void DetPredictImage(const std::vector &batch_imgs, } } +void PrintResult(std::string &img_path, + std::vector &det_result, + std::vector &indeices, VectorSearch &vector_search, + SearchResult &search_result) { + printf("%s:\n", img_path.c_str()); + for (int i = 0; i < indeices.size(); ++i) { + int t = indeices[i]; + printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i, + det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2], + det_result[t].rect[3], det_result[t].confidence, + vector_search.GetLabel(search_result.I[search_result.return_k * t]) + .c_str()); + } +} + int main(int argc, char **argv) { google::ParseCommandLineFlags(&argc, &argv, true); std::string yaml_path = ""; @@ -169,6 +185,11 @@ int main(int argc, char **argv) { if (config.config_file["Global"]["max_det_results"].IsDefined()) { max_det_results = config.config_file["Global"]["max_det_results"].as(); } + float rec_nms_thresold = 0.05; + if (config.config_file["Global"]["rec_nms_thresold"].IsDefined()) { + rec_nms_thresold = + config.config_file["Global"]["rec_nms_thresold"].as(); + } // load image_file_path std::string path = @@ -184,16 +205,20 @@ int main(int argc, char **argv) { img_files_list.push_back(path); } std::cout << "img_file_list length: " << img_files_list.size() << std::endl; - - double elapsed_time = 0.0; + // for time log std::vector cls_times = {0, 0, 0}; std::vector det_times = {0, 0, 0}; + // for read images std::vector batch_imgs; std::vector img_paths; + // for detection std::vector det_result; std::vector det_bbox_num; + // for vector search std::vector features; std::vector feature; + // for nms + std::vector indeices; int warmup_iter = img_files_list.size() > 5 ? 5 : 0; for (int idx = 0; idx < img_files_list.size(); ++idx) { @@ -214,8 +239,8 @@ int main(int argc, char **argv) { det_bbox_num, det_times, visual_det, run_benchmark); // select max_det_results bbox - while (det_result.size() > max_det_results) { - det_result.pop_back(); + if (det_result.size() > max_det_results) { + det_result.resize(max_det_results); } // step2: add the whole image for recognition to improve recall PaddleDetection::ObjectResult result_whole_img = { @@ -238,6 +263,13 @@ int main(int argc, char **argv) { search_result = searcher.Search(features.data(), det_result.size()); // nms for search result + for (int i = 0; i < det_result.size(); ++i) { + det_result[i].confidence = search_result.D[search_result.return_k * i]; + } + NMSBoxes(det_result, detector.GetThreshold(), rec_nms_thresold, indeices); + + // print result + PrintResult(img_path, det_result, indeices, searcher, search_result); // for postprocess batch_imgs.clear(); @@ -246,6 +278,7 @@ int main(int argc, char **argv) { det_result.clear(); feature.clear(); features.clear(); + indeices.clear(); } std::string presion = "fp32";