diff --git a/deploy/lite/include/keypoint_detector.h b/deploy/lite/include/keypoint_detector.h index a4feb01cb9dbe286a386133c07cc90e816957a77..d41ba0adde31b81c6a797a7c70cae7ec7fdac37d 100644 --- a/deploy/lite/include/keypoint_detector.h +++ b/deploy/lite/include/keypoint_detector.h @@ -43,7 +43,8 @@ struct KeyPointResult { // Visualiztion KeyPoint Result cv::Mat VisualizeKptsResult(const cv::Mat& img, const std::vector& results, - const std::vector& colormap); + const std::vector& colormap, + float threshold = 0.2); class KeyPointDetector { public: @@ -67,7 +68,6 @@ class KeyPointDetector { void Predict(const std::vector imgs, std::vector>& center, std::vector>& scale, - const double threshold = 0.5, const int warmup = 0, const int repeats = 1, std::vector* result = nullptr, @@ -80,6 +80,8 @@ class KeyPointDetector { bool use_dark(){return this->use_dark_;} + inline float get_threshold() {return threshold_;}; + private: // Preprocess image and copy data to input buffer void Preprocess(const cv::Mat& image_mat); diff --git a/deploy/lite/src/keypoint_detector.cc b/deploy/lite/src/keypoint_detector.cc index f698bd41aa80d86d68acc5a6dffd3fcd76b2eba3..2be7471779355614457f52292443bf05ec73d21c 100644 --- a/deploy/lite/src/keypoint_detector.cc +++ b/deploy/lite/src/keypoint_detector.cc @@ -32,28 +32,29 @@ void KeyPointDetector::LoadModel(std::string model_file, int num_theads) { // Visualiztion MaskDetector results cv::Mat VisualizeKptsResult(const cv::Mat& img, const std::vector& results, - const std::vector& colormap) { + const std::vector& colormap, + float threshold) { const int edge[][2] = {{0, 1}, - {0, 2}, - {1, 3}, - {2, 4}, - {3, 5}, - {4, 6}, - {5, 7}, - {6, 8}, - {7, 9}, - {8, 10}, - {5, 11}, - {6, 12}, - {11, 13}, - {12, 14}, - {13, 15}, - {14, 16}, - {11, 12}}; + {0, 2}, + {1, 3}, + {2, 4}, + {3, 5}, + {4, 6}, + {5, 7}, + {6, 8}, + {7, 9}, + {8, 10}, + {5, 11}, + {6, 12}, + {11, 13}, + {12, 14}, + {13, 15}, + {14, 16}, + {11, 12}}; cv::Mat vis_img = img.clone(); for (int batchid = 0; batchid < results.size(); batchid++) { for (int i = 0; i < results[batchid].num_joints; i++) { - if (results[batchid].keypoints[i * 3] > 0.5) { + if (results[batchid].keypoints[i * 3] > threshold) { int x_coord = int(results[batchid].keypoints[i * 3 + 1]); int y_coord = int(results[batchid].keypoints[i * 3 + 2]); cv::circle(vis_img, @@ -64,15 +65,18 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img, } } for (int i = 0; i < results[batchid].num_joints; i++) { - int x_start = int(results[batchid].keypoints[edge[i][0] * 3 + 1]); - int y_start = int(results[batchid].keypoints[edge[i][0] * 3 + 2]); - int x_end = int(results[batchid].keypoints[edge[i][1] * 3 + 1]); - int y_end = int(results[batchid].keypoints[edge[i][1] * 3 + 2]); - cv::line(vis_img, - cv::Point2d(x_start, y_start), - cv::Point2d(x_end, y_end), - colormap[i], - 1); + if (results[batchid].keypoints[edge[i][0] * 3] > threshold && + results[batchid].keypoints[edge[i][1] * 3] > threshold) { + int x_start = int(results[batchid].keypoints[edge[i][0] * 3 + 1]); + int y_start = int(results[batchid].keypoints[edge[i][0] * 3 + 2]); + int x_end = int(results[batchid].keypoints[edge[i][1] * 3 + 1]); + int y_end = int(results[batchid].keypoints[edge[i][1] * 3 + 2]); + cv::line(vis_img, + cv::Point2d(x_start, y_start), + cv::Point2d(x_end, y_end), + colormap[i], + 1); + } } } return vis_img; @@ -119,7 +123,6 @@ void KeyPointDetector::Postprocess(std::vector& output, void KeyPointDetector::Predict(const std::vector imgs, std::vector>& center_bs, std::vector>& scale_bs, - const double threshold, const int warmup, const int repeats, std::vector* result, diff --git a/deploy/lite/src/main.cc b/deploy/lite/src/main.cc index 0e67b78ccc8c5eddc8edbeb32d976af16cd5e9f1..6d4f214f0d93aa9830d0d3a87a989f11087495a9 100644 --- a/deploy/lite/src/main.cc +++ b/deploy/lite/src/main.cc @@ -238,7 +238,6 @@ void PredictImage(const std::vector all_img_paths, keypoint->Predict(imgs_kpts, center_bs, scale_bs, - 0.5, 10, 10, &result_kpts, @@ -247,7 +246,6 @@ void PredictImage(const std::vector all_img_paths, keypoint->Predict(imgs_kpts, center_bs, scale_bs, - 0.5, 0, 1, &result_kpts, @@ -265,7 +263,7 @@ void PredictImage(const std::vector all_img_paths, output_path + "keypoint_" + image_file_path.substr(image_file_path.find_last_of('/') + 1); cv::Mat kpts_vis_img = - VisualizeKptsResult(im, result_kpts, colormap_kpts); + VisualizeKptsResult(im, result_kpts, colormap_kpts, keypoint->get_threshold()); cv::imwrite(kpts_savepath, kpts_vis_img, compression_params); printf("Visualized output saved as %s\n", kpts_savepath.c_str()); } else {