未验证 提交 19bc3294 编写于 作者: Z zhiboniu 提交者: GitHub

add darkpose support (#4232)

上级 e226a78d
...@@ -49,9 +49,11 @@ class KeyPointDetector { ...@@ -49,9 +49,11 @@ class KeyPointDetector {
public: public:
explicit KeyPointDetector(const std::string& model_dir, explicit KeyPointDetector(const std::string& model_dir,
int cpu_threads = 1, int cpu_threads = 1,
const int batch_size = 1) { const int batch_size = 1,
bool use_dark = true) {
config_.load_config(model_dir); config_.load_config(model_dir);
threshold_ = config_.draw_threshold_; threshold_ = config_.draw_threshold_;
use_dark_ = use_dark;
preprocessor_.Init(config_.preprocess_info_); preprocessor_.Init(config_.preprocess_info_);
printf("before keypoint detector\n"); printf("before keypoint detector\n");
LoadModel(model_dir, cpu_threads); LoadModel(model_dir, cpu_threads);
...@@ -76,14 +78,16 @@ class KeyPointDetector { ...@@ -76,14 +78,16 @@ class KeyPointDetector {
return config_.label_list_; return config_.label_list_;
} }
bool use_dark(){return this->use_dark_;}
private: private:
// Preprocess image and copy data to input buffer // Preprocess image and copy data to input buffer
void Preprocess(const cv::Mat& image_mat); void Preprocess(const cv::Mat& image_mat);
// Postprocess result // Postprocess result
void Postprocess(const std::vector<float> output, void Postprocess(std::vector<float>& output,
const std::vector<int64_t> output_shape, std::vector<int64_t>& output_shape,
const std::vector<int64_t> idxout, std::vector<int64_t>& idxout,
const std::vector<int64_t> idx_shape, std::vector<int64_t>& idx_shape,
std::vector<KeyPointResult>* result, std::vector<KeyPointResult>* result,
std::vector<std::vector<float>>& center, std::vector<std::vector<float>>& center,
std::vector<std::vector<float>>& scale); std::vector<std::vector<float>>& scale);
...@@ -95,6 +99,7 @@ class KeyPointDetector { ...@@ -95,6 +99,7 @@ class KeyPointDetector {
std::vector<int64_t> idx_data_; std::vector<int64_t> idx_data_;
float threshold_; float threshold_;
ConfigPaser config_; ConfigPaser config_;
bool use_dark_;
}; };
} // namespace PaddleDetection } // namespace PaddleDetection
...@@ -22,34 +22,35 @@ ...@@ -22,34 +22,35 @@
std::vector<float> get_3rd_point(std::vector<float>& a, std::vector<float>& b); std::vector<float> get_3rd_point(std::vector<float>& a, std::vector<float>& b);
std::vector<float> get_dir(float src_point_x, float src_point_y, float rot_rad); std::vector<float> get_dir(float src_point_x, float src_point_y, float rot_rad);
void affine_tranform( void affine_tranform(
float pt_x, float pt_y, cv::Mat& trans, float* x, int p, int num); float pt_x, float pt_y, cv::Mat& trans, std::vector<float>& x, int p, int num);
cv::Mat get_affine_transform(std::vector<float>& center, cv::Mat get_affine_transform(std::vector<float>& center,
std::vector<float>& scale, std::vector<float>& scale,
float rot, float rot,
std::vector<int>& output_size, std::vector<int>& output_size,
int inv); int inv);
void transform_preds(float* coords, void transform_preds(std::vector<float>& coords,
std::vector<float>& center, std::vector<float>& center,
std::vector<float>& scale, std::vector<float>& scale,
std::vector<int>& output_size, std::vector<int>& output_size,
std::vector<int>& dim, std::vector<int>& dim,
float* target_coords); std::vector<float>& target_coords);
void box_to_center_scale(std::vector<int>& box, void box_to_center_scale(std::vector<int>& box,
int width, int width,
int height, int height,
std::vector<float>& center, std::vector<float>& center,
std::vector<float>& scale); std::vector<float>& scale);
void get_max_preds(float* heatmap, void get_max_preds(std::vector<float>& heatmap,
std::vector<int64_t>& dim, std::vector<int64_t>& dim,
float* preds, std::vector<float>& preds,
float* maxvals, std::vector<float>& maxvals,
int batchid, int batchid,
int joint_idx); int joint_idx);
void get_final_preds(float* heatmap, void get_final_preds(std::vector<float>& heatmap,
std::vector<int64_t>& dim, std::vector<int64_t>& dim,
int64_t* idxout, std::vector<int64_t>& idxout,
std::vector<int64_t>& idxdim, std::vector<int64_t>& idxdim,
std::vector<float>& center, std::vector<float>& center,
std::vector<float> scale, std::vector<float> scale,
float* preds, std::vector<float>& preds,
int batchid); int batchid,
bool DARK = true);
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
"model_dir_keypoint": "./model_keypoint/", "model_dir_keypoint": "./model_keypoint/",
"batch_size_keypoint": 8, "batch_size_keypoint": 8,
"threshold_keypoint": 0.5, "threshold_keypoint": 0.5,
"image_file": "", "image_file": "./demo.jpg",
"image_dir": "", "image_dir": "",
"run_benchmark": false, "run_benchmark": false,
"cpu_threads": 1 "cpu_threads": 4,
"use_dark_decode": true
} }
...@@ -29,7 +29,11 @@ void KeyPointDetector::LoadModel(std::string model_file, int num_theads) { ...@@ -29,7 +29,11 @@ void KeyPointDetector::LoadModel(std::string model_file, int num_theads) {
predictor_ = std::move(CreatePaddlePredictor<MobileConfig>(config)); predictor_ = std::move(CreatePaddlePredictor<MobileConfig>(config));
} }
const int edge[][2] = {{0, 1}, // Visualiztion MaskDetector results
cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap) {
const int edge[][2] = {{0, 1},
{0, 2}, {0, 2},
{1, 3}, {1, 3},
{2, 4}, {2, 4},
...@@ -46,10 +50,6 @@ const int edge[][2] = {{0, 1}, ...@@ -46,10 +50,6 @@ const int edge[][2] = {{0, 1},
{13, 15}, {13, 15},
{14, 16}, {14, 16},
{11, 12}}; {11, 12}};
// Visualiztion MaskDetector results
cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap) {
cv::Mat vis_img = img.clone(); cv::Mat vis_img = img.clone();
for (int batchid = 0; batchid < results.size(); batchid++) { for (int batchid = 0; batchid < results.size(); batchid++) {
for (int i = 0; i < results[batchid].num_joints; i++) { for (int i = 0; i < results[batchid].num_joints; i++) {
...@@ -85,24 +85,25 @@ void KeyPointDetector::Preprocess(const cv::Mat& ori_im) { ...@@ -85,24 +85,25 @@ void KeyPointDetector::Preprocess(const cv::Mat& ori_im) {
preprocessor_.Run(&im, &inputs_); preprocessor_.Run(&im, &inputs_);
} }
void KeyPointDetector::Postprocess(std::vector<float> output, void KeyPointDetector::Postprocess(std::vector<float>& output,
std::vector<int64_t> output_shape, std::vector<int64_t>& output_shape,
std::vector<int64_t> idxout, std::vector<int64_t>& idxout,
std::vector<int64_t> idx_shape, std::vector<int64_t>& idx_shape,
std::vector<KeyPointResult>* result, std::vector<KeyPointResult>* result,
std::vector<std::vector<float>>& center_bs, std::vector<std::vector<float>>& center_bs,
std::vector<std::vector<float>>& scale_bs) { std::vector<std::vector<float>>& scale_bs) {
float* preds = new float[output_shape[1] * 3]{0}; std::vector<float> preds(output_shape[1] * 3, 0);
for (int batchid = 0; batchid < output_shape[0]; batchid++) { for (int batchid = 0; batchid < output_shape[0]; batchid++) {
get_final_preds(const_cast<float*>(output.data()), get_final_preds(output,
output_shape, output_shape,
idxout.data(), idxout,
idx_shape, idx_shape,
center_bs[batchid], center_bs[batchid],
scale_bs[batchid], scale_bs[batchid],
preds, preds,
batchid); batchid,
this->use_dark());
KeyPointResult result_item; KeyPointResult result_item;
result_item.num_joints = output_shape[1]; result_item.num_joints = output_shape[1];
result_item.keypoints.clear(); result_item.keypoints.clear();
...@@ -113,7 +114,6 @@ void KeyPointDetector::Postprocess(std::vector<float> output, ...@@ -113,7 +114,6 @@ void KeyPointDetector::Postprocess(std::vector<float> output,
} }
result->push_back(result_item); result->push_back(result_item);
} }
delete[] preds;
} }
void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs, void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "include/keypoint_postprocess.h" #include "include/keypoint_postprocess.h"
#define PI 3.1415926535
#define HALF_CIRCLE_DEGREE 180
cv::Point2f get_3rd_point(cv::Point2f& a, cv::Point2f& b) { cv::Point2f get_3rd_point(cv::Point2f& a, cv::Point2f& b) {
cv::Point2f direct{a.x - b.x, a.y - b.y}; cv::Point2f direct{a.x - b.x, a.y - b.y};
...@@ -31,7 +33,7 @@ std::vector<float> get_dir(float src_point_x, ...@@ -31,7 +33,7 @@ std::vector<float> get_dir(float src_point_x,
} }
void affine_tranform( void affine_tranform(
float pt_x, float pt_y, cv::Mat& trans, float* preds, int p) { float pt_x, float pt_y, cv::Mat& trans, std::vector<float>& preds, int p) {
double new1[3] = {pt_x, pt_y, 1.0}; double new1[3] = {pt_x, pt_y, 1.0};
cv::Mat new_pt(3, 1, trans.type(), new1); cv::Mat new_pt(3, 1, trans.type(), new1);
cv::Mat w = trans * new_pt; cv::Mat w = trans * new_pt;
...@@ -48,7 +50,7 @@ void get_affine_transform(std::vector<float>& center, ...@@ -48,7 +50,7 @@ void get_affine_transform(std::vector<float>& center,
float src_w = scale[0]; float src_w = scale[0];
float dst_w = static_cast<float>(output_size[0]); float dst_w = static_cast<float>(output_size[0]);
float dst_h = static_cast<float>(output_size[1]); float dst_h = static_cast<float>(output_size[1]);
float rot_rad = rot * 3.1415926535 / 180; float rot_rad = rot * PI / HALF_CIRCLE_DEGREE;
std::vector<float> src_dir = get_dir(-0.5 * src_w, 0, rot_rad); std::vector<float> src_dir = get_dir(-0.5 * src_w, 0, rot_rad);
std::vector<float> dst_dir{-0.5 * dst_w, 0.0}; std::vector<float> dst_dir{-0.5 * dst_w, 0.0};
cv::Point2f srcPoint2f[3], dstPoint2f[3]; cv::Point2f srcPoint2f[3], dstPoint2f[3];
...@@ -67,12 +69,12 @@ void get_affine_transform(std::vector<float>& center, ...@@ -67,12 +69,12 @@ void get_affine_transform(std::vector<float>& center,
} }
} }
void transform_preds(float* coords, void transform_preds(std::vector<float>& coords,
std::vector<float>& center, std::vector<float>& center,
std::vector<float>& scale, std::vector<float>& scale,
std::vector<int>& output_size, std::vector<int>& output_size,
std::vector<int64_t>& dim, std::vector<int64_t>& dim,
float* target_coords) { std::vector<float>& target_coords) {
cv::Mat trans(2, 3, CV_64FC1); cv::Mat trans(2, 3, CV_64FC1);
get_affine_transform(center, scale, 0, output_size, trans, 1); get_affine_transform(center, scale, 0, output_size, trans, 1);
for (int p = 0; p < dim[1]; ++p) { for (int p = 0; p < dim[1]; ++p) {
...@@ -81,10 +83,10 @@ void transform_preds(float* coords, ...@@ -81,10 +83,10 @@ void transform_preds(float* coords,
} }
// only for batchsize == 1 // only for batchsize == 1
void get_max_preds(float* heatmap, void get_max_preds(std::vector<float>& heatmap,
std::vector<int>& dim, std::vector<int>& dim,
float* preds, std::vector<float>& preds,
float* maxvals, std::vector<float>& maxvals,
int batchid, int batchid,
int joint_idx) { int joint_idx) {
int num_joints = dim[1]; int num_joints = dim[1];
...@@ -106,14 +108,75 @@ void get_max_preds(float* heatmap, ...@@ -106,14 +108,75 @@ void get_max_preds(float* heatmap,
} }
} }
void get_final_preds(float* heatmap,
void dark_parse(std::vector<float>& heatmap,
std::vector<int64_t>& dim,
std::vector<float>& coords,
int px,
int py,
int index,
int ch){
/*DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
1) offset = - hassian.inv() * derivative
2) dx = (heatmap[x+1] - heatmap[x-1])/2.
3) dxx = (dx[x+1] - dx[x-1])/2.
4) derivative = Mat([dx, dy])
5) hassian = Mat([[dxx, dxy], [dxy, dyy]])
*/
std::vector<float>::const_iterator first1 = heatmap.begin() + index;
std::vector<float>::const_iterator last1 = heatmap.begin() + index + dim[2]*dim[3];
std::vector<float> heatmap_ch(first1, last1);
cv::Mat heatmap_mat{heatmap_ch};
heatmap_mat.resize(dim[2],dim[3]);
cv::GaussianBlur(heatmap_mat, heatmap_mat, cv::Size(3,3), 0, 0);
heatmap_ch.assign(heatmap_mat.datastart, heatmap_mat.dataend);
float epsilon = 1e-10;
//sample heatmap to get values in around target location
float xy = log(fmax(heatmap_ch[py * dim[3] + px], epsilon));
float xr = log(fmax(heatmap_ch[py * dim[3] + px + 1], epsilon));
float xl = log(fmax(heatmap_ch[py * dim[3] + px - 1], epsilon));
float xr2 = log(fmax(heatmap_ch[py * dim[3] + px + 2], epsilon));
float xl2 = log(fmax(heatmap_ch[py * dim[3] + px - 2], epsilon));
float yu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px], epsilon));
float yd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px], epsilon));
float yu2 = log(fmax(heatmap_ch[(py + 2) * dim[3] + px], epsilon));
float yd2 = log(fmax(heatmap_ch[(py - 2) * dim[3] + px], epsilon));
float xryu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px + 1], epsilon));
float xryd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px + 1], epsilon));
float xlyu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px - 1], epsilon));
float xlyd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px - 1], epsilon));
//compute dx/dy and dxx/dyy with sampled values
float dx = 0.5 * (xr - xl);
float dy = 0.5 * (yu - yd);
float dxx = 0.25 * (xr2 - 2*xy + xl2);
float dxy = 0.25 * (xryu - xryd - xlyu + xlyd);
float dyy = 0.25 * (yu2 - 2*xy + yd2);
//finally get offset by derivative and hassian, which combined by dx/dy and dxx/dyy
if(dxx * dyy - dxy*dxy != 0){
float M[2][2] = {dxx, dxy, dxy, dyy};
float D[2] = {dx, dy};
cv::Mat hassian(2,2,CV_32F,M);
cv::Mat derivative(2,1,CV_32F,D);
cv::Mat offset = - hassian.inv() * derivative;
coords[ch * 2] += offset.at<float>(0,0);
coords[ch * 2 + 1] += offset.at<float>(1,0);
}
}
void get_final_preds(std::vector<float>& heatmap,
std::vector<int64_t>& dim, std::vector<int64_t>& dim,
int64_t* idxout, std::vector<int64_t>& idxout,
std::vector<int64_t>& idxdim, std::vector<int64_t>& idxdim,
std::vector<float>& center, std::vector<float>& center,
std::vector<float> scale, std::vector<float> scale,
float* preds, std::vector<float>& preds,
int batchid) { int batchid,
bool DARK) {
std::vector<float> coords; std::vector<float> coords;
coords.resize(dim[1] * 2); coords.resize(dim[1] * 2);
int heatmap_height = dim[2]; int heatmap_height = dim[2];
...@@ -130,18 +193,23 @@ void get_final_preds(float* heatmap, ...@@ -130,18 +193,23 @@ void get_final_preds(float* heatmap,
int px = int(coords[j * 2] + 0.5); int px = int(coords[j * 2] + 0.5);
int py = int(coords[j * 2 + 1] + 0.5); int py = int(coords[j * 2 + 1] + 0.5);
if (px > 1 && px < heatmap_width - 1) { if(DARK && px > 1 && px < heatmap_width - 2){
float diff_x = heatmap[index + py * dim[3] + px + 1] - dark_parse(heatmap, dim, coords, px, py, index, j);
heatmap[index + py * dim[3] + px - 1];
coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25;
} }
if (py > 1 && py < heatmap_height - 1) { else{
float diff_y = heatmap[index + (py + 1) * dim[3] + px] - if (px > 0 && px < heatmap_width - 1) {
heatmap[index + (py - 1) * dim[3] + px]; float diff_x = heatmap[index + py * dim[3] + px + 1] -
coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25; heatmap[index + py * dim[3] + px - 1];
coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25;
}
if (py > 0 && py < heatmap_height - 1) {
float diff_y = heatmap[index + (py + 1) * dim[3] + px] -
heatmap[index + (py - 1) * dim[3] + px];
coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25;
}
} }
} }
std::vector<int> img_size{heatmap_width, heatmap_height}; std::vector<int> img_size{heatmap_width, heatmap_height};
transform_preds(coords.data(), center, scale, img_size, dim, preds); transform_preds(coords, center, scale, img_size, dim, preds);
} }
\ No newline at end of file
...@@ -308,7 +308,8 @@ int main(int argc, char** argv) { ...@@ -308,7 +308,8 @@ int main(int argc, char** argv) {
keypoint = new PaddleDetection::KeyPointDetector( keypoint = new PaddleDetection::KeyPointDetector(
RT_Config["model_dir_keypoint"].as<std::string>(), RT_Config["model_dir_keypoint"].as<std::string>(),
RT_Config["cpu_threads"].as<int>(), RT_Config["cpu_threads"].as<int>(),
RT_Config["batch_size_keypoint"].as<int>()); RT_Config["batch_size_keypoint"].as<int>(),
RT_Config["use_dark_decode"].as<bool>());
RT_Config["batch_size_det"] = 1; RT_Config["batch_size_det"] = 1;
printf( printf(
"batchsize of detection forced to be 1 while keypoint model is not " "batchsize of detection forced to be 1 while keypoint model is not "
......
...@@ -31,7 +31,7 @@ void InitInfo::Run(cv::Mat* im, ImageBlob* data) { ...@@ -31,7 +31,7 @@ void InitInfo::Run(cv::Mat* im, ImageBlob* data) {
void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) { void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) {
double e = 1.0; double e = 1.0;
if (is_scale_) { if (is_scale_) {
e /= 255.0; e *= 1./255.0;
} }
(*im).convertTo(*im, CV_32FC3, e); (*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) { for (int h = 0; h < im->rows; h++) {
...@@ -151,15 +151,18 @@ void CropImg(cv::Mat& img, ...@@ -151,15 +151,18 @@ void CropImg(cv::Mat& img,
int crop_y1 = std::max(0, area[1]); int crop_y1 = std::max(0, area[1]);
int crop_x2 = std::min(img.cols - 1, area[2]); int crop_x2 = std::min(img.cols - 1, area[2]);
int crop_y2 = std::min(img.rows - 1, area[3]); int crop_y2 = std::min(img.rows - 1, area[3]);
int center_x = (crop_x1 + crop_x2) / 2.; int center_x = (crop_x1 + crop_x2) / 2.;
int center_y = (crop_y1 + crop_y2) / 2.; int center_y = (crop_y1 + crop_y2) / 2.;
int half_h = (crop_y2 - crop_y1) / 2.; int half_h = (crop_y2 - crop_y1) / 2.;
int half_w = (crop_x2 - crop_x1) / 2.; int half_w = (crop_x2 - crop_x1) / 2.;
if (half_h * 3 > half_w * 4) { if (half_h * 3 > half_w * 4) {
half_w = static_cast<int>(half_h * 0.75); half_w = static_cast<int>(half_h * 0.75);
} else { } else {
half_h = static_cast<int>(half_w * 4 / 3); half_h = static_cast<int>(half_w * 4 / 3);
} }
crop_x1 = crop_x1 =
std::max(0, center_x - static_cast<int>(half_w * (1 + expandratio))); std::max(0, center_x - static_cast<int>(half_w * (1 + expandratio)));
crop_y1 = crop_y1 =
...@@ -170,6 +173,7 @@ void CropImg(cv::Mat& img, ...@@ -170,6 +173,7 @@ void CropImg(cv::Mat& img,
static_cast<int>(center_y + half_h * (1 + expandratio))); static_cast<int>(center_y + half_h * (1 + expandratio)));
crop_img = crop_img =
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1)); img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
center.clear(); center.clear();
center.emplace_back((crop_x1 + crop_x2) / 2); center.emplace_back((crop_x1 + crop_x2) / 2);
center.emplace_back((crop_y1 + crop_y2) / 2); center.emplace_back((crop_y1 + crop_y2) / 2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册