未验证 提交 85218f9a 编写于 作者: Z zhiboniu 提交者: GitHub

lite deploy: fix pose visualize (#4349)

上级 4f0aa7ef
......@@ -43,7 +43,8 @@ struct KeyPointResult {
// Visualiztion KeyPoint Result
cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap);
const std::vector<int>& colormap,
float threshold = 0.2);
class KeyPointDetector {
public:
......@@ -67,7 +68,6 @@ class KeyPointDetector {
void Predict(const std::vector<cv::Mat> imgs,
std::vector<std::vector<float>>& center,
std::vector<std::vector<float>>& scale,
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
std::vector<KeyPointResult>* result = nullptr,
......@@ -80,6 +80,8 @@ class KeyPointDetector {
bool use_dark(){return this->use_dark_;}
inline float get_threshold() {return threshold_;};
private:
// Preprocess image and copy data to input buffer
void Preprocess(const cv::Mat& image_mat);
......
......@@ -32,7 +32,8 @@ void KeyPointDetector::LoadModel(std::string model_file, int num_theads) {
// Visualiztion MaskDetector results
cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap) {
const std::vector<int>& colormap,
float threshold) {
const int edge[][2] = {{0, 1},
{0, 2},
{1, 3},
......@@ -53,7 +54,7 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img,
cv::Mat vis_img = img.clone();
for (int batchid = 0; batchid < results.size(); batchid++) {
for (int i = 0; i < results[batchid].num_joints; i++) {
if (results[batchid].keypoints[i * 3] > 0.5) {
if (results[batchid].keypoints[i * 3] > threshold) {
int x_coord = int(results[batchid].keypoints[i * 3 + 1]);
int y_coord = int(results[batchid].keypoints[i * 3 + 2]);
cv::circle(vis_img,
......@@ -64,6 +65,8 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img,
}
}
for (int i = 0; i < results[batchid].num_joints; i++) {
if (results[batchid].keypoints[edge[i][0] * 3] > threshold &&
results[batchid].keypoints[edge[i][1] * 3] > threshold) {
int x_start = int(results[batchid].keypoints[edge[i][0] * 3 + 1]);
int y_start = int(results[batchid].keypoints[edge[i][0] * 3 + 2]);
int x_end = int(results[batchid].keypoints[edge[i][1] * 3 + 1]);
......@@ -75,6 +78,7 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img,
1);
}
}
}
return vis_img;
}
......@@ -119,7 +123,6 @@ void KeyPointDetector::Postprocess(std::vector<float>& output,
void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
std::vector<std::vector<float>>& center_bs,
std::vector<std::vector<float>>& scale_bs,
const double threshold,
const int warmup,
const int repeats,
std::vector<KeyPointResult>* result,
......
......@@ -238,7 +238,6 @@ void PredictImage(const std::vector<std::string> all_img_paths,
keypoint->Predict(imgs_kpts,
center_bs,
scale_bs,
0.5,
10,
10,
&result_kpts,
......@@ -247,7 +246,6 @@ void PredictImage(const std::vector<std::string> all_img_paths,
keypoint->Predict(imgs_kpts,
center_bs,
scale_bs,
0.5,
0,
1,
&result_kpts,
......@@ -265,7 +263,7 @@ void PredictImage(const std::vector<std::string> all_img_paths,
output_path + "keypoint_" +
image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::Mat kpts_vis_img =
VisualizeKptsResult(im, result_kpts, colormap_kpts);
VisualizeKptsResult(im, result_kpts, colormap_kpts, keypoint->get_threshold());
cv::imwrite(kpts_savepath, kpts_vis_img, compression_params);
printf("Visualized output saved as %s\n", kpts_savepath.c_str());
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册