未验证 提交 85218f9a 编写于 作者: Z zhiboniu 提交者: GitHub

lite deploy: fix pose visualize (#4349)

上级 4f0aa7ef
...@@ -43,7 +43,8 @@ struct KeyPointResult { ...@@ -43,7 +43,8 @@ struct KeyPointResult {
// Visualiztion KeyPoint Result // Visualiztion KeyPoint Result
cv::Mat VisualizeKptsResult(const cv::Mat& img, cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results, const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap); const std::vector<int>& colormap,
float threshold = 0.2);
class KeyPointDetector { class KeyPointDetector {
public: public:
...@@ -67,7 +68,6 @@ class KeyPointDetector { ...@@ -67,7 +68,6 @@ class KeyPointDetector {
void Predict(const std::vector<cv::Mat> imgs, void Predict(const std::vector<cv::Mat> imgs,
std::vector<std::vector<float>>& center, std::vector<std::vector<float>>& center,
std::vector<std::vector<float>>& scale, std::vector<std::vector<float>>& scale,
const double threshold = 0.5,
const int warmup = 0, const int warmup = 0,
const int repeats = 1, const int repeats = 1,
std::vector<KeyPointResult>* result = nullptr, std::vector<KeyPointResult>* result = nullptr,
...@@ -80,6 +80,8 @@ class KeyPointDetector { ...@@ -80,6 +80,8 @@ class KeyPointDetector {
bool use_dark(){return this->use_dark_;} bool use_dark(){return this->use_dark_;}
inline float get_threshold() {return threshold_;};
private: private:
// Preprocess image and copy data to input buffer // Preprocess image and copy data to input buffer
void Preprocess(const cv::Mat& image_mat); void Preprocess(const cv::Mat& image_mat);
......
...@@ -32,28 +32,29 @@ void KeyPointDetector::LoadModel(std::string model_file, int num_theads) { ...@@ -32,28 +32,29 @@ void KeyPointDetector::LoadModel(std::string model_file, int num_theads) {
// Visualiztion MaskDetector results // Visualiztion MaskDetector results
cv::Mat VisualizeKptsResult(const cv::Mat& img, cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results, const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap) { const std::vector<int>& colormap,
float threshold) {
const int edge[][2] = {{0, 1}, const int edge[][2] = {{0, 1},
{0, 2}, {0, 2},
{1, 3}, {1, 3},
{2, 4}, {2, 4},
{3, 5}, {3, 5},
{4, 6}, {4, 6},
{5, 7}, {5, 7},
{6, 8}, {6, 8},
{7, 9}, {7, 9},
{8, 10}, {8, 10},
{5, 11}, {5, 11},
{6, 12}, {6, 12},
{11, 13}, {11, 13},
{12, 14}, {12, 14},
{13, 15}, {13, 15},
{14, 16}, {14, 16},
{11, 12}}; {11, 12}};
cv::Mat vis_img = img.clone(); cv::Mat vis_img = img.clone();
for (int batchid = 0; batchid < results.size(); batchid++) { for (int batchid = 0; batchid < results.size(); batchid++) {
for (int i = 0; i < results[batchid].num_joints; i++) { for (int i = 0; i < results[batchid].num_joints; i++) {
if (results[batchid].keypoints[i * 3] > 0.5) { if (results[batchid].keypoints[i * 3] > threshold) {
int x_coord = int(results[batchid].keypoints[i * 3 + 1]); int x_coord = int(results[batchid].keypoints[i * 3 + 1]);
int y_coord = int(results[batchid].keypoints[i * 3 + 2]); int y_coord = int(results[batchid].keypoints[i * 3 + 2]);
cv::circle(vis_img, cv::circle(vis_img,
...@@ -64,15 +65,18 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img, ...@@ -64,15 +65,18 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img,
} }
} }
for (int i = 0; i < results[batchid].num_joints; i++) { for (int i = 0; i < results[batchid].num_joints; i++) {
int x_start = int(results[batchid].keypoints[edge[i][0] * 3 + 1]); if (results[batchid].keypoints[edge[i][0] * 3] > threshold &&
int y_start = int(results[batchid].keypoints[edge[i][0] * 3 + 2]); results[batchid].keypoints[edge[i][1] * 3] > threshold) {
int x_end = int(results[batchid].keypoints[edge[i][1] * 3 + 1]); int x_start = int(results[batchid].keypoints[edge[i][0] * 3 + 1]);
int y_end = int(results[batchid].keypoints[edge[i][1] * 3 + 2]); int y_start = int(results[batchid].keypoints[edge[i][0] * 3 + 2]);
cv::line(vis_img, int x_end = int(results[batchid].keypoints[edge[i][1] * 3 + 1]);
cv::Point2d(x_start, y_start), int y_end = int(results[batchid].keypoints[edge[i][1] * 3 + 2]);
cv::Point2d(x_end, y_end), cv::line(vis_img,
colormap[i], cv::Point2d(x_start, y_start),
1); cv::Point2d(x_end, y_end),
colormap[i],
1);
}
} }
} }
return vis_img; return vis_img;
...@@ -119,7 +123,6 @@ void KeyPointDetector::Postprocess(std::vector<float>& output, ...@@ -119,7 +123,6 @@ void KeyPointDetector::Postprocess(std::vector<float>& output,
void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs, void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
std::vector<std::vector<float>>& center_bs, std::vector<std::vector<float>>& center_bs,
std::vector<std::vector<float>>& scale_bs, std::vector<std::vector<float>>& scale_bs,
const double threshold,
const int warmup, const int warmup,
const int repeats, const int repeats,
std::vector<KeyPointResult>* result, std::vector<KeyPointResult>* result,
......
...@@ -238,7 +238,6 @@ void PredictImage(const std::vector<std::string> all_img_paths, ...@@ -238,7 +238,6 @@ void PredictImage(const std::vector<std::string> all_img_paths,
keypoint->Predict(imgs_kpts, keypoint->Predict(imgs_kpts,
center_bs, center_bs,
scale_bs, scale_bs,
0.5,
10, 10,
10, 10,
&result_kpts, &result_kpts,
...@@ -247,7 +246,6 @@ void PredictImage(const std::vector<std::string> all_img_paths, ...@@ -247,7 +246,6 @@ void PredictImage(const std::vector<std::string> all_img_paths,
keypoint->Predict(imgs_kpts, keypoint->Predict(imgs_kpts,
center_bs, center_bs,
scale_bs, scale_bs,
0.5,
0, 0,
1, 1,
&result_kpts, &result_kpts,
...@@ -265,7 +263,7 @@ void PredictImage(const std::vector<std::string> all_img_paths, ...@@ -265,7 +263,7 @@ void PredictImage(const std::vector<std::string> all_img_paths,
output_path + "keypoint_" + output_path + "keypoint_" +
image_file_path.substr(image_file_path.find_last_of('/') + 1); image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::Mat kpts_vis_img = cv::Mat kpts_vis_img =
VisualizeKptsResult(im, result_kpts, colormap_kpts); VisualizeKptsResult(im, result_kpts, colormap_kpts, keypoint->get_threshold());
cv::imwrite(kpts_savepath, kpts_vis_img, compression_params); cv::imwrite(kpts_savepath, kpts_vis_img, compression_params);
printf("Visualized output saved as %s\n", kpts_savepath.c_str()); printf("Visualized output saved as %s\n", kpts_savepath.c_str());
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册