keypoint_postprocess.cc 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "include/keypoint_postprocess.h"
Z
zhiboniu 已提交
15
#include <math.h>
16 17 18
#define PI 3.1415926535
#define HALF_CIRCLE_DEGREE 180

Z
zhiboniu 已提交
19 20
namespace PaddleDetection {

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
cv::Point2f get_3rd_point(cv::Point2f& a, cv::Point2f& b) {
  cv::Point2f direct{a.x - b.x, a.y - b.y};
  return cv::Point2f(a.x - direct.y, a.y + direct.x);
}

std::vector<float> get_dir(float src_point_x,
                           float src_point_y,
                           float rot_rad) {
  float sn = sin(rot_rad);
  float cs = cos(rot_rad);
  std::vector<float> src_result{0.0, 0.0};
  src_result[0] = src_point_x * cs - src_point_y * sn;
  src_result[1] = src_point_x * sn + src_point_y * cs;
  return src_result;
}

void affine_tranform(
    float pt_x, float pt_y, cv::Mat& trans, std::vector<float>& preds, int p) {
  double new1[3] = {pt_x, pt_y, 1.0};
  cv::Mat new_pt(3, 1, trans.type(), new1);
  cv::Mat w = trans * new_pt;
  preds[p * 3 + 1] = static_cast<float>(w.at<double>(0, 0));
  preds[p * 3 + 2] = static_cast<float>(w.at<double>(1, 0));
}

void get_affine_transform(std::vector<float>& center,
                          std::vector<float>& scale,
                          float rot,
                          std::vector<int>& output_size,
                          cv::Mat& trans,
                          int inv) {
  float src_w = scale[0];
  float dst_w = static_cast<float>(output_size[0]);
  float dst_h = static_cast<float>(output_size[1]);
  float rot_rad = rot * PI / HALF_CIRCLE_DEGREE;
  std::vector<float> src_dir = get_dir(-0.5 * src_w, 0, rot_rad);
Z
zhiboniu 已提交
57
  std::vector<float> dst_dir{-0.5f * dst_w, 0.0};
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
  cv::Point2f srcPoint2f[3], dstPoint2f[3];
  srcPoint2f[0] = cv::Point2f(center[0], center[1]);
  srcPoint2f[1] = cv::Point2f(center[0] + src_dir[0], center[1] + src_dir[1]);
  srcPoint2f[2] = get_3rd_point(srcPoint2f[0], srcPoint2f[1]);

  dstPoint2f[0] = cv::Point2f(dst_w * 0.5, dst_h * 0.5);
  dstPoint2f[1] =
      cv::Point2f(dst_w * 0.5 + dst_dir[0], dst_h * 0.5 + dst_dir[1]);
  dstPoint2f[2] = get_3rd_point(dstPoint2f[0], dstPoint2f[1]);
  if (inv == 0) {
    trans = cv::getAffineTransform(srcPoint2f, dstPoint2f);
  } else {
    trans = cv::getAffineTransform(dstPoint2f, srcPoint2f);
  }
}

void transform_preds(std::vector<float>& coords,
                     std::vector<float>& center,
                     std::vector<float>& scale,
                     std::vector<int>& output_size,
                     std::vector<int>& dim,
Z
zhiboniu 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
                     std::vector<float>& target_coords,
                     bool affine) {
  if (affine) {
    cv::Mat trans(2, 3, CV_64FC1);
    get_affine_transform(center, scale, 0, output_size, trans, 1);
    for (int p = 0; p < dim[1]; ++p) {
      affine_tranform(
          coords[p * 2], coords[p * 2 + 1], trans, target_coords, p);
    }
  } else {
    float heat_w = static_cast<float>(output_size[0]);
    float heat_h = static_cast<float>(output_size[1]);
    float x_scale = scale[0] / heat_w;
    float y_scale = scale[1] / heat_h;
    float offset_x = center[0] - scale[0] / 2.;
    float offset_y = center[1] - scale[1] / 2.;
    for (int i = 0; i < dim[1]; i++) {
      target_coords[i * 3 + 1] = x_scale * coords[i * 2] + offset_x;
      target_coords[i * 3 + 2] = y_scale * coords[i * 2 + 1] + offset_y;
    }
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
  }
}

// only for batchsize == 1
void get_max_preds(float* heatmap,
                   std::vector<int>& dim,
                   std::vector<float>& preds,
                   float* maxvals,
                   int batchid,
                   int joint_idx) {
  int num_joints = dim[1];
  int width = dim[3];
  std::vector<int> idx;
  idx.resize(num_joints * 2);

  for (int j = 0; j < dim[1]; j++) {
    float* index = &(
        heatmap[batchid * num_joints * dim[2] * dim[3] + j * dim[2] * dim[3]]);
    float* end = index + dim[2] * dim[3];
    float* max_dis = std::max_element(index, end);
    auto max_id = std::distance(index, max_dis);
    maxvals[j] = *max_dis;
    if (*max_dis > 0) {
      preds[j * 2] = static_cast<float>(max_id % width);
      preds[j * 2 + 1] = static_cast<float>(max_id / width);
    }
  }
}

void dark_parse(std::vector<float>& heatmap,
                std::vector<int>& dim,
                std::vector<float>& coords,
Z
zhiboniu 已提交
131 132
                int px,
                int py,
133
                int index,
Z
zhiboniu 已提交
134
                int ch) {
135 136 137 138 139 140 141 142 143
  /*DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
  Representation for Human Pose Estimation (CVPR 2020).
  1) offset = - hassian.inv() * derivative
  2) dx = (heatmap[x+1] - heatmap[x-1])/2.
  3) dxx = (dx[x+1] - dx[x-1])/2.
  4) derivative = Mat([dx, dy])
  5) hassian = Mat([[dxx, dxy], [dxy, dyy]])
  */
  std::vector<float>::const_iterator first1 = heatmap.begin() + index;
Z
zhiboniu 已提交
144 145
  std::vector<float>::const_iterator last1 =
      heatmap.begin() + index + dim[2] * dim[3];
146
  std::vector<float> heatmap_ch(first1, last1);
Z
zhiboniu 已提交
147
  cv::Mat heatmap_mat = cv::Mat(heatmap_ch).reshape(0, dim[2]);
Z
zhiboniu 已提交
148 149
  heatmap_mat.convertTo(heatmap_mat, CV_32FC1);
  cv::GaussianBlur(heatmap_mat, heatmap_mat, cv::Size(3, 3), 0, 0);
Z
zhiboniu 已提交
150 151
  heatmap_mat = heatmap_mat.reshape(1, 1);
  heatmap_ch = std::vector<float>(heatmap_mat.reshape(1, 1));
152 153

  float epsilon = 1e-10;
Z
zhiboniu 已提交
154
  // sample heatmap to get values in around target location
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  float xy = log(fmax(heatmap_ch[py * dim[3] + px], epsilon));
  float xr = log(fmax(heatmap_ch[py * dim[3] + px + 1], epsilon));
  float xl = log(fmax(heatmap_ch[py * dim[3] + px - 1], epsilon));

  float xr2 = log(fmax(heatmap_ch[py * dim[3] + px + 2], epsilon));
  float xl2 = log(fmax(heatmap_ch[py * dim[3] + px - 2], epsilon));
  float yu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px], epsilon));
  float yd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px], epsilon));
  float yu2 = log(fmax(heatmap_ch[(py + 2) * dim[3] + px], epsilon));
  float yd2 = log(fmax(heatmap_ch[(py - 2) * dim[3] + px], epsilon));
  float xryu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px + 1], epsilon));
  float xryd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px + 1], epsilon));
  float xlyu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px - 1], epsilon));
  float xlyd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px - 1], epsilon));

Z
zhiboniu 已提交
170
  // compute dx/dy and dxx/dyy with sampled values
171 172
  float dx = 0.5 * (xr - xl);
  float dy = 0.5 * (yu - yd);
Z
zhiboniu 已提交
173
  float dxx = 0.25 * (xr2 - 2 * xy + xl2);
174
  float dxy = 0.25 * (xryu - xryd - xlyu + xlyd);
Z
zhiboniu 已提交
175
  float dyy = 0.25 * (yu2 - 2 * xy + yd2);
176

Z
zhiboniu 已提交
177 178 179
  // finally get offset by derivative and hassian, which combined by dx/dy and
  // dxx/dyy
  if (dxx * dyy - dxy * dxy != 0) {
180 181
    float M[2][2] = {dxx, dxy, dxy, dyy};
    float D[2] = {dx, dy};
Z
zhiboniu 已提交
182 183 184 185 186
    cv::Mat hassian(2, 2, CV_32F, M);
    cv::Mat derivative(2, 1, CV_32F, D);
    cv::Mat offset = -hassian.inv() * derivative;
    coords[ch * 2] += offset.at<float>(0, 0);
    coords[ch * 2 + 1] += offset.at<float>(1, 0);
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
  }
}

void get_final_preds(std::vector<float>& heatmap,
                     std::vector<int>& dim,
                     std::vector<int64_t>& idxout,
                     std::vector<int>& idxdim,
                     std::vector<float>& center,
                     std::vector<float> scale,
                     std::vector<float>& preds,
                     int batchid,
                     bool DARK) {
  std::vector<float> coords;
  coords.resize(dim[1] * 2);
  int heatmap_height = dim[2];
  int heatmap_width = dim[3];

  for (int j = 0; j < dim[1]; ++j) {
    int index = (batchid * dim[1] + j) * dim[2] * dim[3];

    int idx = idxout[batchid * dim[1] + j];
    preds[j * 3] = heatmap[index + idx];
    coords[j * 2] = idx % heatmap_width;
    coords[j * 2 + 1] = idx / heatmap_width;

    int px = int(coords[j * 2] + 0.5);
    int py = int(coords[j * 2 + 1] + 0.5);

Z
zhiboniu 已提交
215 216
    if (DARK && px > 1 && px < heatmap_width - 2 && py > 1 &&
        py < heatmap_height - 2) {
217
      dark_parse(heatmap, dim, coords, px, py, index, j);
Z
zhiboniu 已提交
218
    } else {
219 220
      if (px > 0 && px < heatmap_width - 1) {
        float diff_x = heatmap[index + py * dim[3] + px + 1] -
Z
zhiboniu 已提交
221
                       heatmap[index + py * dim[3] + px - 1];
222 223 224 225
        coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25;
      }
      if (py > 0 && py < heatmap_height - 1) {
        float diff_y = heatmap[index + (py + 1) * dim[3] + px] -
Z
zhiboniu 已提交
226
                       heatmap[index + (py - 1) * dim[3] + px];
227 228 229 230 231 232 233 234
        coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25;
      }
    }
  }

  std::vector<int> img_size{heatmap_width, heatmap_height};
  transform_preds(coords, center, scale, img_size, dim, preds);
}
Z
zhiboniu 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318

// Run predictor
KeyPointResult PoseSmooth::smooth_process(KeyPointResult* result) {
  KeyPointResult keypoint_smoothed = *result;
  if (this->x_prev_hat.num_joints == -1) {
    this->x_prev_hat = *result;
    this->dx_prev_hat = *result;
    std::fill(dx_prev_hat.keypoints.begin(), dx_prev_hat.keypoints.end(), 0.);
    return keypoint_smoothed;
  } else {
    for (int i = 0; i < result->num_joints; i++) {
      this->PointSmooth(result, &keypoint_smoothed, this->thresholds, i);
    }
    return keypoint_smoothed;
  }
}

void PoseSmooth::PointSmooth(KeyPointResult* result,
                             KeyPointResult* keypoint_smoothed,
                             std::vector<float> thresholds,
                             int index) {
  float distance = sqrt(pow((result->keypoints[index * 3 + 1] -
                             this->x_prev_hat.keypoints[index * 3 + 1]) /
                                this->width,
                            2) +
                        pow((result->keypoints[index * 3 + 2] -
                             this->x_prev_hat.keypoints[index * 3 + 2]) /
                                this->height,
                            2));
  if (distance < thresholds[index] * this->thres_mult) {
    keypoint_smoothed->keypoints[index * 3 + 1] =
        this->x_prev_hat.keypoints[index * 3 + 1];
    keypoint_smoothed->keypoints[index * 3 + 2] =
        this->x_prev_hat.keypoints[index * 3 + 2];
  } else {
    if (this->filter_type == "OneEuro") {
      keypoint_smoothed->keypoints[index * 3 + 1] =
          this->OneEuroFilter(result->keypoints[index * 3 + 1],
                              this->x_prev_hat.keypoints[index * 3 + 1],
                              index * 3 + 1);
      keypoint_smoothed->keypoints[index * 3 + 2] =
          this->OneEuroFilter(result->keypoints[index * 3 + 2],
                              this->x_prev_hat.keypoints[index * 3 + 2],
                              index * 3 + 2);
    } else {
      keypoint_smoothed->keypoints[index * 3 + 1] =
          this->ExpSmoothing(result->keypoints[index * 3 + 1],
                             this->x_prev_hat.keypoints[index * 3 + 1],
                             index * 3 + 1);
      keypoint_smoothed->keypoints[index * 3 + 2] =
          this->ExpSmoothing(result->keypoints[index * 3 + 2],
                             this->x_prev_hat.keypoints[index * 3 + 2],
                             index * 3 + 2);
    }
  }
  return;
}

float PoseSmooth::OneEuroFilter(float x_cur, float x_pre, int loc) {
  float te = 1.0;
  this->alpha = this->smoothing_factor(te, this->fc_d);
  float dx_cur = (x_cur - x_pre) / te;
  float dx_cur_hat =
      this->ExpSmoothing(dx_cur, this->dx_prev_hat.keypoints[loc]);

  float fc = this->fc_min + this->beta * abs(dx_cur_hat);
  this->alpha = this->smoothing_factor(te, fc);
  float x_cur_hat = this->ExpSmoothing(x_cur, x_pre);
  // printf("alpha:%f, x_cur:%f, x_pre:%f, x_cur_hat:%f\n", this->alpha, x_cur,
  // x_pre, x_cur_hat);
  this->x_prev_hat.keypoints[loc] = x_cur_hat;
  this->dx_prev_hat.keypoints[loc] = dx_cur_hat;
  return x_cur_hat;
}

float PoseSmooth::smoothing_factor(float te, float fc) {
  float r = 2 * PI * fc * te;
  return r / (r + 1);
}

float PoseSmooth::ExpSmoothing(float x_cur, float x_pre, int loc) {
  return this->alpha * x_cur + (1 - this->alpha) * x_pre;
}
}  // namespace PaddleDetection