diff --git a/deploy/lite/include/keypoint_detector.h b/deploy/lite/include/keypoint_detector.h index 952a910e12cb8d02ecbc529c8d5d96122bac0c41..a4feb01cb9dbe286a386133c07cc90e816957a77 100644 --- a/deploy/lite/include/keypoint_detector.h +++ b/deploy/lite/include/keypoint_detector.h @@ -49,9 +49,11 @@ class KeyPointDetector { public: explicit KeyPointDetector(const std::string& model_dir, int cpu_threads = 1, - const int batch_size = 1) { + const int batch_size = 1, + bool use_dark = true) { config_.load_config(model_dir); threshold_ = config_.draw_threshold_; + use_dark_ = use_dark; preprocessor_.Init(config_.preprocess_info_); printf("before keypoint detector\n"); LoadModel(model_dir, cpu_threads); @@ -76,14 +78,16 @@ class KeyPointDetector { return config_.label_list_; } + bool use_dark(){return this->use_dark_;} + private: // Preprocess image and copy data to input buffer void Preprocess(const cv::Mat& image_mat); // Postprocess result - void Postprocess(const std::vector output, - const std::vector output_shape, - const std::vector idxout, - const std::vector idx_shape, + void Postprocess(std::vector& output, + std::vector& output_shape, + std::vector& idxout, + std::vector& idx_shape, std::vector* result, std::vector>& center, std::vector>& scale); @@ -95,6 +99,7 @@ class KeyPointDetector { std::vector idx_data_; float threshold_; ConfigPaser config_; + bool use_dark_; }; } // namespace PaddleDetection diff --git a/deploy/lite/include/keypoint_postprocess.h b/deploy/lite/include/keypoint_postprocess.h index 85ef8d4b5e81ebc2911a2c897dbc4dbf8c878cd2..0d1e747f306e44679d0500272e80df8a5fb19ab9 100644 --- a/deploy/lite/include/keypoint_postprocess.h +++ b/deploy/lite/include/keypoint_postprocess.h @@ -22,34 +22,35 @@ std::vector get_3rd_point(std::vector& a, std::vector& b); std::vector get_dir(float src_point_x, float src_point_y, float rot_rad); void affine_tranform( - float pt_x, float pt_y, cv::Mat& trans, float* x, int p, int num); + float pt_x, float pt_y, cv::Mat& trans, std::vector& x, int p, int num); cv::Mat get_affine_transform(std::vector& center, std::vector& scale, float rot, std::vector& output_size, int inv); -void transform_preds(float* coords, +void transform_preds(std::vector& coords, std::vector& center, std::vector& scale, std::vector& output_size, std::vector& dim, - float* target_coords); + std::vector& target_coords); void box_to_center_scale(std::vector& box, int width, int height, std::vector& center, std::vector& scale); -void get_max_preds(float* heatmap, +void get_max_preds(std::vector& heatmap, std::vector& dim, - float* preds, - float* maxvals, + std::vector& preds, + std::vector& maxvals, int batchid, int joint_idx); -void get_final_preds(float* heatmap, +void get_final_preds(std::vector& heatmap, std::vector& dim, - int64_t* idxout, + std::vector& idxout, std::vector& idxdim, std::vector& center, std::vector scale, - float* preds, - int batchid); + std::vector& preds, + int batchid, + bool DARK = true); diff --git a/deploy/lite/runtime_config.json b/deploy/lite/runtime_config.json index 47be5af73e7d07773123ef6bd5005c0f930d07b5..80971e51a8c79534704d50be2a8959f631a3cf83 100644 --- a/deploy/lite/runtime_config.json +++ b/deploy/lite/runtime_config.json @@ -5,8 +5,9 @@ "model_dir_keypoint": "./model_keypoint/", "batch_size_keypoint": 8, "threshold_keypoint": 0.5, - "image_file": "", + "image_file": "./demo.jpg", "image_dir": "", "run_benchmark": false, - "cpu_threads": 1 + "cpu_threads": 4, + "use_dark_decode": true } diff --git a/deploy/lite/src/keypoint_detector.cc b/deploy/lite/src/keypoint_detector.cc index 01fab7c3890f9fa2f994400d6601b8c2f6ebf261..f698bd41aa80d86d68acc5a6dffd3fcd76b2eba3 100644 --- a/deploy/lite/src/keypoint_detector.cc +++ b/deploy/lite/src/keypoint_detector.cc @@ -29,7 +29,11 @@ void KeyPointDetector::LoadModel(std::string model_file, int num_theads) { predictor_ = std::move(CreatePaddlePredictor(config)); } -const int edge[][2] = {{0, 1}, +// Visualiztion MaskDetector results +cv::Mat VisualizeKptsResult(const cv::Mat& img, + const std::vector& results, + const std::vector& colormap) { + const int edge[][2] = {{0, 1}, {0, 2}, {1, 3}, {2, 4}, @@ -46,10 +50,6 @@ const int edge[][2] = {{0, 1}, {13, 15}, {14, 16}, {11, 12}}; -// Visualiztion MaskDetector results -cv::Mat VisualizeKptsResult(const cv::Mat& img, - const std::vector& results, - const std::vector& colormap) { cv::Mat vis_img = img.clone(); for (int batchid = 0; batchid < results.size(); batchid++) { for (int i = 0; i < results[batchid].num_joints; i++) { @@ -85,24 +85,25 @@ void KeyPointDetector::Preprocess(const cv::Mat& ori_im) { preprocessor_.Run(&im, &inputs_); } -void KeyPointDetector::Postprocess(std::vector output, - std::vector output_shape, - std::vector idxout, - std::vector idx_shape, +void KeyPointDetector::Postprocess(std::vector& output, + std::vector& output_shape, + std::vector& idxout, + std::vector& idx_shape, std::vector* result, std::vector>& center_bs, std::vector>& scale_bs) { - float* preds = new float[output_shape[1] * 3]{0}; + std::vector preds(output_shape[1] * 3, 0); for (int batchid = 0; batchid < output_shape[0]; batchid++) { - get_final_preds(const_cast(output.data()), + get_final_preds(output, output_shape, - idxout.data(), + idxout, idx_shape, center_bs[batchid], scale_bs[batchid], preds, - batchid); + batchid, + this->use_dark()); KeyPointResult result_item; result_item.num_joints = output_shape[1]; result_item.keypoints.clear(); @@ -113,7 +114,6 @@ void KeyPointDetector::Postprocess(std::vector output, } result->push_back(result_item); } - delete[] preds; } void KeyPointDetector::Predict(const std::vector imgs, diff --git a/deploy/lite/src/keypoint_postprocess.cc b/deploy/lite/src/keypoint_postprocess.cc index 782f9c1931966f843bc77fdc9f4e5637a84630cc..6124e505dfff023e70133131796a503fef5f4de2 100644 --- a/deploy/lite/src/keypoint_postprocess.cc +++ b/deploy/lite/src/keypoint_postprocess.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "include/keypoint_postprocess.h" +#define PI 3.1415926535 +#define HALF_CIRCLE_DEGREE 180 cv::Point2f get_3rd_point(cv::Point2f& a, cv::Point2f& b) { cv::Point2f direct{a.x - b.x, a.y - b.y}; @@ -31,7 +33,7 @@ std::vector get_dir(float src_point_x, } void affine_tranform( - float pt_x, float pt_y, cv::Mat& trans, float* preds, int p) { + float pt_x, float pt_y, cv::Mat& trans, std::vector& preds, int p) { double new1[3] = {pt_x, pt_y, 1.0}; cv::Mat new_pt(3, 1, trans.type(), new1); cv::Mat w = trans * new_pt; @@ -48,7 +50,7 @@ void get_affine_transform(std::vector& center, float src_w = scale[0]; float dst_w = static_cast(output_size[0]); float dst_h = static_cast(output_size[1]); - float rot_rad = rot * 3.1415926535 / 180; + float rot_rad = rot * PI / HALF_CIRCLE_DEGREE; std::vector src_dir = get_dir(-0.5 * src_w, 0, rot_rad); std::vector dst_dir{-0.5 * dst_w, 0.0}; cv::Point2f srcPoint2f[3], dstPoint2f[3]; @@ -67,12 +69,12 @@ void get_affine_transform(std::vector& center, } } -void transform_preds(float* coords, +void transform_preds(std::vector& coords, std::vector& center, std::vector& scale, std::vector& output_size, std::vector& dim, - float* target_coords) { + std::vector& target_coords) { cv::Mat trans(2, 3, CV_64FC1); get_affine_transform(center, scale, 0, output_size, trans, 1); for (int p = 0; p < dim[1]; ++p) { @@ -81,10 +83,10 @@ void transform_preds(float* coords, } // only for batchsize == 1 -void get_max_preds(float* heatmap, +void get_max_preds(std::vector& heatmap, std::vector& dim, - float* preds, - float* maxvals, + std::vector& preds, + std::vector& maxvals, int batchid, int joint_idx) { int num_joints = dim[1]; @@ -106,14 +108,75 @@ void get_max_preds(float* heatmap, } } -void get_final_preds(float* heatmap, + +void dark_parse(std::vector& heatmap, + std::vector& dim, + std::vector& coords, + int px, + int py, + int index, + int ch){ + /*DARK postpocessing, Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + 1) offset = - hassian.inv() * derivative + 2) dx = (heatmap[x+1] - heatmap[x-1])/2. + 3) dxx = (dx[x+1] - dx[x-1])/2. + 4) derivative = Mat([dx, dy]) + 5) hassian = Mat([[dxx, dxy], [dxy, dyy]]) + */ + std::vector::const_iterator first1 = heatmap.begin() + index; + std::vector::const_iterator last1 = heatmap.begin() + index + dim[2]*dim[3]; + std::vector heatmap_ch(first1, last1); + cv::Mat heatmap_mat{heatmap_ch}; + heatmap_mat.resize(dim[2],dim[3]); + cv::GaussianBlur(heatmap_mat, heatmap_mat, cv::Size(3,3), 0, 0); + heatmap_ch.assign(heatmap_mat.datastart, heatmap_mat.dataend); + + float epsilon = 1e-10; + //sample heatmap to get values in around target location + float xy = log(fmax(heatmap_ch[py * dim[3] + px], epsilon)); + float xr = log(fmax(heatmap_ch[py * dim[3] + px + 1], epsilon)); + float xl = log(fmax(heatmap_ch[py * dim[3] + px - 1], epsilon)); + + float xr2 = log(fmax(heatmap_ch[py * dim[3] + px + 2], epsilon)); + float xl2 = log(fmax(heatmap_ch[py * dim[3] + px - 2], epsilon)); + float yu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px], epsilon)); + float yd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px], epsilon)); + float yu2 = log(fmax(heatmap_ch[(py + 2) * dim[3] + px], epsilon)); + float yd2 = log(fmax(heatmap_ch[(py - 2) * dim[3] + px], epsilon)); + float xryu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px + 1], epsilon)); + float xryd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px + 1], epsilon)); + float xlyu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px - 1], epsilon)); + float xlyd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px - 1], epsilon)); + + //compute dx/dy and dxx/dyy with sampled values + float dx = 0.5 * (xr - xl); + float dy = 0.5 * (yu - yd); + float dxx = 0.25 * (xr2 - 2*xy + xl2); + float dxy = 0.25 * (xryu - xryd - xlyu + xlyd); + float dyy = 0.25 * (yu2 - 2*xy + yd2); + + //finally get offset by derivative and hassian, which combined by dx/dy and dxx/dyy + if(dxx * dyy - dxy*dxy != 0){ + float M[2][2] = {dxx, dxy, dxy, dyy}; + float D[2] = {dx, dy}; + cv::Mat hassian(2,2,CV_32F,M); + cv::Mat derivative(2,1,CV_32F,D); + cv::Mat offset = - hassian.inv() * derivative; + coords[ch * 2] += offset.at(0,0); + coords[ch * 2 + 1] += offset.at(1,0); + } +} + +void get_final_preds(std::vector& heatmap, std::vector& dim, - int64_t* idxout, + std::vector& idxout, std::vector& idxdim, std::vector& center, std::vector scale, - float* preds, - int batchid) { + std::vector& preds, + int batchid, + bool DARK) { std::vector coords; coords.resize(dim[1] * 2); int heatmap_height = dim[2]; @@ -130,18 +193,23 @@ void get_final_preds(float* heatmap, int px = int(coords[j * 2] + 0.5); int py = int(coords[j * 2 + 1] + 0.5); - if (px > 1 && px < heatmap_width - 1) { - float diff_x = heatmap[index + py * dim[3] + px + 1] - - heatmap[index + py * dim[3] + px - 1]; - coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25; + if(DARK && px > 1 && px < heatmap_width - 2){ + dark_parse(heatmap, dim, coords, px, py, index, j); } - if (py > 1 && py < heatmap_height - 1) { - float diff_y = heatmap[index + (py + 1) * dim[3] + px] - - heatmap[index + (py - 1) * dim[3] + px]; - coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25; + else{ + if (px > 0 && px < heatmap_width - 1) { + float diff_x = heatmap[index + py * dim[3] + px + 1] - + heatmap[index + py * dim[3] + px - 1]; + coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25; + } + if (py > 0 && py < heatmap_height - 1) { + float diff_y = heatmap[index + (py + 1) * dim[3] + px] - + heatmap[index + (py - 1) * dim[3] + px]; + coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25; + } } } - + std::vector img_size{heatmap_width, heatmap_height}; - transform_preds(coords.data(), center, scale, img_size, dim, preds); -} \ No newline at end of file + transform_preds(coords, center, scale, img_size, dim, preds); +} diff --git a/deploy/lite/src/main.cc b/deploy/lite/src/main.cc index b651fba1d26c61b7edede67d728ebdda9288ad64..bbc1193882096b9f5a7f300080d5aee827ab87a3 100644 --- a/deploy/lite/src/main.cc +++ b/deploy/lite/src/main.cc @@ -308,7 +308,8 @@ int main(int argc, char** argv) { keypoint = new PaddleDetection::KeyPointDetector( RT_Config["model_dir_keypoint"].as(), RT_Config["cpu_threads"].as(), - RT_Config["batch_size_keypoint"].as()); + RT_Config["batch_size_keypoint"].as(), + RT_Config["use_dark_decode"].as()); RT_Config["batch_size_det"] = 1; printf( "batchsize of detection forced to be 1 while keypoint model is not " diff --git a/deploy/lite/src/preprocess_op.cc b/deploy/lite/src/preprocess_op.cc index 082a0e0ef18d5b4f77b307975d378367844b0e2b..fbbc5adb1d431c800b0624107d8c281f4b53c9cd 100644 --- a/deploy/lite/src/preprocess_op.cc +++ b/deploy/lite/src/preprocess_op.cc @@ -31,7 +31,7 @@ void InitInfo::Run(cv::Mat* im, ImageBlob* data) { void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) { double e = 1.0; if (is_scale_) { - e /= 255.0; + e *= 1./255.0; } (*im).convertTo(*im, CV_32FC3, e); for (int h = 0; h < im->rows; h++) { @@ -151,15 +151,18 @@ void CropImg(cv::Mat& img, int crop_y1 = std::max(0, area[1]); int crop_x2 = std::min(img.cols - 1, area[2]); int crop_y2 = std::min(img.rows - 1, area[3]); + int center_x = (crop_x1 + crop_x2) / 2.; int center_y = (crop_y1 + crop_y2) / 2.; int half_h = (crop_y2 - crop_y1) / 2.; int half_w = (crop_x2 - crop_x1) / 2.; + if (half_h * 3 > half_w * 4) { half_w = static_cast(half_h * 0.75); } else { half_h = static_cast(half_w * 4 / 3); } + crop_x1 = std::max(0, center_x - static_cast(half_w * (1 + expandratio))); crop_y1 = @@ -170,6 +173,7 @@ void CropImg(cv::Mat& img, static_cast(center_y + half_h * (1 + expandratio))); crop_img = img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1)); + center.clear(); center.emplace_back((crop_x1 + crop_x2) / 2); center.emplace_back((crop_y1 + crop_y2) / 2);