From b3446197a2496f962bc69eebc95b58783ebeb2b3 Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Thu, 4 Nov 2021 12:51:24 +0000 Subject: [PATCH] add detect infer for pp_shitu cpp infer --- deploy/cpp_shitu/include/cls.h | 5 +- deploy/cpp_shitu/include/object_detector.h | 137 +++++++ deploy/cpp_shitu/include/preprocess_op_det.h | 155 ++++++++ deploy/cpp_shitu/src/cls.cpp | 40 +- deploy/cpp_shitu/src/main.cpp | 163 ++++++++- deploy/cpp_shitu/src/object_detector.cpp | 365 +++++++++++++++++++ deploy/cpp_shitu/src/preprocess_op_det.cpp | 130 +++++++ 7 files changed, 960 insertions(+), 35 deletions(-) create mode 100644 deploy/cpp_shitu/include/object_detector.h create mode 100644 deploy/cpp_shitu/include/preprocess_op_det.h create mode 100644 deploy/cpp_shitu/src/object_detector.cpp create mode 100644 deploy/cpp_shitu/src/preprocess_op_det.cpp diff --git a/deploy/cpp_shitu/include/cls.h b/deploy/cpp_shitu/include/cls.h index e0b7989c..8edef1c4 100644 --- a/deploy/cpp_shitu/include/cls.h +++ b/deploy/cpp_shitu/include/cls.h @@ -76,11 +76,12 @@ public: void LoadModel(const std::string &model_path, const std::string ¶ms_path); // Run predictor - double Run(cv::Mat &img, std::vector *times); + void Run(cv::Mat &img, std::vector &out_data, + std::vector ×); -private: std::shared_ptr predictor_; +private: bool use_gpu_ = false; int gpu_id_ = 0; int gpu_mem_ = 4000; diff --git a/deploy/cpp_shitu/include/object_detector.h b/deploy/cpp_shitu/include/object_detector.h new file mode 100644 index 00000000..015885b1 --- /dev/null +++ b/deploy/cpp_shitu/include/object_detector.h @@ -0,0 +1,137 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "paddle_inference_api.h" // NOLINT + +#include "include/preprocess_op_det.h" +#include "include/yaml_config.h" + +using namespace paddle_infer; + +namespace PaddleDetection { +// Object Detection Result +struct ObjectResult { + // Rectangle coordinates of detected object: left, right, top, down + std::vector rect; + // Class id of detected object + int class_id; + // Confidence of detected object + float confidence; +}; + +// Generate visualization colormap for each class +std::vector GenerateColorMap(int num_class); + +// Visualiztion Detection Result +cv::Mat VisualizeResult(const cv::Mat &img, + const std::vector &results, + const std::vector &lables, + const std::vector &colormap, const bool is_rbox); + +class ObjectDetector { +public: + explicit ObjectDetector(const YAML::Node &config_file) { + this->use_gpu_ = config_file["Global"]["use_gpu"].as(); + if (config_file["Global"]["gpu_id"].IsDefined()) + this->gpu_id_ = config_file["Global"]["gpu_id"].as(); + this->gpu_mem_ = config_file["Global"]["gpu_mem"].as(); + this->cpu_math_library_num_threads_ = + config_file["Global"]["cpu_num_threads"].as(); + this->use_mkldnn_ = config_file["Global"]["enable_mkldnn"].as(); + this->use_tensorrt_ = config_file["Global"]["use_tensorrt"].as(); + this->use_fp16_ = config_file["Global"]["use_fp16"].as(); + this->model_dir_ = + config_file["Global"]["det_inference_model_dir"].as(); + this->nms_thres_ = config_file["Global"]["rec_nms_thresold"].as(); + this->threshold_ = config_file["Global"]["threshold"].as(); + this->max_det_results_ = config_file["Global"]["max_det_results"].as(); + this->image_shape_ = + config_file["Global"]["image_shape"].as>(); + this->label_list_ = + config_file["Global"]["labe_list"].as>(); + this->ir_optim_ = config_file["Global"]["ir_optim"].as(); + this->batch_size_ = config_file["Global"]["batch_size"].as(); + + preprocessor_.Init(config_file["DetPreProcess"]["transform_ops"]); + LoadModel(model_dir_, batch_size_, run_mode); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir, const int batch_size = 1, + const std::string &run_mode = "fluid"); + + // Run predictor + void Predict(const std::vector imgs, const int warmup = 0, + const int repeats = 1, + std::vector *result = nullptr, + std::vector *bbox_num = nullptr, + std::vector *times = nullptr); + const std::vector &GetLabelList() const { + return this->label_list_; + } + const float &GetThreshold() const { return this->threshold_; } + +private: + bool use_gpu_ = true; + int gpu_id_ = 0; + int gpu_mem_ = 800; + int cpu_math_library_num_threads_ = 6; + std::string run_mode = "fluid"; + bool use_mkldnn_ = false; + bool use_tensorrt_ = false; + bool batch_size_ = 1; + bool use_fp16_ = false; + std::string model_dir_; + float nms_thres_ = 0.02; + float threshold_ = 0.5; + float max_det_results_ = 5; + std::vector image_shape_ = {3, 640, 640}; + std::vector label_list_; + bool ir_optim_ = true; + bool det_permute_ = true; + bool det_postprocess_ = true; + int min_subgraph_size_ = 30; + bool use_dynamic_shape_ = false; + int trt_min_shape_ = 1; + int trt_max_shape_ = 1280; + int trt_opt_shape_ = 640; + bool trt_calib_mode_ = false; + + // Preprocess image and copy data to input buffer + void Preprocess(const cv::Mat &image_mat); + // Postprocess result + void Postprocess(const std::vector mats, + std::vector *result, std::vector bbox_num, + bool is_rbox); + + std::shared_ptr predictor_; + Preprocessor preprocessor_; + ImageBlob inputs_; + std::vector output_data_; + std::vector out_bbox_num_data_; +}; + +} // namespace PaddleDetection diff --git a/deploy/cpp_shitu/include/preprocess_op_det.h b/deploy/cpp_shitu/include/preprocess_op_det.h new file mode 100644 index 00000000..de23ea87 --- /dev/null +++ b/deploy/cpp_shitu/include/preprocess_op_det.h @@ -0,0 +1,155 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace PaddleDetection { + +// Object for storing all preprocessed data +class ImageBlob { +public: + // image width and height + std::vector im_shape_; + // Buffer for image data after preprocessing + std::vector im_data_; + // in net data shape(after pad) + std::vector in_net_shape_; + // Evaluation image width and height + // std::vector eval_im_size_f_; + // Scale factor for image size to origin image size + std::vector scale_factor_; +}; + +// Abstraction of preprocessing opration class +class PreprocessOp { +public: + virtual void Init(const YAML::Node &item) = 0; + virtual void Run(cv::Mat *im, ImageBlob *data) = 0; +}; + +class InitInfo : public PreprocessOp { +public: + virtual void Init(const YAML::Node &item) {} + virtual void Run(cv::Mat *im, ImageBlob *data); +}; + +class NormalizeImage : public PreprocessOp { +public: + virtual void Init(const YAML::Node &item) { + mean_ = item["mean"].as>(); + scale_ = item["std"].as>(); + is_scale_ = item["is_scale"].as(); + } + + virtual void Run(cv::Mat *im, ImageBlob *data); + +private: + // CHW or HWC + std::vector mean_; + std::vector scale_; + bool is_scale_; +}; + +class Permute : public PreprocessOp { +public: + virtual void Init(const YAML::Node &item) {} + virtual void Run(cv::Mat *im, ImageBlob *data); +}; + +class Resize : public PreprocessOp { +public: + virtual void Init(const YAML::Node &item) { + interp_ = item["interp"].as(); + // max_size_ = item["target_size"].as(); + keep_ratio_ = item["keep_ratio"].as(); + target_size_ = item["target_size"].as>(); + } + + // Compute best resize scale for x-dimension, y-dimension + std::pair GenerateScale(const cv::Mat &im); + + virtual void Run(cv::Mat *im, ImageBlob *data); + +private: + int interp_; + bool keep_ratio_; + std::vector target_size_; + std::vector in_net_shape_; +}; + +// Models with FPN need input shape % stride == 0 +class PadStride : public PreprocessOp { +public: + virtual void Init(const YAML::Node &item) { + stride_ = item["stride"].as(); + } + + virtual void Run(cv::Mat *im, ImageBlob *data); + +private: + int stride_; +}; + +class Preprocessor { +public: + void Init(const YAML::Node &config_node) { + // initialize image info at first + ops_["InitInfo"] = std::make_shared(); + for (int i = 0; i < config_node.size(); ++i) { + if (config_node[i]["DetResize"].IsDefined()) { + ops_["Resize"] = std::make_shared(); + ops_["Resize"]->Init(config_node[i]["DetResize"]); + } + + if (config_node[i]["DetNormalizeImage"].IsDefined()) { + ops_["NormalizeImage"] = std::make_shared(); + ops_["NormalizeImage"]->Init(config_node[i]["DetNormalizeImage"]); + } + + if (config_node[i]["DetPermute"].IsDefined()) { + ops_["Permute"] = std::make_shared(); + ops_["Permute"]->Init(config_node[i]["DetPermute"]); + } + + if (config_node[i]["DetPadStrid"].IsDefined()) { + ops_["PadStride"] = std::make_shared(); + ops_["PadStride"]->Init(config_node[i]["DetPadStrid"]); + } + } + } + + void Run(cv::Mat *im, ImageBlob *data); + +public: + static const std::vector RUN_ORDER; + +private: + std::unordered_map> ops_; +}; + +} // namespace PaddleDetection diff --git a/deploy/cpp_shitu/src/cls.cpp b/deploy/cpp_shitu/src/cls.cpp index a6ca7d5a..c8f1e2b9 100644 --- a/deploy/cpp_shitu/src/cls.cpp +++ b/deploy/cpp_shitu/src/cls.cpp @@ -52,10 +52,12 @@ void Classifier::LoadModel(const std::string &model_path, this->predictor_ = CreatePredictor(config); } -double Classifier::Run(cv::Mat &img, std::vector *times) { +void Classifier::Run(cv::Mat &img, std::vector &out_data, + std::vector ×) { cv::Mat srcimg; cv::Mat resize_img; img.copyTo(srcimg); + std::vector time; auto preprocess_start = std::chrono::system_clock::now(); this->resize_op_.Run(img, resize_img, this->resize_short_, @@ -74,7 +76,6 @@ double Classifier::Run(cv::Mat &img, std::vector *times) { input_t->CopyFromCpu(input.data()); this->predictor_->Run(); - std::vector out_data; auto output_names = this->predictor_->GetOutputNames(); auto output_t = this->predictor_->GetOutputHandle(output_names[0]); std::vector output_shape = output_t->shape(); @@ -85,27 +86,28 @@ double Classifier::Run(cv::Mat &img, std::vector *times) { output_t->CopyToCpu(out_data.data()); auto infer_end = std::chrono::system_clock::now(); - auto postprocess_start = std::chrono::system_clock::now(); - int maxPosition = - max_element(out_data.begin(), out_data.end()) - out_data.begin(); - auto postprocess_end = std::chrono::system_clock::now(); + // auto postprocess_start = std::chrono::system_clock::now(); + // int maxPosition = + // max_element(out_data.begin(), out_data.end()) - out_data.begin(); + // auto postprocess_end = std::chrono::system_clock::now(); std::chrono::duration preprocess_diff = preprocess_end - preprocess_start; - times->push_back(double(preprocess_diff.count() * 1000)); + time.push_back(double(preprocess_diff.count())); std::chrono::duration inference_diff = infer_end - infer_start; - double inference_cost_time = double(inference_diff.count() * 1000); - times->push_back(inference_cost_time); - std::chrono::duration postprocess_diff = - postprocess_end - postprocess_start; - times->push_back(double(postprocess_diff.count() * 1000)); - - std::cout << "result: " << std::endl; - std::cout << "\tclass id: " << maxPosition << std::endl; - std::cout << std::fixed << std::setprecision(10) - << "\tscore: " << double(out_data[maxPosition]) << std::endl; - - return inference_cost_time; + double inference_cost_time = double(inference_diff.count()); + time.push_back(inference_cost_time); + // std::chrono::duration postprocess_diff = + // postprocess_end - postprocess_start; + time.push_back(0); + + // std::cout << "result: " << std::endl; + // std::cout << "\tclass id: " << maxPosition << std::endl; + // std::cout << std::fixed << std::setprecision(10) + // << "\tscore: " << double(out_data[maxPosition]) << std::endl; + times[0] += time[0]; + times[1] += time[1]; + times[2] += time[2]; } } // namespace PaddleClas diff --git a/deploy/cpp_shitu/src/main.cpp b/deploy/cpp_shitu/src/main.cpp index d1f36973..677954fb 100644 --- a/deploy/cpp_shitu/src/main.cpp +++ b/deploy/cpp_shitu/src/main.cpp @@ -28,11 +28,104 @@ #include #include +#include #include using namespace std; using namespace cv; -using namespace PaddleClas; + +void DetPredictImage(const std::vector &batch_imgs, + const std::vector &all_img_paths, + const int batch_size, PaddleDetection::ObjectDetector *det, + std::vector &im_result, + std::vector &im_bbox_num, std::vector &det_t, + const bool visual_det = false, + const bool run_benchmark = false, + const std::string &output_dir = "output") { + int steps = ceil(float(all_img_paths.size()) / batch_size); + // printf("total images = %d, batch_size = %d, total steps = %d\n", + // all_img_paths.size(), batch_size, steps); + for (int idx = 0; idx < steps; idx++) { + int left_image_cnt = all_img_paths.size() - idx * batch_size; + if (left_image_cnt > batch_size) { + left_image_cnt = batch_size; + } + // for (int bs = 0; bs < left_image_cnt; bs++) { + // std::string image_file_path = all_img_paths.at(idx * batch_size+bs); + // cv::Mat im = cv::imread(image_file_path, 1); + // batch_imgs.insert(batch_imgs.end(), im); + // } + + // Store all detected result + std::vector result; + std::vector bbox_num; + std::vector det_times; + bool is_rbox = false; + if (run_benchmark) { + det->Predict(batch_imgs, 10, 10, &result, &bbox_num, &det_times); + } else { + det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times); + // get labels and colormap + auto labels = det->GetLabelList(); + auto colormap = PaddleDetection::GenerateColorMap(labels.size()); + + int item_start_idx = 0; + for (int i = 0; i < left_image_cnt; i++) { + cv::Mat im = batch_imgs[i]; + int detect_num = 0; + + for (int j = 0; j < bbox_num[i]; j++) { + PaddleDetection::ObjectResult item = result[item_start_idx + j]; + if (item.confidence < det->GetThreshold() || item.class_id == -1) { + continue; + } + detect_num += 1; + im_result.push_back(item); + if (visual_det) { + if (item.rect.size() > 6) { + is_rbox = true; + printf( + "class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n", + item.class_id, item.confidence, item.rect[0], item.rect[1], + item.rect[2], item.rect[3], item.rect[4], item.rect[5], + item.rect[6], item.rect[7]); + } else { + printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", + item.class_id, item.confidence, item.rect[0], item.rect[1], + item.rect[2], item.rect[3]); + } + } + } + im_bbox_num.push_back(detect_num); + item_start_idx = item_start_idx + bbox_num[i]; + + // Visualization result + if (visual_det) { + std::cout << all_img_paths.at(idx * batch_size + i) + << " The number of detected box: " << detect_num + << std::endl; + cv::Mat vis_img = PaddleDetection::VisualizeResult( + im, im_result, labels, colormap, is_rbox); + std::vector compression_params; + compression_params.push_back(CV_IMWRITE_JPEG_QUALITY); + compression_params.push_back(95); + std::string output_path(output_dir); + if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) { + output_path += OS_PATH_SEP; + } + std::string image_file_path = all_img_paths.at(idx * batch_size + i); + output_path += + image_file_path.substr(image_file_path.find_last_of('/') + 1); + cv::imwrite(output_path, vis_img, compression_params); + printf("Visualized output saved as %s\n", output_path.c_str()); + } + } + } + det_t[0] += det_times[0]; + det_t[1] += det_times[1]; + det_t[2] += det_times[2]; + } +} int main(int argc, char **argv) { if (argc != 2) { @@ -40,9 +133,24 @@ int main(int argc, char **argv) { exit(1); } - YamlConfig config(argv[1]); + PaddleClas::YamlConfig config(argv[1]); config.PrintConfigInfo(); + // config + const int batch_size = config.config_file["Global"]["batch_size"].as(); + bool visual_det = false; + if (config.config_file["Global"]["visual_det"].IsDefined()) { + visual_det = config.config_file["Global"]["visual_det"].as(); + } + bool run_benchmark = false; + if (config.config_file["Global"]["benchmark"].IsDefined()) { + run_benchmark = config.config_file["Global"]["benchmark"].as(); + } + int max_det_results = 5; + if (config.config_file["Global"]["max_det_results"].IsDefined()) { + max_det_results = config.config_file["Global"]["max_det_results"].as(); + } + std::string path = config.config_file["Global"]["infer_imgs"].as(); std::vector img_files_list; @@ -58,10 +166,17 @@ int main(int argc, char **argv) { std::cout << "img_file_list length: " << img_files_list.size() << std::endl; - Classifier classifier(config.config_file); + PaddleClas::Classifier classifier(config.config_file); + PaddleDetection::ObjectDetector detector(config.config_file); double elapsed_time = 0.0; - std::vector cls_times; + std::vector cls_times = {0, 0, 0}; + std::vector det_times = {0, 0, 0}; + std::vector batch_imgs; + std::vector img_paths; + std::vector det_result; + std::vector det_bbox_num; + int warmup_iter = img_files_list.size() > 5 ? 5 : 0; for (int idx = 0; idx < img_files_list.size(); ++idx) { std::string img_path = img_files_list[idx]; @@ -71,19 +186,39 @@ int main(int argc, char **argv) { << "\n"; exit(-1); } - cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB); - double run_time = classifier.Run(srcimg, &cls_times); - if (idx >= warmup_iter) { - elapsed_time += run_time; - std::cout << "Current image path: " << img_path << std::endl; - std::cout << "Current time cost: " << run_time << " s, " - << "average time cost in all: " - << elapsed_time / (idx + 1 - warmup_iter) << " s." << std::endl; - } else { - std::cout << "Current time cost: " << run_time << " s." << std::endl; + batch_imgs.push_back(srcimg); + img_paths.push_back(img_path); + + // step1: get all detection results + DetPredictImage(batch_imgs, img_paths, batch_size, &detector, det_result, + det_bbox_num, det_times, visual_det, run_benchmark); + + // select max_det_results bbox + while (det_result.size() > max_det_results) { + det_result.pop_back(); + } + // step2: add the whole image for recognition to improve recall + PaddleDetection::ObjectResult result_whole_img = { + {0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0}; + det_result.push_back(result_whole_img); + det_bbox_num[0] = det_result.size() + 1; + + // step3: recognition process, use score_thres to ensure accuracy + for (int j = 0; j < det_result.size(); ++j) { + int w = det_result[j].rect[2] - det_result[j].rect[0]; + int h = det_result[j].rect[3] - det_result[j].rect[1]; + cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h); + cv::Mat crop_img = srcimg(rect); + std::vector feature; + classifier.Run(crop_img, feature, cls_times); } + // double run_time = classifier.Run(srcimg, cls_times); + batch_imgs.clear(); + img_paths.clear(); + det_bbox_num.clear(); + det_result.clear(); } std::string presion = "fp32"; diff --git a/deploy/cpp_shitu/src/object_detector.cpp b/deploy/cpp_shitu/src/object_detector.cpp new file mode 100644 index 00000000..257bbcf4 --- /dev/null +++ b/deploy/cpp_shitu/src/object_detector.cpp @@ -0,0 +1,365 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +// for setprecision +#include "include/object_detector.h" +#include +#include + +using namespace paddle_infer; + +namespace PaddleDetection { + +// Load Model and create model predictor +void ObjectDetector::LoadModel(const std::string &model_dir, + const int batch_size, + const std::string &run_mode) { + paddle_infer::Config config; + std::string prog_file = model_dir + OS_PATH_SEP + "inference.pdmodel"; + std::string params_file = model_dir + OS_PATH_SEP + "inference.pdiparams"; + config.SetModel(prog_file, params_file); + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + config.SwitchIrOptim(this->ir_optim_); + // // use tensorrt + // if (run_mode != "fluid") { + // auto precision = paddle_infer::Config::Precision::kFloat32; + // if (run_mode == "trt_fp32") { + // precision = paddle_infer::Config::Precision::kFloat32; + // } + // else if (run_mode == "trt_fp16") { + // precision = paddle_infer::Config::Precision::kHalf; + // } + // else if (run_mode == "trt_int8") { + // precision = paddle_infer::Config::Precision::kInt8; + // } else { + // printf("run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or + // 'trt_int8'"); + // } + // set tensorrt + + if (this->use_tensorrt_) { + config.EnableTensorRtEngine( + 1 << 30, batch_size, this->min_subgraph_size_, + this->use_fp16_ ? paddle_infer::Config::Precision::kHalf + : paddle_infer::Config::Precision::kFloat32, + false, this->trt_calib_mode_); + // set use dynamic shape + if (this->use_dynamic_shape_) { + // set DynamicShsape for image tensor + const std::vector min_input_shape = {1, 3, this->trt_min_shape_, + this->trt_min_shape_}; + const std::vector max_input_shape = {1, 3, this->trt_max_shape_, + this->trt_max_shape_}; + const std::vector opt_input_shape = {1, 3, this->trt_opt_shape_, + this->trt_opt_shape_}; + const std::map> map_min_input_shape = { + {"image", min_input_shape}}; + const std::map> map_max_input_shape = { + {"image", max_input_shape}}; + const std::map> map_opt_input_shape = { + {"image", opt_input_shape}}; + + config.SetTRTDynamicShapeInfo(map_min_input_shape, map_max_input_shape, + map_opt_input_shape); + std::cout << "TensorRT dynamic shape enabled" << std::endl; + } + } + + // } else if (this->device_ == "XPU"){ + // config.EnableXpu(10*1024*1024); + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + // cache 10 different shapes for mkldnn to avoid memory leak + config.SetMkldnnCacheCapacity(10); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + config.SwitchUseFeedFetchOps(false); + config.SwitchIrOptim(this->ir_optim_); + config.DisableGlogInfo(); + // Memory optimization + config.EnableMemoryOptim(); + predictor_ = std::move(CreatePredictor(config)); +} + +// Visualiztion MaskDetector results +cv::Mat VisualizeResult(const cv::Mat &img, + const std::vector &results, + const std::vector &lables, + const std::vector &colormap, + const bool is_rbox = false) { + cv::Mat vis_img = img.clone(); + for (int i = 0; i < results.size(); ++i) { + // Configure color and text size + std::ostringstream oss; + oss << std::setiosflags(std::ios::fixed) << std::setprecision(4); + oss << lables[results[i].class_id] << " "; + oss << results[i].confidence; + std::string text = oss.str(); + int c1 = colormap[3 * results[i].class_id + 0]; + int c2 = colormap[3 * results[i].class_id + 1]; + int c3 = colormap[3 * results[i].class_id + 2]; + cv::Scalar roi_color = cv::Scalar(c1, c2, c3); + int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL; + double font_scale = 0.5f; + float thickness = 0.5; + cv::Size text_size = + cv::getTextSize(text, font_face, font_scale, thickness, nullptr); + cv::Point origin; + + if (is_rbox) { + // Draw object, text, and background + for (int k = 0; k < 4; k++) { + cv::Point pt1 = cv::Point(results[i].rect[(k * 2) % 8], + results[i].rect[(k * 2 + 1) % 8]); + cv::Point pt2 = cv::Point(results[i].rect[(k * 2 + 2) % 8], + results[i].rect[(k * 2 + 3) % 8]); + cv::line(vis_img, pt1, pt2, roi_color, 2); + } + } else { + int w = results[i].rect[2] - results[i].rect[0]; + int h = results[i].rect[3] - results[i].rect[1]; + cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h); + // Draw roi object, text, and background + cv::rectangle(vis_img, roi, roi_color, 2); + } + + origin.x = results[i].rect[0]; + origin.y = results[i].rect[1]; + + // Configure text background + cv::Rect text_back = + cv::Rect(results[i].rect[0], results[i].rect[1] - text_size.height, + text_size.width, text_size.height); + // Draw text, and background + cv::rectangle(vis_img, text_back, roi_color, -1); + cv::putText(vis_img, text, origin, font_face, font_scale, + cv::Scalar(255, 255, 255), thickness); + } + return vis_img; +} + +void ObjectDetector::Preprocess(const cv::Mat &ori_im) { + // Clone the image : keep the original mat for postprocess + cv::Mat im = ori_im.clone(); + cv::cvtColor(im, im, cv::COLOR_BGR2RGB); + preprocessor_.Run(&im, &inputs_); +} + +void ObjectDetector::Postprocess(const std::vector mats, + std::vector *result, + std::vector bbox_num, + bool is_rbox = false) { + result->clear(); + int start_idx = 0; + for (int im_id = 0; im_id < mats.size(); im_id++) { + cv::Mat raw_mat = mats[im_id]; + int rh = 1; + int rw = 1; + // if (config_.arch_ == "Face") { + // rh = raw_mat.rows; + // rw = raw_mat.cols; + // } + for (int j = start_idx; j < start_idx + bbox_num[im_id]; j++) { + if (is_rbox) { + // Class id + int class_id = static_cast(round(output_data_[0 + j * 10])); + // Confidence score + float score = output_data_[1 + j * 10]; + int x1 = (output_data_[2 + j * 10] * rw); + int y1 = (output_data_[3 + j * 10] * rh); + int x2 = (output_data_[4 + j * 10] * rw); + int y2 = (output_data_[5 + j * 10] * rh); + int x3 = (output_data_[6 + j * 10] * rw); + int y3 = (output_data_[7 + j * 10] * rh); + int x4 = (output_data_[8 + j * 10] * rw); + int y4 = (output_data_[9 + j * 10] * rh); + + ObjectResult result_item; + result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4}; + result_item.class_id = class_id; + result_item.confidence = score; + result->push_back(result_item); + } else { + // Class id + int class_id = static_cast(round(output_data_[0 + j * 6])); + // Confidence score + float score = output_data_[1 + j * 6]; + int xmin = (output_data_[2 + j * 6] * rw); + int ymin = (output_data_[3 + j * 6] * rh); + int xmax = (output_data_[4 + j * 6] * rw); + int ymax = (output_data_[5 + j * 6] * rh); + int wd = xmax - xmin; + int hd = ymax - ymin; + + ObjectResult result_item; + result_item.rect = {xmin, ymin, xmax, ymax}; + result_item.class_id = class_id; + result_item.confidence = score; + result->push_back(result_item); + } + } + start_idx += bbox_num[im_id]; + } +} + +void ObjectDetector::Predict(const std::vector imgs, const int warmup, + const int repeats, + std::vector *result, + std::vector *bbox_num, + std::vector *times) { + auto preprocess_start = std::chrono::steady_clock::now(); + int batch_size = imgs.size(); + + // in_data_batch + std::vector in_data_all; + std::vector im_shape_all(batch_size * 2); + std::vector scale_factor_all(batch_size * 2); + + // Preprocess image + for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) { + cv::Mat im = imgs.at(bs_idx); + Preprocess(im); + im_shape_all[bs_idx * 2] = inputs_.im_shape_[0]; + im_shape_all[bs_idx * 2 + 1] = inputs_.im_shape_[1]; + + scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0]; + scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1]; + + // TODO: reduce cost time + in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), + inputs_.im_data_.end()); + } + + // Prepare input tensor + auto input_names = predictor_->GetInputNames(); + for (const auto &tensor_name : input_names) { + auto in_tensor = predictor_->GetInputHandle(tensor_name); + if (tensor_name == "image") { + int rh = inputs_.in_net_shape_[0]; + int rw = inputs_.in_net_shape_[1]; + in_tensor->Reshape({batch_size, 3, rh, rw}); + in_tensor->CopyFromCpu(in_data_all.data()); + } else if (tensor_name == "im_shape") { + in_tensor->Reshape({batch_size, 2}); + in_tensor->CopyFromCpu(im_shape_all.data()); + } else if (tensor_name == "scale_factor") { + in_tensor->Reshape({batch_size, 2}); + in_tensor->CopyFromCpu(scale_factor_all.data()); + } + } + + auto preprocess_end = std::chrono::steady_clock::now(); + // Run predictor + // warmup + for (int i = 0; i < warmup; i++) { + predictor_->Run(); + // Get output tensor + auto output_names = predictor_->GetOutputNames(); + auto out_tensor = predictor_->GetOutputHandle(output_names[0]); + std::vector output_shape = out_tensor->shape(); + auto out_bbox_num = predictor_->GetOutputHandle(output_names[1]); + std::vector out_bbox_num_shape = out_bbox_num->shape(); + // Calculate output length + int output_size = 1; + for (int j = 0; j < output_shape.size(); ++j) { + output_size *= output_shape[j]; + } + + if (output_size < 6) { + std::cerr << "[WARNING] No object detected." << std::endl; + } + output_data_.resize(output_size); + out_tensor->CopyToCpu(output_data_.data()); + + int out_bbox_num_size = 1; + for (int j = 0; j < out_bbox_num_shape.size(); ++j) { + out_bbox_num_size *= out_bbox_num_shape[j]; + } + out_bbox_num_data_.resize(out_bbox_num_size); + out_bbox_num->CopyToCpu(out_bbox_num_data_.data()); + } + + bool is_rbox = false; + auto inference_start = std::chrono::steady_clock::now(); + for (int i = 0; i < repeats; i++) { + predictor_->Run(); + // Get output tensor + auto output_names = predictor_->GetOutputNames(); + auto out_tensor = predictor_->GetOutputHandle(output_names[0]); + std::vector output_shape = out_tensor->shape(); + auto out_bbox_num = predictor_->GetOutputHandle(output_names[1]); + std::vector out_bbox_num_shape = out_bbox_num->shape(); + // Calculate output length + int output_size = 1; + for (int j = 0; j < output_shape.size(); ++j) { + output_size *= output_shape[j]; + } + is_rbox = output_shape[output_shape.size() - 1] % 10 == 0; + + if (output_size < 6) { + std::cerr << "[WARNING] No object detected." << std::endl; + } + output_data_.resize(output_size); + out_tensor->CopyToCpu(output_data_.data()); + + int out_bbox_num_size = 1; + for (int j = 0; j < out_bbox_num_shape.size(); ++j) { + out_bbox_num_size *= out_bbox_num_shape[j]; + } + out_bbox_num_data_.resize(out_bbox_num_size); + out_bbox_num->CopyToCpu(out_bbox_num_data_.data()); + } + auto inference_end = std::chrono::steady_clock::now(); + auto postprocess_start = std::chrono::steady_clock::now(); + // Postprocessing result + result->clear(); + Postprocess(imgs, result, out_bbox_num_data_, is_rbox); + bbox_num->clear(); + for (int k = 0; k < out_bbox_num_data_.size(); k++) { + int tmp = out_bbox_num_data_[k]; + bbox_num->push_back(tmp); + } + auto postprocess_end = std::chrono::steady_clock::now(); + + std::chrono::duration preprocess_diff = + preprocess_end - preprocess_start; + times->push_back(double(preprocess_diff.count() * 1000)); + std::chrono::duration inference_diff = inference_end - inference_start; + times->push_back(double(inference_diff.count() / repeats * 1000)); + std::chrono::duration postprocess_diff = + postprocess_end - postprocess_start; + times->push_back(double(postprocess_diff.count() * 1000)); +} + +std::vector GenerateColorMap(int num_class) { + auto colormap = std::vector(3 * num_class, 0); + for (int i = 0; i < num_class; ++i) { + int j = 0; + int lab = i; + while (lab) { + colormap[i * 3] |= (((lab >> 0) & 1) << (7 - j)); + colormap[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)); + colormap[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)); + ++j; + lab >>= 3; + } + } + return colormap; +} + +} // namespace PaddleDetection diff --git a/deploy/cpp_shitu/src/preprocess_op_det.cpp b/deploy/cpp_shitu/src/preprocess_op_det.cpp new file mode 100644 index 00000000..16d035e3 --- /dev/null +++ b/deploy/cpp_shitu/src/preprocess_op_det.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "include/preprocess_op_det.h" + +namespace PaddleDetection { + +void InitInfo::Run(cv::Mat *im, ImageBlob *data) { + data->im_shape_ = {static_cast(im->rows), + static_cast(im->cols)}; + data->scale_factor_ = {1., 1.}; + data->in_net_shape_ = {static_cast(im->rows), + static_cast(im->cols)}; +} + +void NormalizeImage::Run(cv::Mat *im, ImageBlob *data) { + double e = 1.0; + if (is_scale_) { + e /= 255.0; + } + (*im).convertTo(*im, CV_32FC3, e); + for (int h = 0; h < im->rows; h++) { + for (int w = 0; w < im->cols; w++) { + im->at(h, w)[0] = + (im->at(h, w)[0] - mean_[0]) / scale_[0]; + im->at(h, w)[1] = + (im->at(h, w)[1] - mean_[1]) / scale_[1]; + im->at(h, w)[2] = + (im->at(h, w)[2] - mean_[2]) / scale_[2]; + } + } +} + +void Permute::Run(cv::Mat *im, ImageBlob *data) { + int rh = im->rows; + int rw = im->cols; + int rc = im->channels(); + (data->im_data_).resize(rc * rh * rw); + float *base = (data->im_data_).data(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); + } +} + +void Resize::Run(cv::Mat *im, ImageBlob *data) { + auto resize_scale = GenerateScale(*im); + data->im_shape_ = {static_cast(im->cols * resize_scale.first), + static_cast(im->rows * resize_scale.second)}; + data->in_net_shape_ = {static_cast(im->cols * resize_scale.first), + static_cast(im->rows * resize_scale.second)}; + cv::resize(*im, *im, cv::Size(), resize_scale.first, resize_scale.second, + interp_); + data->im_shape_ = { + static_cast(im->rows), static_cast(im->cols), + }; + data->scale_factor_ = { + resize_scale.second, resize_scale.first, + }; +} + +std::pair Resize::GenerateScale(const cv::Mat &im) { + std::pair resize_scale; + int origin_w = im.cols; + int origin_h = im.rows; + + if (keep_ratio_) { + int im_size_max = std::max(origin_w, origin_h); + int im_size_min = std::min(origin_w, origin_h); + int target_size_max = + *std::max_element(target_size_.begin(), target_size_.end()); + int target_size_min = + *std::min_element(target_size_.begin(), target_size_.end()); + float scale_min = + static_cast(target_size_min) / static_cast(im_size_min); + float scale_max = + static_cast(target_size_max) / static_cast(im_size_max); + float scale_ratio = std::min(scale_min, scale_max); + resize_scale = {scale_ratio, scale_ratio}; + } else { + resize_scale.first = + static_cast(target_size_[1]) / static_cast(origin_w); + resize_scale.second = + static_cast(target_size_[0]) / static_cast(origin_h); + } + return resize_scale; +} + +void PadStride::Run(cv::Mat *im, ImageBlob *data) { + if (stride_ <= 0) { + return; + } + int rc = im->channels(); + int rh = im->rows; + int rw = im->cols; + int nh = (rh / stride_) * stride_ + (rh % stride_ != 0) * stride_; + int nw = (rw / stride_) * stride_ + (rw % stride_ != 0) * stride_; + cv::copyMakeBorder(*im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT, + cv::Scalar(0)); + data->in_net_shape_ = { + static_cast(im->rows), static_cast(im->cols), + }; +} + +// Preprocessor op running order +const std::vector Preprocessor::RUN_ORDER = { + "InitInfo", "Resize", "NormalizeImage", "PadStride", "Permute"}; + +void Preprocessor::Run(cv::Mat *im, ImageBlob *data) { + for (const auto &name : RUN_ORDER) { + if (ops_.find(name) != ops_.end()) { + ops_[name]->Run(im, data); + } + } +} + +} // namespace PaddleDetection -- GitLab