提交 47897299 编写于 作者: Z zhiboniu 提交者: zhiboniu

cpp deploy smooth ok

上级 5c3d64a4
......@@ -33,12 +33,6 @@
using namespace paddle_infer;
namespace PaddleDetection {
// Object KeyPoint Result
struct KeyPointResult {
// Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf
std::vector<float> keypoints;
int num_joints = -1;
};
// Visualiztion KeyPoint Result
cv::Mat VisualizeKptsResult(const cv::Mat& img,
......
......@@ -14,11 +14,14 @@
#pragma once
#include <math.h>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <vector>
namespace PaddleDetection {
std::vector<float> get_3rd_point(std::vector<float>& a, std::vector<float>& b);
std::vector<float> get_dir(float src_point_x, float src_point_y, float rot_rad);
......@@ -37,7 +40,8 @@ void transform_preds(std::vector<float>& coords,
std::vector<float>& scale,
std::vector<int>& output_size,
std::vector<int>& dim,
std::vector<float>& target_coords);
std::vector<float>& target_coords,
bool affine = false);
void box_to_center_scale(std::vector<int>& box,
int width,
......@@ -51,7 +55,7 @@ void get_max_preds(float* heatmap,
float* maxvals,
int batchid,
int joint_idx);
void get_final_preds(std::vector<float>& heatmap,
std::vector<int>& dim,
std::vector<int64_t>& idxout,
......@@ -61,3 +65,70 @@ void get_final_preds(std::vector<float>& heatmap,
std::vector<float>& preds,
int batchid,
bool DARK = true);
// Object KeyPoint Result
struct KeyPointResult {
// Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf
std::vector<float> keypoints;
int num_joints = -1;
};
class PoseSmooth {
public:
explicit PoseSmooth(const int width,
const int height,
std::string filter_type = "OneEuro",
float alpha = 0.5,
float fc_d = 0.1,
float fc_min = 0.1,
float beta = 0.1,
float thres_mult = 0.3)
: width(width),
height(height),
alpha(alpha),
fc_d(fc_d),
fc_min(fc_min),
beta(beta),
filter_type(filter_type),
thres_mult(thres_mult){};
// Run predictor
KeyPointResult smooth_process(KeyPointResult* result);
void PointSmooth(KeyPointResult* result,
KeyPointResult* keypoint_smoothed,
std::vector<float> thresholds,
int index);
float OneEuroFilter(float x_cur, float x_pre, int loc);
float smoothing_factor(float te, float fc);
float ExpSmoothing(float x_cur, float x_pre, int loc = 0);
private:
int width = 0;
int height = 0;
float alpha = 0.;
float fc_d = 1.;
float fc_min = 0.;
float beta = 1.;
float thres_mult = 1.;
std::string filter_type = "OneEuro";
std::vector<float> thresholds = {0.005,
0.005,
0.005,
0.005,
0.005,
0.01,
0.01,
0.01,
0.01,
0.01,
0.01,
0.01,
0.01,
0.01,
0.01,
0.01,
0.01};
KeyPointResult x_prev_hat;
KeyPointResult dx_prev_hat;
};
} // namespace PaddleDetection
......@@ -11,11 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include "include/keypoint_postprocess.h"
#include <math.h>
#define PI 3.1415926535
#define HALF_CIRCLE_DEGREE 180
namespace PaddleDetection {
cv::Point2f get_3rd_point(cv::Point2f& a, cv::Point2f& b) {
cv::Point2f direct{a.x - b.x, a.y - b.y};
return cv::Point2f(a.x - direct.y, a.y + direct.x);
......@@ -52,7 +54,7 @@ void get_affine_transform(std::vector<float>& center,
float dst_h = static_cast<float>(output_size[1]);
float rot_rad = rot * PI / HALF_CIRCLE_DEGREE;
std::vector<float> src_dir = get_dir(-0.5 * src_w, 0, rot_rad);
std::vector<float> dst_dir{-0.5 * dst_w, 0.0};
std::vector<float> dst_dir{-0.5f * dst_w, 0.0};
cv::Point2f srcPoint2f[3], dstPoint2f[3];
srcPoint2f[0] = cv::Point2f(center[0], center[1]);
srcPoint2f[1] = cv::Point2f(center[0] + src_dir[0], center[1] + src_dir[1]);
......@@ -74,11 +76,26 @@ void transform_preds(std::vector<float>& coords,
std::vector<float>& scale,
std::vector<int>& output_size,
std::vector<int>& dim,
std::vector<float>& target_coords) {
cv::Mat trans(2, 3, CV_64FC1);
get_affine_transform(center, scale, 0, output_size, trans, 1);
for (int p = 0; p < dim[1]; ++p) {
affine_tranform(coords[p * 2], coords[p * 2 + 1], trans, target_coords, p);
std::vector<float>& target_coords,
bool affine) {
if (affine) {
cv::Mat trans(2, 3, CV_64FC1);
get_affine_transform(center, scale, 0, output_size, trans, 1);
for (int p = 0; p < dim[1]; ++p) {
affine_tranform(
coords[p * 2], coords[p * 2 + 1], trans, target_coords, p);
}
} else {
float heat_w = static_cast<float>(output_size[0]);
float heat_h = static_cast<float>(output_size[1]);
float x_scale = scale[0] / heat_w;
float y_scale = scale[1] / heat_h;
float offset_x = center[0] - scale[0] / 2.;
float offset_y = center[1] - scale[1] / 2.;
for (int i = 0; i < dim[1]; i++) {
target_coords[i * 3 + 1] = x_scale * coords[i * 2] + offset_x;
target_coords[i * 3 + 2] = y_scale * coords[i * 2 + 1] + offset_y;
}
}
}
......@@ -111,10 +128,10 @@ void get_max_preds(float* heatmap,
void dark_parse(std::vector<float>& heatmap,
std::vector<int>& dim,
std::vector<float>& coords,
int px,
int py,
int px,
int py,
int index,
int ch){
int ch) {
/*DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
1) offset = - hassian.inv() * derivative
......@@ -124,16 +141,17 @@ void dark_parse(std::vector<float>& heatmap,
5) hassian = Mat([[dxx, dxy], [dxy, dyy]])
*/
std::vector<float>::const_iterator first1 = heatmap.begin() + index;
std::vector<float>::const_iterator last1 = heatmap.begin() + index + dim[2] * dim[3];
std::vector<float>::const_iterator last1 =
heatmap.begin() + index + dim[2] * dim[3];
std::vector<float> heatmap_ch(first1, last1);
cv::Mat heatmap_mat = cv::Mat(heatmap_ch).reshape(0,dim[2]);
cv::Mat heatmap_mat = cv::Mat(heatmap_ch).reshape(0, dim[2]);
heatmap_mat.convertTo(heatmap_mat, CV_32FC1);
cv::GaussianBlur(heatmap_mat, heatmap_mat, cv::Size(3, 3), 0, 0);
heatmap_mat = heatmap_mat.reshape(1,1);
heatmap_ch = std::vector<float>(heatmap_mat.reshape(1,1));
heatmap_mat = heatmap_mat.reshape(1, 1);
heatmap_ch = std::vector<float>(heatmap_mat.reshape(1, 1));
float epsilon = 1e-10;
//sample heatmap to get values in around target location
// sample heatmap to get values in around target location
float xy = log(fmax(heatmap_ch[py * dim[3] + px], epsilon));
float xr = log(fmax(heatmap_ch[py * dim[3] + px + 1], epsilon));
float xl = log(fmax(heatmap_ch[py * dim[3] + px - 1], epsilon));
......@@ -149,22 +167,23 @@ void dark_parse(std::vector<float>& heatmap,
float xlyu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px - 1], epsilon));
float xlyd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px - 1], epsilon));
//compute dx/dy and dxx/dyy with sampled values
// compute dx/dy and dxx/dyy with sampled values
float dx = 0.5 * (xr - xl);
float dy = 0.5 * (yu - yd);
float dxx = 0.25 * (xr2 - 2*xy + xl2);
float dxx = 0.25 * (xr2 - 2 * xy + xl2);
float dxy = 0.25 * (xryu - xryd - xlyu + xlyd);
float dyy = 0.25 * (yu2 - 2*xy + yd2);
float dyy = 0.25 * (yu2 - 2 * xy + yd2);
//finally get offset by derivative and hassian, which combined by dx/dy and dxx/dyy
if(dxx * dyy - dxy*dxy != 0){
// finally get offset by derivative and hassian, which combined by dx/dy and
// dxx/dyy
if (dxx * dyy - dxy * dxy != 0) {
float M[2][2] = {dxx, dxy, dxy, dyy};
float D[2] = {dx, dy};
cv::Mat hassian(2,2,CV_32F,M);
cv::Mat derivative(2,1,CV_32F,D);
cv::Mat offset = - hassian.inv() * derivative;
coords[ch * 2] += offset.at<float>(0,0);
coords[ch * 2 + 1] += offset.at<float>(1,0);
cv::Mat hassian(2, 2, CV_32F, M);
cv::Mat derivative(2, 1, CV_32F, D);
cv::Mat offset = -hassian.inv() * derivative;
coords[ch * 2] += offset.at<float>(0, 0);
coords[ch * 2 + 1] += offset.at<float>(1, 0);
}
}
......@@ -193,18 +212,18 @@ void get_final_preds(std::vector<float>& heatmap,
int px = int(coords[j * 2] + 0.5);
int py = int(coords[j * 2 + 1] + 0.5);
if(DARK && px > 1 && px < heatmap_width - 2 && py > 1 && py < heatmap_height - 2){
if (DARK && px > 1 && px < heatmap_width - 2 && py > 1 &&
py < heatmap_height - 2) {
dark_parse(heatmap, dim, coords, px, py, index, j);
}
else{
} else {
if (px > 0 && px < heatmap_width - 1) {
float diff_x = heatmap[index + py * dim[3] + px + 1] -
heatmap[index + py * dim[3] + px - 1];
heatmap[index + py * dim[3] + px - 1];
coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25;
}
if (py > 0 && py < heatmap_height - 1) {
float diff_y = heatmap[index + (py + 1) * dim[3] + px] -
heatmap[index + (py - 1) * dim[3] + px];
heatmap[index + (py - 1) * dim[3] + px];
coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25;
}
}
......@@ -213,3 +232,87 @@ void get_final_preds(std::vector<float>& heatmap,
std::vector<int> img_size{heatmap_width, heatmap_height};
transform_preds(coords, center, scale, img_size, dim, preds);
}
// Run predictor
KeyPointResult PoseSmooth::smooth_process(KeyPointResult* result) {
KeyPointResult keypoint_smoothed = *result;
if (this->x_prev_hat.num_joints == -1) {
this->x_prev_hat = *result;
this->dx_prev_hat = *result;
std::fill(dx_prev_hat.keypoints.begin(), dx_prev_hat.keypoints.end(), 0.);
return keypoint_smoothed;
} else {
for (int i = 0; i < result->num_joints; i++) {
this->PointSmooth(result, &keypoint_smoothed, this->thresholds, i);
}
return keypoint_smoothed;
}
}
void PoseSmooth::PointSmooth(KeyPointResult* result,
KeyPointResult* keypoint_smoothed,
std::vector<float> thresholds,
int index) {
float distance = sqrt(pow((result->keypoints[index * 3 + 1] -
this->x_prev_hat.keypoints[index * 3 + 1]) /
this->width,
2) +
pow((result->keypoints[index * 3 + 2] -
this->x_prev_hat.keypoints[index * 3 + 2]) /
this->height,
2));
if (distance < thresholds[index] * this->thres_mult) {
keypoint_smoothed->keypoints[index * 3 + 1] =
this->x_prev_hat.keypoints[index * 3 + 1];
keypoint_smoothed->keypoints[index * 3 + 2] =
this->x_prev_hat.keypoints[index * 3 + 2];
} else {
if (this->filter_type == "OneEuro") {
keypoint_smoothed->keypoints[index * 3 + 1] =
this->OneEuroFilter(result->keypoints[index * 3 + 1],
this->x_prev_hat.keypoints[index * 3 + 1],
index * 3 + 1);
keypoint_smoothed->keypoints[index * 3 + 2] =
this->OneEuroFilter(result->keypoints[index * 3 + 2],
this->x_prev_hat.keypoints[index * 3 + 2],
index * 3 + 2);
} else {
keypoint_smoothed->keypoints[index * 3 + 1] =
this->ExpSmoothing(result->keypoints[index * 3 + 1],
this->x_prev_hat.keypoints[index * 3 + 1],
index * 3 + 1);
keypoint_smoothed->keypoints[index * 3 + 2] =
this->ExpSmoothing(result->keypoints[index * 3 + 2],
this->x_prev_hat.keypoints[index * 3 + 2],
index * 3 + 2);
}
}
return;
}
float PoseSmooth::OneEuroFilter(float x_cur, float x_pre, int loc) {
float te = 1.0;
this->alpha = this->smoothing_factor(te, this->fc_d);
float dx_cur = (x_cur - x_pre) / te;
float dx_cur_hat =
this->ExpSmoothing(dx_cur, this->dx_prev_hat.keypoints[loc]);
float fc = this->fc_min + this->beta * abs(dx_cur_hat);
this->alpha = this->smoothing_factor(te, fc);
float x_cur_hat = this->ExpSmoothing(x_cur, x_pre);
// printf("alpha:%f, x_cur:%f, x_pre:%f, x_cur_hat:%f\n", this->alpha, x_cur,
// x_pre, x_cur_hat);
this->x_prev_hat.keypoints[loc] = x_cur_hat;
this->dx_prev_hat.keypoints[loc] = dx_cur_hat;
return x_cur_hat;
}
float PoseSmooth::smoothing_factor(float te, float fc) {
float r = 2 * PI * fc * te;
return r / (r + 1);
}
float PoseSmooth::ExpSmoothing(float x_cur, float x_pre, int loc) {
return this->alpha * x_cur + (1 - this->alpha) * x_pre;
}
} // namespace PaddleDetection
......@@ -219,6 +219,8 @@ void PredictVideo(const std::string& video_path,
printf("create video writer failed!\n");
return;
}
PaddleDetection::PoseSmooth smoother =
PaddleDetection::PoseSmooth(video_width, video_height);
std::vector<PaddleDetection::ObjectResult> result;
std::vector<int> bbox_num;
......@@ -307,6 +309,13 @@ void PredictVideo(const std::string& video_path,
scale_bs.clear();
}
}
if (result_kpts.size() == 1) {
for (int i = 0; i < result_kpts.size(); i++) {
result_kpts[i] = smoother.smooth_process(&(result_kpts[i]));
}
}
cv::Mat out_im = VisualizeKptsResult(frame, result_kpts, colormap_kpts);
video_out.write(out_im);
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册