From 478972999b91811d8d4af99c7388a6f748a0b3fc Mon Sep 17 00:00:00 2001 From: zhiboniu Date: Thu, 30 Jun 2022 07:49:23 +0000 Subject: [PATCH] cpp deploy smooth ok --- deploy/cpp/include/keypoint_detector.h | 6 - deploy/cpp/include/keypoint_postprocess.h | 75 +++++++++- deploy/cpp/src/keypoint_postprocess.cc | 163 ++++++++++++++++++---- deploy/cpp/src/main_keypoint.cc | 9 ++ 4 files changed, 215 insertions(+), 38 deletions(-) diff --git a/deploy/cpp/include/keypoint_detector.h b/deploy/cpp/include/keypoint_detector.h index 55eed8f91..ce6aa0e06 100644 --- a/deploy/cpp/include/keypoint_detector.h +++ b/deploy/cpp/include/keypoint_detector.h @@ -33,12 +33,6 @@ using namespace paddle_infer; namespace PaddleDetection { -// Object KeyPoint Result -struct KeyPointResult { - // Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf - std::vector keypoints; - int num_joints = -1; -}; // Visualiztion KeyPoint Result cv::Mat VisualizeKptsResult(const cv::Mat& img, diff --git a/deploy/cpp/include/keypoint_postprocess.h b/deploy/cpp/include/keypoint_postprocess.h index 4239cdf73..fa0c7d55f 100644 --- a/deploy/cpp/include/keypoint_postprocess.h +++ b/deploy/cpp/include/keypoint_postprocess.h @@ -14,11 +14,14 @@ #pragma once +#include #include #include #include #include +namespace PaddleDetection { + std::vector get_3rd_point(std::vector& a, std::vector& b); std::vector get_dir(float src_point_x, float src_point_y, float rot_rad); @@ -37,7 +40,8 @@ void transform_preds(std::vector& coords, std::vector& scale, std::vector& output_size, std::vector& dim, - std::vector& target_coords); + std::vector& target_coords, + bool affine = false); void box_to_center_scale(std::vector& box, int width, @@ -51,7 +55,7 @@ void get_max_preds(float* heatmap, float* maxvals, int batchid, int joint_idx); - + void get_final_preds(std::vector& heatmap, std::vector& dim, std::vector& idxout, @@ -61,3 +65,70 @@ void get_final_preds(std::vector& heatmap, std::vector& preds, int batchid, bool DARK = true); + +// Object KeyPoint Result +struct KeyPointResult { + // Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf + std::vector keypoints; + int num_joints = -1; +}; + +class PoseSmooth { + public: + explicit PoseSmooth(const int width, + const int height, + std::string filter_type = "OneEuro", + float alpha = 0.5, + float fc_d = 0.1, + float fc_min = 0.1, + float beta = 0.1, + float thres_mult = 0.3) + : width(width), + height(height), + alpha(alpha), + fc_d(fc_d), + fc_min(fc_min), + beta(beta), + filter_type(filter_type), + thres_mult(thres_mult){}; + + // Run predictor + KeyPointResult smooth_process(KeyPointResult* result); + void PointSmooth(KeyPointResult* result, + KeyPointResult* keypoint_smoothed, + std::vector thresholds, + int index); + float OneEuroFilter(float x_cur, float x_pre, int loc); + float smoothing_factor(float te, float fc); + float ExpSmoothing(float x_cur, float x_pre, int loc = 0); + + private: + int width = 0; + int height = 0; + float alpha = 0.; + float fc_d = 1.; + float fc_min = 0.; + float beta = 1.; + float thres_mult = 1.; + std::string filter_type = "OneEuro"; + std::vector thresholds = {0.005, + 0.005, + 0.005, + 0.005, + 0.005, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01}; + KeyPointResult x_prev_hat; + KeyPointResult dx_prev_hat; +}; +} // namespace PaddleDetection diff --git a/deploy/cpp/src/keypoint_postprocess.cc b/deploy/cpp/src/keypoint_postprocess.cc index 405195c3e..4cb53d8aa 100644 --- a/deploy/cpp/src/keypoint_postprocess.cc +++ b/deploy/cpp/src/keypoint_postprocess.cc @@ -11,11 +11,13 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include #include "include/keypoint_postprocess.h" +#include #define PI 3.1415926535 #define HALF_CIRCLE_DEGREE 180 +namespace PaddleDetection { + cv::Point2f get_3rd_point(cv::Point2f& a, cv::Point2f& b) { cv::Point2f direct{a.x - b.x, a.y - b.y}; return cv::Point2f(a.x - direct.y, a.y + direct.x); @@ -52,7 +54,7 @@ void get_affine_transform(std::vector& center, float dst_h = static_cast(output_size[1]); float rot_rad = rot * PI / HALF_CIRCLE_DEGREE; std::vector src_dir = get_dir(-0.5 * src_w, 0, rot_rad); - std::vector dst_dir{-0.5 * dst_w, 0.0}; + std::vector dst_dir{-0.5f * dst_w, 0.0}; cv::Point2f srcPoint2f[3], dstPoint2f[3]; srcPoint2f[0] = cv::Point2f(center[0], center[1]); srcPoint2f[1] = cv::Point2f(center[0] + src_dir[0], center[1] + src_dir[1]); @@ -74,11 +76,26 @@ void transform_preds(std::vector& coords, std::vector& scale, std::vector& output_size, std::vector& dim, - std::vector& target_coords) { - cv::Mat trans(2, 3, CV_64FC1); - get_affine_transform(center, scale, 0, output_size, trans, 1); - for (int p = 0; p < dim[1]; ++p) { - affine_tranform(coords[p * 2], coords[p * 2 + 1], trans, target_coords, p); + std::vector& target_coords, + bool affine) { + if (affine) { + cv::Mat trans(2, 3, CV_64FC1); + get_affine_transform(center, scale, 0, output_size, trans, 1); + for (int p = 0; p < dim[1]; ++p) { + affine_tranform( + coords[p * 2], coords[p * 2 + 1], trans, target_coords, p); + } + } else { + float heat_w = static_cast(output_size[0]); + float heat_h = static_cast(output_size[1]); + float x_scale = scale[0] / heat_w; + float y_scale = scale[1] / heat_h; + float offset_x = center[0] - scale[0] / 2.; + float offset_y = center[1] - scale[1] / 2.; + for (int i = 0; i < dim[1]; i++) { + target_coords[i * 3 + 1] = x_scale * coords[i * 2] + offset_x; + target_coords[i * 3 + 2] = y_scale * coords[i * 2 + 1] + offset_y; + } } } @@ -111,10 +128,10 @@ void get_max_preds(float* heatmap, void dark_parse(std::vector& heatmap, std::vector& dim, std::vector& coords, - int px, - int py, + int px, + int py, int index, - int ch){ + int ch) { /*DARK postpocessing, Zhang et al. Distribution-Aware Coordinate Representation for Human Pose Estimation (CVPR 2020). 1) offset = - hassian.inv() * derivative @@ -124,16 +141,17 @@ void dark_parse(std::vector& heatmap, 5) hassian = Mat([[dxx, dxy], [dxy, dyy]]) */ std::vector::const_iterator first1 = heatmap.begin() + index; - std::vector::const_iterator last1 = heatmap.begin() + index + dim[2] * dim[3]; + std::vector::const_iterator last1 = + heatmap.begin() + index + dim[2] * dim[3]; std::vector heatmap_ch(first1, last1); - cv::Mat heatmap_mat = cv::Mat(heatmap_ch).reshape(0,dim[2]); + cv::Mat heatmap_mat = cv::Mat(heatmap_ch).reshape(0, dim[2]); heatmap_mat.convertTo(heatmap_mat, CV_32FC1); cv::GaussianBlur(heatmap_mat, heatmap_mat, cv::Size(3, 3), 0, 0); - heatmap_mat = heatmap_mat.reshape(1,1); - heatmap_ch = std::vector(heatmap_mat.reshape(1,1)); + heatmap_mat = heatmap_mat.reshape(1, 1); + heatmap_ch = std::vector(heatmap_mat.reshape(1, 1)); float epsilon = 1e-10; - //sample heatmap to get values in around target location + // sample heatmap to get values in around target location float xy = log(fmax(heatmap_ch[py * dim[3] + px], epsilon)); float xr = log(fmax(heatmap_ch[py * dim[3] + px + 1], epsilon)); float xl = log(fmax(heatmap_ch[py * dim[3] + px - 1], epsilon)); @@ -149,22 +167,23 @@ void dark_parse(std::vector& heatmap, float xlyu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px - 1], epsilon)); float xlyd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px - 1], epsilon)); - //compute dx/dy and dxx/dyy with sampled values + // compute dx/dy and dxx/dyy with sampled values float dx = 0.5 * (xr - xl); float dy = 0.5 * (yu - yd); - float dxx = 0.25 * (xr2 - 2*xy + xl2); + float dxx = 0.25 * (xr2 - 2 * xy + xl2); float dxy = 0.25 * (xryu - xryd - xlyu + xlyd); - float dyy = 0.25 * (yu2 - 2*xy + yd2); + float dyy = 0.25 * (yu2 - 2 * xy + yd2); - //finally get offset by derivative and hassian, which combined by dx/dy and dxx/dyy - if(dxx * dyy - dxy*dxy != 0){ + // finally get offset by derivative and hassian, which combined by dx/dy and + // dxx/dyy + if (dxx * dyy - dxy * dxy != 0) { float M[2][2] = {dxx, dxy, dxy, dyy}; float D[2] = {dx, dy}; - cv::Mat hassian(2,2,CV_32F,M); - cv::Mat derivative(2,1,CV_32F,D); - cv::Mat offset = - hassian.inv() * derivative; - coords[ch * 2] += offset.at(0,0); - coords[ch * 2 + 1] += offset.at(1,0); + cv::Mat hassian(2, 2, CV_32F, M); + cv::Mat derivative(2, 1, CV_32F, D); + cv::Mat offset = -hassian.inv() * derivative; + coords[ch * 2] += offset.at(0, 0); + coords[ch * 2 + 1] += offset.at(1, 0); } } @@ -193,18 +212,18 @@ void get_final_preds(std::vector& heatmap, int px = int(coords[j * 2] + 0.5); int py = int(coords[j * 2 + 1] + 0.5); - if(DARK && px > 1 && px < heatmap_width - 2 && py > 1 && py < heatmap_height - 2){ + if (DARK && px > 1 && px < heatmap_width - 2 && py > 1 && + py < heatmap_height - 2) { dark_parse(heatmap, dim, coords, px, py, index, j); - } - else{ + } else { if (px > 0 && px < heatmap_width - 1) { float diff_x = heatmap[index + py * dim[3] + px + 1] - - heatmap[index + py * dim[3] + px - 1]; + heatmap[index + py * dim[3] + px - 1]; coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25; } if (py > 0 && py < heatmap_height - 1) { float diff_y = heatmap[index + (py + 1) * dim[3] + px] - - heatmap[index + (py - 1) * dim[3] + px]; + heatmap[index + (py - 1) * dim[3] + px]; coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25; } } @@ -213,3 +232,87 @@ void get_final_preds(std::vector& heatmap, std::vector img_size{heatmap_width, heatmap_height}; transform_preds(coords, center, scale, img_size, dim, preds); } + +// Run predictor +KeyPointResult PoseSmooth::smooth_process(KeyPointResult* result) { + KeyPointResult keypoint_smoothed = *result; + if (this->x_prev_hat.num_joints == -1) { + this->x_prev_hat = *result; + this->dx_prev_hat = *result; + std::fill(dx_prev_hat.keypoints.begin(), dx_prev_hat.keypoints.end(), 0.); + return keypoint_smoothed; + } else { + for (int i = 0; i < result->num_joints; i++) { + this->PointSmooth(result, &keypoint_smoothed, this->thresholds, i); + } + return keypoint_smoothed; + } +} + +void PoseSmooth::PointSmooth(KeyPointResult* result, + KeyPointResult* keypoint_smoothed, + std::vector thresholds, + int index) { + float distance = sqrt(pow((result->keypoints[index * 3 + 1] - + this->x_prev_hat.keypoints[index * 3 + 1]) / + this->width, + 2) + + pow((result->keypoints[index * 3 + 2] - + this->x_prev_hat.keypoints[index * 3 + 2]) / + this->height, + 2)); + if (distance < thresholds[index] * this->thres_mult) { + keypoint_smoothed->keypoints[index * 3 + 1] = + this->x_prev_hat.keypoints[index * 3 + 1]; + keypoint_smoothed->keypoints[index * 3 + 2] = + this->x_prev_hat.keypoints[index * 3 + 2]; + } else { + if (this->filter_type == "OneEuro") { + keypoint_smoothed->keypoints[index * 3 + 1] = + this->OneEuroFilter(result->keypoints[index * 3 + 1], + this->x_prev_hat.keypoints[index * 3 + 1], + index * 3 + 1); + keypoint_smoothed->keypoints[index * 3 + 2] = + this->OneEuroFilter(result->keypoints[index * 3 + 2], + this->x_prev_hat.keypoints[index * 3 + 2], + index * 3 + 2); + } else { + keypoint_smoothed->keypoints[index * 3 + 1] = + this->ExpSmoothing(result->keypoints[index * 3 + 1], + this->x_prev_hat.keypoints[index * 3 + 1], + index * 3 + 1); + keypoint_smoothed->keypoints[index * 3 + 2] = + this->ExpSmoothing(result->keypoints[index * 3 + 2], + this->x_prev_hat.keypoints[index * 3 + 2], + index * 3 + 2); + } + } + return; +} + +float PoseSmooth::OneEuroFilter(float x_cur, float x_pre, int loc) { + float te = 1.0; + this->alpha = this->smoothing_factor(te, this->fc_d); + float dx_cur = (x_cur - x_pre) / te; + float dx_cur_hat = + this->ExpSmoothing(dx_cur, this->dx_prev_hat.keypoints[loc]); + + float fc = this->fc_min + this->beta * abs(dx_cur_hat); + this->alpha = this->smoothing_factor(te, fc); + float x_cur_hat = this->ExpSmoothing(x_cur, x_pre); + // printf("alpha:%f, x_cur:%f, x_pre:%f, x_cur_hat:%f\n", this->alpha, x_cur, + // x_pre, x_cur_hat); + this->x_prev_hat.keypoints[loc] = x_cur_hat; + this->dx_prev_hat.keypoints[loc] = dx_cur_hat; + return x_cur_hat; +} + +float PoseSmooth::smoothing_factor(float te, float fc) { + float r = 2 * PI * fc * te; + return r / (r + 1); +} + +float PoseSmooth::ExpSmoothing(float x_cur, float x_pre, int loc) { + return this->alpha * x_cur + (1 - this->alpha) * x_pre; +} +} // namespace PaddleDetection diff --git a/deploy/cpp/src/main_keypoint.cc b/deploy/cpp/src/main_keypoint.cc index 7701d5ebb..da333f6eb 100644 --- a/deploy/cpp/src/main_keypoint.cc +++ b/deploy/cpp/src/main_keypoint.cc @@ -219,6 +219,8 @@ void PredictVideo(const std::string& video_path, printf("create video writer failed!\n"); return; } + PaddleDetection::PoseSmooth smoother = + PaddleDetection::PoseSmooth(video_width, video_height); std::vector result; std::vector bbox_num; @@ -307,6 +309,13 @@ void PredictVideo(const std::string& video_path, scale_bs.clear(); } } + + if (result_kpts.size() == 1) { + for (int i = 0; i < result_kpts.size(); i++) { + result_kpts[i] = smoother.smooth_process(&(result_kpts[i])); + } + } + cv::Mat out_im = VisualizeKptsResult(frame, result_kpts, colormap_kpts); video_out.write(out_im); } else { -- GitLab