diff --git a/deploy/cpp/include/object_detector.h b/deploy/cpp/include/object_detector.h index 30dd09ab7ef808314a353c72660f78b368004d25..47bd29362c85eafc3825d25af73694803e2a1504 100644 --- a/deploy/cpp/include/object_detector.h +++ b/deploy/cpp/include/object_detector.h @@ -25,7 +25,7 @@ #include #include -#include "paddle_inference_api.h" // NOLINT +#include "paddle_inference_api.h" // NOLINT #include "include/config_parser.h" #include "include/picodet_postprocess.h" @@ -33,29 +33,25 @@ #include "include/utils.h" using namespace paddle_infer; - namespace PaddleDetection { // Generate visualization colormap for each class std::vector GenerateColorMap(int num_class); // Visualiztion Detection Result -cv::Mat VisualizeResult( - const cv::Mat& img, - const std::vector& results, - const std::vector& lables, - const std::vector& colormap, - const bool is_rbox); +cv::Mat +VisualizeResult(const cv::Mat &img, + const std::vector &results, + const std::vector &lables, + const std::vector &colormap, const bool is_rbox); class ObjectDetector { - public: - explicit ObjectDetector(const std::string& model_dir, - const std::string& device = "CPU", - bool use_mkldnn = false, - int cpu_threads = 1, - const std::string& run_mode = "paddle", - const int batch_size = 1, - const int gpu_id = 0, +public: + explicit ObjectDetector(const std::string &model_dir, + const std::string &device = "CPU", + bool use_mkldnn = false, int cpu_threads = 1, + const std::string &run_mode = "paddle", + const int batch_size = 1, const int gpu_id = 0, const int trt_min_shape = 1, const int trt_max_shape = 1280, const int trt_opt_shape = 640, @@ -78,25 +74,22 @@ class ObjectDetector { } // Load Paddle inference model - void LoadModel(const std::string& model_dir, - const int batch_size = 1, - const std::string& run_mode = "paddle"); + void LoadModel(const std::string &model_dir, const int batch_size = 1, + const std::string &run_mode = "paddle"); // Run predictor - void Predict(const std::vector imgs, - const double threshold = 0.5, - const int warmup = 0, - const int repeats = 1, - std::vector* result = nullptr, - std::vector* bbox_num = nullptr, - std::vector* times = nullptr); + void Predict(const std::vector imgs, const double threshold = 0.5, + const int warmup = 0, const int repeats = 1, + std::vector *result = nullptr, + std::vector *bbox_num = nullptr, + std::vector *times = nullptr); // Get Model Label list - const std::vector& GetLabelList() const { + const std::vector &GetLabelList() const { return config_.label_list_; } - private: +private: std::string device_ = "CPU"; int gpu_id_ = 0; int cpu_math_library_num_threads_ = 1; @@ -108,14 +101,18 @@ class ObjectDetector { int trt_opt_shape_ = 640; bool trt_calib_mode_ = false; // Preprocess image and copy data to input buffer - void Preprocess(const cv::Mat& image_mat); + void Preprocess(const cv::Mat &image_mat); // Postprocess result void Postprocess(const std::vector mats, - std::vector* result, - std::vector bbox_num, - std::vector output_data_, - std::vector output_mask_data_, - bool is_rbox); + std::vector *result, + std::vector bbox_num, std::vector output_data_, + std::vector output_mask_data_, bool is_rbox); + + void SOLOv2Postprocess( + const std::vector mats, std::vector *result, + std::vector *bbox_num, std::vector out_bbox_num_data_, + std::vector out_label_data_, std::vector out_score_data_, + std::vector out_global_mask_data_, float threshold = 0.5); std::shared_ptr predictor_; Preprocessor preprocessor_; @@ -124,4 +121,4 @@ class ObjectDetector { ConfigPaser config_; }; -} // namespace PaddleDetection +} // namespace PaddleDetection diff --git a/deploy/cpp/src/object_detector.cc b/deploy/cpp/src/object_detector.cc index 509928bd15dbbaaf9d8b8614eb61bd3a90da8409..d4f2ceb5d7c07142e51e2b0008148e5d90b55adc 100644 --- a/deploy/cpp/src/object_detector.cc +++ b/deploy/cpp/src/object_detector.cc @@ -41,17 +41,12 @@ void ObjectDetector::LoadModel(const std::string &model_dir, } else if (run_mode == "trt_int8") { precision = paddle_infer::Config::Precision::kInt8; } else { - printf( - "run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or " - "'trt_int8'"); + printf("run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or " + "'trt_int8'"); } // set tensorrt - config.EnableTensorRtEngine(1 << 30, - batch_size, - this->min_subgraph_size_, - precision, - false, - this->trt_calib_mode_); + config.EnableTensorRtEngine(1 << 30, batch_size, this->min_subgraph_size_, + precision, false, this->trt_calib_mode_); // set use dynamic shape if (this->use_dynamic_shape_) { @@ -69,8 +64,8 @@ void ObjectDetector::LoadModel(const std::string &model_dir, const std::map> map_opt_input_shape = { {"image", opt_input_shape}}; - config.SetTRTDynamicShapeInfo( - map_min_input_shape, map_max_input_shape, map_opt_input_shape); + config.SetTRTDynamicShapeInfo(map_min_input_shape, map_max_input_shape, + map_opt_input_shape); std::cout << "TensorRT dynamic shape enabled" << std::endl; } } @@ -95,12 +90,11 @@ void ObjectDetector::LoadModel(const std::string &model_dir, } // Visualiztion MaskDetector results -cv::Mat VisualizeResult( - const cv::Mat &img, - const std::vector &results, - const std::vector &lables, - const std::vector &colormap, - const bool is_rbox = false) { +cv::Mat +VisualizeResult(const cv::Mat &img, + const std::vector &results, + const std::vector &lables, + const std::vector &colormap, const bool is_rbox = false) { cv::Mat vis_img = img.clone(); int img_h = vis_img.rows; int img_w = vis_img.cols; @@ -149,16 +143,10 @@ cv::Mat VisualizeResult( std::vector contours; cv::Mat hierarchy; mask.convertTo(mask, CV_8U); - cv::findContours( - mask, contours, hierarchy, cv::RETR_CCOMP, cv::CHAIN_APPROX_SIMPLE); - cv::drawContours(colored_img, - contours, - -1, - roi_color, - -1, - cv::LINE_8, - hierarchy, - 100); + cv::findContours(mask, contours, hierarchy, cv::RETR_CCOMP, + cv::CHAIN_APPROX_SIMPLE); + cv::drawContours(colored_img, contours, -1, roi_color, -1, cv::LINE_8, + hierarchy, 100); cv::Mat debug_roi = vis_img; colored_img = 0.4 * colored_img + 0.6 * vis_img; @@ -170,19 +158,13 @@ cv::Mat VisualizeResult( origin.y = results[i].rect[1]; // Configure text background - cv::Rect text_back = cv::Rect(results[i].rect[0], - results[i].rect[1] - text_size.height, - text_size.width, - text_size.height); + cv::Rect text_back = + cv::Rect(results[i].rect[0], results[i].rect[1] - text_size.height, + text_size.width, text_size.height); // Draw text, and background cv::rectangle(vis_img, text_back, roi_color, -1); - cv::putText(vis_img, - text, - origin, - font_face, - font_scale, - cv::Scalar(255, 255, 255), - thickness); + cv::putText(vis_img, text, origin, font_face, font_scale, + cv::Scalar(255, 255, 255), thickness); } return vis_img; } @@ -197,10 +179,8 @@ void ObjectDetector::Preprocess(const cv::Mat &ori_im) { void ObjectDetector::Postprocess( const std::vector mats, std::vector *result, - std::vector bbox_num, - std::vector output_data_, - std::vector output_mask_data_, - bool is_rbox = false) { + std::vector bbox_num, std::vector output_data_, + std::vector output_mask_data_, bool is_rbox = false) { result->clear(); int start_idx = 0; int total_num = std::accumulate(bbox_num.begin(), bbox_num.end(), 0); @@ -267,9 +247,81 @@ void ObjectDetector::Postprocess( } } +// This function is to convert output result from SOLOv2 to class ObjectResult +void ObjectDetector::SOLOv2Postprocess( + const std::vector mats, std::vector *result, + std::vector *bbox_num, std::vector out_bbox_num_data_, + std::vector out_label_data_, std::vector out_score_data_, + std::vector out_global_mask_data_, float threshold) { + + for (int im_id = 0; im_id < mats.size(); im_id++) { + cv::Mat mat = mats[im_id]; + + int valid_bbox_count = 0; + for (int bbox_id = 0; bbox_id < out_bbox_num_data_[im_id]; ++bbox_id) { + if (out_score_data_[bbox_id] >= threshold) { + ObjectResult result_item; + result_item.class_id = out_label_data_[bbox_id]; + result_item.confidence = out_score_data_[bbox_id]; + std::vector global_mask; + + for (int k = 0; k < mat.rows * mat.cols; ++k) { + global_mask.push_back(static_cast( + out_global_mask_data_[k + bbox_id * mat.rows * mat.cols])); + } + + // find minimize bounding box from mask + cv::Mat mask(mat.rows, mat.cols, CV_32SC1); + std::memcpy(mask.data, global_mask.data(), + global_mask.size() * sizeof(int)); + + cv::Mat mask_fp; + cv::Mat rowSum; + cv::Mat colSum; + std::vector sum_of_row(mat.rows); + std::vector sum_of_col(mat.cols); + + mask.convertTo(mask_fp, CV_32FC1); + cv::reduce(mask_fp, colSum, 0, CV_REDUCE_SUM, CV_32FC1); + cv::reduce(mask_fp, rowSum, 1, CV_REDUCE_SUM, CV_32FC1); + + for (int row_id = 0; row_id < mat.rows; ++row_id) { + sum_of_row[row_id] = rowSum.at(row_id, 0); + } + + for (int col_id = 0; col_id < mat.cols; ++col_id) { + sum_of_col[col_id] = colSum.at(0, col_id); + } + + auto it = std::find_if(sum_of_row.begin(), sum_of_row.end(), + [](int x) { return x > 0.5; }); + int y1 = std::distance(sum_of_row.begin(), it); + + auto it2 = std::find_if(sum_of_col.begin(), sum_of_col.end(), + [](int x) { return x > 0.5; }); + int x1 = std::distance(sum_of_col.begin(), it2); + + auto rit = std::find_if(sum_of_row.rbegin(), sum_of_row.rend(), + [](int x) { return x > 0.5; }); + int y2 = std::distance(rit, sum_of_row.rend()); + + auto rit2 = std::find_if(sum_of_col.rbegin(), sum_of_col.rend(), + [](int x) { return x > 0.5; }); + int x2 = std::distance(rit2, sum_of_col.rend()); + + result_item.rect = {x1, y1, x2, y2}; + result_item.mask = global_mask; + + result->push_back(result_item); + valid_bbox_count++; + } + } + bbox_num->push_back(valid_bbox_count); + } +} + void ObjectDetector::Predict(const std::vector imgs, - const double threshold, - const int warmup, + const double threshold, const int warmup, const int repeats, std::vector *result, std::vector *bbox_num, @@ -285,6 +337,11 @@ void ObjectDetector::Predict(const std::vector imgs, std::vector out_bbox_num_data_; std::vector out_mask_data_; + // these parameters are for SOLOv2 output + std::vector out_score_data_; + std::vector out_global_mask_data_; + std::vector out_label_data_; + // in_net img for each batch std::vector in_net_img_all(batch_size); @@ -298,8 +355,8 @@ void ObjectDetector::Predict(const std::vector imgs, scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0]; scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1]; - in_data_all.insert( - in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end()); + in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), + inputs_.im_data_.end()); // collect in_net img in_net_img_all[bs_idx] = inputs_.in_net_im_; @@ -320,8 +377,8 @@ void ObjectDetector::Predict(const std::vector imgs, pad_data.resize(rc * rh * rw); float *base = pad_data.data(); for (int i = 0; i < rc; ++i) { - cv::extractChannel( - pad_img, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); + cv::extractChannel(pad_img, + cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); } in_data_all.insert(in_data_all.end(), pad_data.begin(), pad_data.end()); } @@ -354,58 +411,118 @@ void ObjectDetector::Predict(const std::vector imgs, bool is_rbox = false; int reg_max = 7; int num_class = 80; - // warmup - for (int i = 0; i < warmup; i++) { - predictor_->Run(); - // Get output tensor - auto output_names = predictor_->GetOutputNames(); - for (int j = 0; j < output_names.size(); j++) { - auto output_tensor = predictor_->GetOutputHandle(output_names[j]); - std::vector output_shape = output_tensor->shape(); - int out_num = std::accumulate( - output_shape.begin(), output_shape.end(), 1, std::multiplies()); - if (config_.mask_ && (j == 2)) { - out_mask_data_.resize(out_num); - output_tensor->CopyToCpu(out_mask_data_.data()); - } else if (output_tensor->type() == paddle_infer::DataType::INT32) { - out_bbox_num_data_.resize(out_num); - output_tensor->CopyToCpu(out_bbox_num_data_.data()); - } else { - std::vector out_data; - out_data.resize(out_num); - output_tensor->CopyToCpu(out_data.data()); - out_tensor_list.push_back(out_data); + + auto inference_start = std::chrono::steady_clock::now(); + if (config_.arch_ == "SOLOv2") { + // warmup + for (int i = 0; i < warmup; i++) { + predictor_->Run(); + // Get output tensor + auto output_names = predictor_->GetOutputNames(); + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = predictor_->GetOutputHandle(output_names[j]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), + 1, std::multiplies()); + if (j == 0) { + out_bbox_num_data_.resize(out_num); + output_tensor->CopyToCpu(out_bbox_num_data_.data()); + } else if (j == 1) { + out_label_data_.resize(out_num); + output_tensor->CopyToCpu(out_label_data_.data()); + } else if (j == 2) { + out_score_data_.resize(out_num); + output_tensor->CopyToCpu(out_score_data_.data()); + } else if (config_.mask_ && (j == 3)) { + out_global_mask_data_.resize(out_num); + output_tensor->CopyToCpu(out_global_mask_data_.data()); + } } } - } - auto inference_start = std::chrono::steady_clock::now(); - for (int i = 0; i < repeats; i++) { - predictor_->Run(); - // Get output tensor - out_tensor_list.clear(); - output_shape_list.clear(); - auto output_names = predictor_->GetOutputNames(); - for (int j = 0; j < output_names.size(); j++) { - auto output_tensor = predictor_->GetOutputHandle(output_names[j]); - std::vector output_shape = output_tensor->shape(); - int out_num = std::accumulate( - output_shape.begin(), output_shape.end(), 1, std::multiplies()); - output_shape_list.push_back(output_shape); - if (config_.mask_ && (j == 2)) { - out_mask_data_.resize(out_num); - output_tensor->CopyToCpu(out_mask_data_.data()); - } else if (output_tensor->type() == paddle_infer::DataType::INT32) { - out_bbox_num_data_.resize(out_num); - output_tensor->CopyToCpu(out_bbox_num_data_.data()); - } else { - std::vector out_data; - out_data.resize(out_num); - output_tensor->CopyToCpu(out_data.data()); - out_tensor_list.push_back(out_data); + inference_start = std::chrono::steady_clock::now(); + for (int i = 0; i < repeats; i++) { + predictor_->Run(); + // Get output tensor + out_tensor_list.clear(); + output_shape_list.clear(); + auto output_names = predictor_->GetOutputNames(); + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = predictor_->GetOutputHandle(output_names[j]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), + 1, std::multiplies()); + output_shape_list.push_back(output_shape); + if (j == 0) { + out_bbox_num_data_.resize(out_num); + output_tensor->CopyToCpu(out_bbox_num_data_.data()); + } else if (j == 1) { + out_label_data_.resize(out_num); + output_tensor->CopyToCpu(out_label_data_.data()); + } else if (j == 2) { + out_score_data_.resize(out_num); + output_tensor->CopyToCpu(out_score_data_.data()); + } else if (config_.mask_ && (j == 3)) { + out_global_mask_data_.resize(out_num); + output_tensor->CopyToCpu(out_global_mask_data_.data()); + } + } + } + } else { + // warmup + for (int i = 0; i < warmup; i++) { + predictor_->Run(); + // Get output tensor + auto output_names = predictor_->GetOutputNames(); + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = predictor_->GetOutputHandle(output_names[j]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), + 1, std::multiplies()); + if (config_.mask_ && (j == 2)) { + out_mask_data_.resize(out_num); + output_tensor->CopyToCpu(out_mask_data_.data()); + } else if (output_tensor->type() == paddle_infer::DataType::INT32) { + out_bbox_num_data_.resize(out_num); + output_tensor->CopyToCpu(out_bbox_num_data_.data()); + } else { + std::vector out_data; + out_data.resize(out_num); + output_tensor->CopyToCpu(out_data.data()); + out_tensor_list.push_back(out_data); + } + } + } + + inference_start = std::chrono::steady_clock::now(); + for (int i = 0; i < repeats; i++) { + predictor_->Run(); + // Get output tensor + out_tensor_list.clear(); + output_shape_list.clear(); + auto output_names = predictor_->GetOutputNames(); + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = predictor_->GetOutputHandle(output_names[j]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), + 1, std::multiplies()); + output_shape_list.push_back(output_shape); + if (config_.mask_ && (j == 2)) { + out_mask_data_.resize(out_num); + output_tensor->CopyToCpu(out_mask_data_.data()); + } else if (output_tensor->type() == paddle_infer::DataType::INT32) { + out_bbox_num_data_.resize(out_num); + output_tensor->CopyToCpu(out_bbox_num_data_.data()); + } else { + std::vector out_data; + out_data.resize(out_num); + output_tensor->CopyToCpu(out_data.data()); + out_tensor_list.push_back(out_data); + } } } } + auto inference_end = std::chrono::steady_clock::now(); auto postprocess_start = std::chrono::steady_clock::now(); // Postprocessing result @@ -420,30 +537,23 @@ void ObjectDetector::Predict(const std::vector imgs, reg_max = output_shape_list[i][2] / 4 - 1; } float *buffer = new float[out_tensor_list[i].size()]; - memcpy(buffer, - &out_tensor_list[i][0], + memcpy(buffer, &out_tensor_list[i][0], out_tensor_list[i].size() * sizeof(float)); output_data_list_.push_back(buffer); } PaddleDetection::PicoDetPostProcess( - result, - output_data_list_, - config_.fpn_stride_, - inputs_.im_shape_, - inputs_.scale_factor_, - config_.nms_info_["score_threshold"].as(), - config_.nms_info_["nms_threshold"].as(), - num_class, - reg_max); + result, output_data_list_, config_.fpn_stride_, inputs_.im_shape_, + inputs_.scale_factor_, config_.nms_info_["score_threshold"].as(), + config_.nms_info_["nms_threshold"].as(), num_class, reg_max); bbox_num->push_back(result->size()); + } else if (config_.arch_ == "SOLOv2") { + SOLOv2Postprocess(imgs, result, bbox_num, out_bbox_num_data_, + out_label_data_, out_score_data_, out_global_mask_data_, + threshold); } else { is_rbox = output_shape_list[0][output_shape_list[0].size() - 1] % 10 == 0; - Postprocess(imgs, - result, - out_bbox_num_data_, - out_tensor_list[0], - out_mask_data_, - is_rbox); + Postprocess(imgs, result, out_bbox_num_data_, out_tensor_list[0], + out_mask_data_, is_rbox); for (int k = 0; k < out_bbox_num_data_.size(); k++) { int tmp = out_bbox_num_data_[k]; bbox_num->push_back(tmp); @@ -479,4 +589,4 @@ std::vector GenerateColorMap(int num_class) { return colormap; } -} // namespace PaddleDetection +} // namespace PaddleDetection