diff --git a/deploy/cpp/CMakeLists.txt b/deploy/cpp/CMakeLists.txt index 27a8093bb7e56a2563ec92d20f2d7a5a02f2ade6..3953df13dcbdfeb8854b9b407effac3dd9eb39ae 100644 --- a/deploy/cpp/CMakeLists.txt +++ b/deploy/cpp/CMakeLists.txt @@ -5,7 +5,7 @@ option(WITH_MKL "Compile demo with MKL/OpenBlas support,defaultuseMKL." option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." ON) option(WITH_TENSORRT "Compile demo with TensorRT." OFF) -option(WITH_KEYPOINT "Whether to Compile KeyPoint detector" ON) +option(WITH_KEYPOINT "Whether to Compile KeyPoint detector" OFF) SET(PADDLE_DIR "" CACHE PATH "Location of libraries") SET(PADDLE_LIB_NAME "" CACHE STRING "libpaddle_inference") @@ -22,9 +22,9 @@ include_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/src/ext-yaml-cpp/i link_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/lib") if (WITH_KEYPOINT) - set(SRCS src/main_keypoint.cc src/preprocess_op.cc src/object_detector.cc src/keypoint_detector.cc src/keypoint_postprocess.cc) + set(SRCS src/main_keypoint.cc src/preprocess_op.cc src/object_detector.cc src/picodet_postprocess.cc src/utils.cc src/keypoint_detector.cc src/keypoint_postprocess.cc) else () - set(SRCS src/main.cc src/preprocess_op.cc src/object_detector.cc) + set(SRCS src/main.cc src/preprocess_op.cc src/object_detector.cc src/picodet_postprocess.cc src/utils.cc) endif() macro(safe_set_static_flag) diff --git a/deploy/cpp/include/config_parser.h b/deploy/cpp/include/config_parser.h index 661b2d2dc2932990accc8a97b2a3f315716e5f1e..6f54bb762048785369ddd83c4c0360f058c3cf7e 100644 --- a/deploy/cpp/include/config_parser.h +++ b/deploy/cpp/include/config_parser.h @@ -99,6 +99,18 @@ class ConfigPaser { return false; } + // Get NMS for postprocess + if (config["NMS"].IsDefined()) { + nms_info_ = config["NMS"]; + } + // Get fpn_stride in PicoDet + if (config["fpn_stride"].IsDefined()) { + fpn_stride_.clear(); + for (auto item : config["fpn_stride"]) { + fpn_stride_.emplace_back(item.as()); + } + } + return true; } std::string mode_; @@ -106,7 +118,9 @@ class ConfigPaser { std::string arch_; int min_subgraph_size_; YAML::Node preprocess_info_; + YAML::Node nms_info_; std::vector label_list_; + std::vector fpn_stride_; bool use_dynamic_shape_; }; diff --git a/deploy/cpp/include/object_detector.h b/deploy/cpp/include/object_detector.h index 2b86ba94527d2aeefa96269a5cadcffdd7470335..0e207c1199cb19867ebf73e88b1c5e3b45bc55fc 100644 --- a/deploy/cpp/include/object_detector.h +++ b/deploy/cpp/include/object_detector.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -28,20 +29,12 @@ #include "include/preprocess_op.h" #include "include/config_parser.h" +#include "include/utils.h" +#include "include/picodet_postprocess.h" using namespace paddle_infer; namespace PaddleDetection { -// Object Detection Result -struct ObjectResult { - // Rectangle coordinates of detected object: left, right, top, down - std::vector rect; - // Class id of detected object - int class_id; - // Confidence of detected object - float confidence; -}; - // Generate visualization colormap for each class std::vector GenerateColorMap(int num_class); @@ -49,7 +42,7 @@ std::vector GenerateColorMap(int num_class); // Visualiztion Detection Result cv::Mat VisualizeResult(const cv::Mat& img, - const std::vector& results, + const std::vector& results, const std::vector& lables, const std::vector& colormap, const bool is_rbox); @@ -96,7 +89,7 @@ class ObjectDetector { const double threshold = 0.5, const int warmup = 0, const int repeats = 1, - std::vector* result = nullptr, + std::vector* result = nullptr, std::vector* bbox_num = nullptr, std::vector* times = nullptr); @@ -121,17 +114,17 @@ class ObjectDetector { // Postprocess result void Postprocess( const std::vector mats, - std::vector* result, + std::vector* result, std::vector bbox_num, + std::vector output_data_, bool is_rbox); std::shared_ptr predictor_; Preprocessor preprocessor_; ImageBlob inputs_; - std::vector output_data_; - std::vector out_bbox_num_data_; float threshold_; ConfigPaser config_; + }; } // namespace PaddleDetection diff --git a/deploy/cpp/include/picodet_postprocess.h b/deploy/cpp/include/picodet_postprocess.h new file mode 100644 index 0000000000000000000000000000000000000000..415ef69e548c9c1ce3b485d391dbe9945c6e0c83 --- /dev/null +++ b/deploy/cpp/include/picodet_postprocess.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "include/utils.h" + +namespace PaddleDetection { + +void PicoDetPostProcess(std::vector* results, + std::vector outs, + std::vector fpn_stride, + std::vector im_shape, + std::vector scale_factor, + float score_threshold = 0.3, + float nms_threshold = 0.5, + int num_class = 80, + int reg_max = 7); + +} // namespace PaddleDetection \ No newline at end of file diff --git a/deploy/cpp/include/preprocess_op.h b/deploy/cpp/include/preprocess_op.h index f34bab45ff61a7d1a4f33b03d0a01012d39505be..7a220baae67927c124aa652588e2665adedad550 100644 --- a/deploy/cpp/include/preprocess_op.h +++ b/deploy/cpp/include/preprocess_op.h @@ -86,7 +86,6 @@ class Resize : public PreprocessOp { public: virtual void Init(const YAML::Node& item) { interp_ = item["interp"].as(); - //max_size_ = item["target_size"].as(); keep_ratio_ = item["keep_ratio"].as(); target_size_ = item["target_size"].as>(); } diff --git a/deploy/cpp/include/utils.h b/deploy/cpp/include/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3802e1267176a050402d1fdf742e54a79f33ffb9 --- /dev/null +++ b/deploy/cpp/include/utils.h @@ -0,0 +1,39 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace PaddleDetection { + +// Object Detection Result +struct ObjectResult { + // Rectangle coordinates of detected object: left, right, top, down + std::vector rect; + // Class id of detected object + int class_id; + // Confidence of detected object + float confidence; +}; + +void nms(std::vector &input_boxes, float nms_threshold); + +} // namespace PaddleDetection \ No newline at end of file diff --git a/deploy/cpp/src/main.cc b/deploy/cpp/src/main.cc index 59e5b8eddcc5d867fd7620e7bfa357c461c84c48..f9a2676674f4acec3017d1e30657bfc9ecdeb9cb 100644 --- a/deploy/cpp/src/main.cc +++ b/deploy/cpp/src/main.cc @@ -241,7 +241,7 @@ void PredictImage(const std::vector all_img_paths, if (run_benchmark) { det->Predict(batch_imgs, threshold, 10, 10, &result, &bbox_num, &det_times); } else { - det->Predict(batch_imgs, 0.5, 0, 1, &result, &bbox_num, &det_times); + det->Predict(batch_imgs, threshold, 0, 1, &result, &bbox_num, &det_times); // get labels and colormap auto labels = det->GetLabelList(); auto colormap = PaddleDetection::GenerateColorMap(labels.size()); @@ -251,7 +251,7 @@ void PredictImage(const std::vector all_img_paths, cv::Mat im = batch_imgs[i]; std::vector im_result; int detect_num = 0; - + for (int j = 0; j < bbox_num[i]; j++) { PaddleDetection::ObjectResult item = result[item_start_idx + j]; if (item.confidence < threshold || item.class_id == -1) { diff --git a/deploy/cpp/src/main_keypoint.cc b/deploy/cpp/src/main_keypoint.cc index 9bd074159ebc668bf121f2f81d1e694e70e58f9e..7c711211711c3f117af5851d489d44364df981df 100644 --- a/deploy/cpp/src/main_keypoint.cc +++ b/deploy/cpp/src/main_keypoint.cc @@ -302,7 +302,7 @@ void PredictImage(const std::vector all_img_paths, if (run_benchmark) { det->Predict(batch_imgs, threshold, 10, 10, &result, &bbox_num, &det_times); } else { - det->Predict(batch_imgs, 0.5, 10, 10, &result, &bbox_num, &det_times); + det->Predict(batch_imgs, threshold, 0, 1, &result, &bbox_num, &det_times); } // get labels and colormap auto labels = det->GetLabelList(); diff --git a/deploy/cpp/src/object_detector.cc b/deploy/cpp/src/object_detector.cc index a5750d363a5953809fc8bdb620d715be20ae6448..134c01092b3f34247f686e3c828a2ba479979471 100644 --- a/deploy/cpp/src/object_detector.cc +++ b/deploy/cpp/src/object_detector.cc @@ -17,7 +17,6 @@ #include #include "include/object_detector.h" - using namespace paddle_infer; namespace PaddleDetection { @@ -94,7 +93,7 @@ void ObjectDetector::LoadModel(const std::string& model_dir, // Visualiztion MaskDetector results cv::Mat VisualizeResult(const cv::Mat& img, - const std::vector& results, + const std::vector& results, const std::vector& lables, const std::vector& colormap, const bool is_rbox=false) { @@ -171,8 +170,9 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) { void ObjectDetector::Postprocess( const std::vector mats, - std::vector* result, + std::vector* result, std::vector bbox_num, + std::vector output_data_, bool is_rbox=false) { result->clear(); int start_idx = 0; @@ -199,7 +199,7 @@ void ObjectDetector::Postprocess( int x4 = (output_data_[8 + j * 10] * rw); int y4 = (output_data_[9 + j * 10] * rh); - ObjectResult result_item; + PaddleDetection::ObjectResult result_item; result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4}; result_item.class_id = class_id; result_item.confidence = score; @@ -217,7 +217,7 @@ void ObjectDetector::Postprocess( int wd = xmax - xmin; int hd = ymax - ymin; - ObjectResult result_item; + PaddleDetection::ObjectResult result_item; result_item.rect = {xmin, ymin, xmax, ymax}; result_item.class_id = class_id; result_item.confidence = score; @@ -232,7 +232,7 @@ void ObjectDetector::Predict(const std::vector imgs, const double threshold, const int warmup, const int repeats, - std::vector* result, + std::vector* result, std::vector* bbox_num, std::vector* times) { auto preprocess_start = std::chrono::steady_clock::now(); @@ -242,6 +242,8 @@ void ObjectDetector::Predict(const std::vector imgs, std::vector in_data_all; std::vector im_shape_all(batch_size * 2); std::vector scale_factor_all(batch_size * 2); + std::vector output_data_list_; + std::vector out_bbox_num_data_; // Preprocess image for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) { @@ -277,77 +279,90 @@ void ObjectDetector::Predict(const std::vector imgs, } // Run predictor + std::vector> out_tensor_list; + std::vector> output_shape_list; + bool is_rbox = false; + int reg_max = 7; + int num_class = 80; // warmup - for (int i = 0; i < warmup; i++) - { + for (int i = 0; i < warmup; i++) { predictor_->Run(); // Get output tensor auto output_names = predictor_->GetOutputNames(); - auto out_tensor = predictor_->GetOutputHandle(output_names[0]); - std::vector output_shape = out_tensor->shape(); - auto out_bbox_num = predictor_->GetOutputHandle(output_names[1]); - std::vector out_bbox_num_shape = out_bbox_num->shape(); - // Calculate output length - int output_size = 1; - for (int j = 0; j < output_shape.size(); ++j) { - output_size *= output_shape[j]; - } - - if (output_size < 6) { - std::cerr << "[WARNING] No object detected." << std::endl; - } - output_data_.resize(output_size); - out_tensor->CopyToCpu(output_data_.data()); - - int out_bbox_num_size = 1; - for (int j = 0; j < out_bbox_num_shape.size(); ++j) { - out_bbox_num_size *= out_bbox_num_shape[j]; + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = predictor_->GetOutputHandle(output_names[j]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), + 1, std::multiplies()); + if (output_tensor->type() == paddle_infer::DataType::INT32) { + out_bbox_num_data_.resize(out_num); + output_tensor->CopyToCpu(out_bbox_num_data_.data()); + } else { + std::vector out_data; + out_data.resize(out_num); + output_tensor->CopyToCpu(out_data.data()); + out_tensor_list.push_back(out_data); + } } - out_bbox_num_data_.resize(out_bbox_num_size); - out_bbox_num->CopyToCpu(out_bbox_num_data_.data()); } - - bool is_rbox = false; + auto inference_start = std::chrono::steady_clock::now(); - for (int i = 0; i < repeats; i++) - { + for (int i = 0; i < repeats; i++) { predictor_->Run(); // Get output tensor + out_tensor_list.clear(); + output_shape_list.clear(); auto output_names = predictor_->GetOutputNames(); - auto out_tensor = predictor_->GetOutputHandle(output_names[0]); - std::vector output_shape = out_tensor->shape(); - auto out_bbox_num = predictor_->GetOutputHandle(output_names[1]); - std::vector out_bbox_num_shape = out_bbox_num->shape(); - // Calculate output length - int output_size = 1; - for (int j = 0; j < output_shape.size(); ++j) { - output_size *= output_shape[j]; - } - is_rbox = output_shape[output_shape.size()-1] % 10 == 0; - - if (output_size < 6) { - std::cerr << "[WARNING] No object detected." << std::endl; - } - output_data_.resize(output_size); - out_tensor->CopyToCpu(output_data_.data()); - - int out_bbox_num_size = 1; - for (int j = 0; j < out_bbox_num_shape.size(); ++j) { - out_bbox_num_size *= out_bbox_num_shape[j]; + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = predictor_->GetOutputHandle(output_names[j]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), + 1, std::multiplies()); + output_shape_list.push_back(output_shape); + if (output_tensor->type() == paddle_infer::DataType::INT32) { + out_bbox_num_data_.resize(out_num); + output_tensor->CopyToCpu(out_bbox_num_data_.data()); + } else { + std::vector out_data; + out_data.resize(out_num); + output_tensor->CopyToCpu(out_data.data()); + out_tensor_list.push_back(out_data); + } } - out_bbox_num_data_.resize(out_bbox_num_size); - out_bbox_num->CopyToCpu(out_bbox_num_data_.data()); } auto inference_end = std::chrono::steady_clock::now(); auto postprocess_start = std::chrono::steady_clock::now(); // Postprocessing result result->clear(); - Postprocess(imgs, result, out_bbox_num_data_, is_rbox); bbox_num->clear(); - for (int k=0; kpush_back(tmp); + if (config_.arch_ == "PicoDet") { + for (int i = 0; i < out_tensor_list.size(); i++) { + if (i == 0) { + num_class = output_shape_list[i][2]; + } + if (i == config_.fpn_stride_.size()) { + reg_max = output_shape_list[i][2] / 4 - 1; + } + float *buffer = new float[out_tensor_list[i].size()]; + memcpy(buffer, &out_tensor_list[i][0], + out_tensor_list[i].size()*sizeof(float)); + output_data_list_.push_back(buffer); + } + PaddleDetection::PicoDetPostProcess( + result, output_data_list_, config_.fpn_stride_, + inputs_.im_shape_, inputs_.scale_factor_, + config_.nms_info_["score_threshold"].as(), + config_.nms_info_["nms_threshold"].as(), num_class, reg_max); + bbox_num->push_back(result->size()); + } else { + is_rbox = output_shape_list[0][output_shape_list[0].size()-1] % 10 == 0; + Postprocess(imgs, result, out_bbox_num_data_, out_tensor_list[0], is_rbox); + for (int k=0; k < out_bbox_num_data_.size(); k++) { + int tmp = out_bbox_num_data_[k]; + bbox_num->push_back(tmp); + } } + auto postprocess_end = std::chrono::steady_clock::now(); std::chrono::duration preprocess_diff = preprocess_end - preprocess_start; diff --git a/deploy/cpp/src/picodet_postprocess.cc b/deploy/cpp/src/picodet_postprocess.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba73c7d8cd60fb0ef04f678c27680628696fff5f --- /dev/null +++ b/deploy/cpp/src/picodet_postprocess.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "include/picodet_postprocess.h" + +namespace PaddleDetection { + +float fast_exp(float x) { + union { + uint32_t i; + float f; + } v{}; + v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f); + return v.f; +} + +template +int activation_function_softmax(const _Tp *src, _Tp *dst, int length) { + const _Tp alpha = *std::max_element(src, src + length); + _Tp denominator{0}; + + for (int i = 0; i < length; ++i) { + dst[i] = fast_exp(src[i] - alpha); + denominator += dst[i]; + } + + for (int i = 0; i < length; ++i) { + dst[i] /= denominator; + } + + return 0; +} + +// PicoDet decode +PaddleDetection::ObjectResult disPred2Bbox(const float *&dfl_det, int label, float score, + int x, int y, int stride, std::vector im_shape, + int reg_max) { + float ct_x = (x + 0.5) * stride; + float ct_y = (y + 0.5) * stride; + std::vector dis_pred; + dis_pred.resize(4); + for (int i = 0; i < 4; i++) { + float dis = 0; + float* dis_after_sm = new float[reg_max + 1]; + activation_function_softmax(dfl_det + i * (reg_max + 1), dis_after_sm, reg_max + 1); + for (int j = 0; j < reg_max + 1; j++) { + dis += j * dis_after_sm[j]; + } + dis *= stride; + dis_pred[i] = dis; + delete[] dis_after_sm; + } + int xmin = (int)(std::max)(ct_x - dis_pred[0], .0f); + int ymin = (int)(std::max)(ct_y - dis_pred[1], .0f); + int xmax = (int)(std::min)(ct_x + dis_pred[2], (float)im_shape[0]); + int ymax = (int)(std::min)(ct_y + dis_pred[3], (float)im_shape[1]); + + PaddleDetection::ObjectResult result_item; + result_item.rect = {xmin, ymin, xmax, ymax}; + result_item.class_id = label; + result_item.confidence = score; + + return result_item; +} + + +void PicoDetPostProcess(std::vector* results, + std::vector outs, + std::vector fpn_stride, + std::vector im_shape, + std::vector scale_factor, + float score_threshold, + float nms_threshold, + int num_class, + int reg_max) { + std::vector> bbox_results; + bbox_results.resize(num_class); + int in_h = im_shape[0], in_w = im_shape[1]; + for (int i = 0; i < fpn_stride.size(); ++i) { + int feature_h = in_h / fpn_stride[i]; + int feature_w = in_w / fpn_stride[i]; + for (int idx = 0; idx < feature_h * feature_w; idx++) { + const float *scores = outs[i] + (idx * num_class); + + int row = idx / feature_w; + int col = idx % feature_w; + float score = 0; + int cur_label = 0; + for (int label = 0; label < num_class; label++) { + if (scores[label] > score) { + score = scores[label]; + cur_label = label; + } + } + if (score > score_threshold) { + const float *bbox_pred = outs[i + fpn_stride.size()] + + (idx * 4 * (reg_max + 1)); + bbox_results[cur_label].push_back(disPred2Bbox(bbox_pred, + cur_label, score, col, row, fpn_stride[i], im_shape, reg_max)); + } + } + } + for (int i = 0; i < (int)bbox_results.size(); i++) { + PaddleDetection::nms(bbox_results[i], nms_threshold); + + for (auto box : bbox_results[i]) { + box.rect[0] = box.rect[0] / scale_factor[1]; + box.rect[2] = box.rect[2] / scale_factor[1]; + box.rect[1] = box.rect[1] / scale_factor[0]; + box.rect[3] = box.rect[3] / scale_factor[0]; + results->push_back(box); + } + } +} + +} // namespace PaddleDetection \ No newline at end of file diff --git a/deploy/cpp/src/utils.cc b/deploy/cpp/src/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b4731cd9e25b3536417ade20d3f9ce5089755fd --- /dev/null +++ b/deploy/cpp/src/utils.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "include/utils.h" + +namespace PaddleDetection { + +void nms(std::vector &input_boxes, float nms_threshold) { + std::sort(input_boxes.begin(), + input_boxes.end(), + [](ObjectResult a, ObjectResult b) { return a.confidence > b.confidence; }); + std::vector vArea(input_boxes.size()); + for (int i = 0; i < int(input_boxes.size()); ++i) { + vArea[i] = (input_boxes.at(i).rect[2] - input_boxes.at(i).rect[0] + 1) + * (input_boxes.at(i).rect[3] - input_boxes.at(i).rect[1] + 1); + } + for (int i = 0; i < int(input_boxes.size()); ++i) { + for (int j = i + 1; j < int(input_boxes.size());) { + float xx1 = (std::max)(input_boxes[i].rect[0], input_boxes[j].rect[0]); + float yy1 = (std::max)(input_boxes[i].rect[1], input_boxes[j].rect[1]); + float xx2 = (std::min)(input_boxes[i].rect[2], input_boxes[j].rect[2]); + float yy2 = (std::min)(input_boxes[i].rect[3], input_boxes[j].rect[3]); + float w = (std::max)(float(0), xx2 - xx1 + 1); + float h = (std::max)(float(0), yy2 - yy1 + 1); + float inter = w * h; + float ovr = inter / (vArea[i] + vArea[j] - inter); + if (ovr >= nms_threshold) { + input_boxes.erase(input_boxes.begin() + j); + vArea.erase(vArea.begin() + j); + } + else { + j++; + } + } + } +} + +} // namespace PaddleDetection diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 396f570b1ea0dc617a98b524f95041ed6f056308..0af0029774002127ef33b7e4712f36f53b383c08 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -25,6 +25,7 @@ from paddle.inference import Config from paddle.inference import create_predictor from benchmark_utils import PaddleInferBenchmark +from picodet_postprocess import PicoDetPostProcess from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize from visualize import visualize_box_mask from utils import argsparser, Timer, get_current_memory_mb @@ -277,6 +278,111 @@ class DetectorSOLOv2(Detector): boxes_num=np_boxes_num) +class DetectorPicoDet(Detector): + """ + Args: + config (object): config of model, defined by `Config(model_dir)` + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + pred_config, + model_dir, + device='CPU', + run_mode='fluid', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): + self.pred_config = pred_config + self.predictor, self.config = load_predictor( + model_dir, + run_mode=run_mode, + batch_size=batch_size, + min_subgraph_size=self.pred_config.min_subgraph_size, + device=device, + use_dynamic_shape=self.pred_config.use_dynamic_shape, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) + self.det_times = Timer() + self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 + + def predict(self, image, threshold=0.5, warmup=0, repeats=1): + ''' + Args: + image (str/np.ndarray): path of image/ np.ndarray read by cv2 + threshold (float): threshold of predicted box' score + Returns: + results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + ''' + self.det_times.preprocess_time_s.start() + inputs = self.preprocess(image) + self.det_times.preprocess_time_s.end() + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + np_score_list, np_boxes_list = [], [] + for i in range(warmup): + self.predictor.run() + np_score_list.clear() + np_boxes_list.clear() + output_names = self.predictor.get_output_names() + num_outs = int(len(output_names) / 2) + for out_idx in range(num_outs): + np_score_list.append( + self.predictor.get_output_handle(output_names[out_idx]) + .copy_to_cpu()) + np_boxes_list.append( + self.predictor.get_output_handle(output_names[ + out_idx + num_outs]).copy_to_cpu()) + + self.det_times.inference_time_s.start() + for i in range(repeats): + self.predictor.run() + np_score_list.clear() + np_boxes_list.clear() + output_names = self.predictor.get_output_names() + num_outs = int(len(output_names) / 2) + for out_idx in range(num_outs): + np_score_list.append( + self.predictor.get_output_handle(output_names[out_idx]) + .copy_to_cpu()) + np_boxes_list.append( + self.predictor.get_output_handle(output_names[ + out_idx + num_outs]).copy_to_cpu()) + self.det_times.inference_time_s.end(repeats=repeats) + self.det_times.img_num += 1 + self.det_times.postprocess_time_s.start() + self.postprocess = PicoDetPostProcess( + inputs['image'].shape[2:], + inputs['im_shape'], + inputs['scale_factor'], + strides=self.pred_config.fpn_stride, + nms_threshold=self.pred_config.nms['nms_threshold']) + np_boxes, np_boxes_num = self.postprocess(np_score_list, np_boxes_list) + self.det_times.postprocess_time_s.end() + return dict(boxes=np_boxes, boxes_num=np_boxes_num) + + def create_inputs(imgs, im_info): """generate input for different model type Args: @@ -341,6 +447,10 @@ class PredictConfig(): self.tracker = None if 'tracker' in yml_conf: self.tracker = yml_conf['tracker'] + if 'NMS' in yml_conf: + self.nms = yml_conf['NMS'] + if 'fpn_stride' in yml_conf: + self.fpn_stride = yml_conf['fpn_stride'] self.print_config() def check_model(self, yml_conf): @@ -595,31 +705,23 @@ def predict_video(detector, camera_id): def main(): pred_config = PredictConfig(FLAGS.model_dir) - detector = Detector( - pred_config, - FLAGS.model_dir, - device=FLAGS.device, - run_mode=FLAGS.run_mode, - batch_size=FLAGS.batch_size, - trt_min_shape=FLAGS.trt_min_shape, - trt_max_shape=FLAGS.trt_max_shape, - trt_opt_shape=FLAGS.trt_opt_shape, - trt_calib_mode=FLAGS.trt_calib_mode, - cpu_threads=FLAGS.cpu_threads, - enable_mkldnn=FLAGS.enable_mkldnn) + detector_func = 'Detector' if pred_config.arch == 'SOLOv2': - detector = DetectorSOLOv2( - pred_config, - FLAGS.model_dir, - device=FLAGS.device, - run_mode=FLAGS.run_mode, - batch_size=FLAGS.batch_size, - trt_min_shape=FLAGS.trt_min_shape, - trt_max_shape=FLAGS.trt_max_shape, - trt_opt_shape=FLAGS.trt_opt_shape, - trt_calib_mode=FLAGS.trt_calib_mode, - cpu_threads=FLAGS.cpu_threads, - enable_mkldnn=FLAGS.enable_mkldnn) + detector_func = 'DetectorSOLOv2' + elif pred_config.arch == 'PicoDet': + detector_func = 'DetectorPicoDet' + + detector = eval(detector_func)(pred_config, + FLAGS.model_dir, + device=FLAGS.device, + run_mode=FLAGS.run_mode, + batch_size=FLAGS.batch_size, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: diff --git a/deploy/python/picodet_postprocess.py b/deploy/python/picodet_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..1fed8c6411c02b502399963811de516538d802d8 --- /dev/null +++ b/deploy/python/picodet_postprocess.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from scipy.special import softmax + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + indexes = np.argsort(scores) + indexes = indexes[-candidate_size:] + while len(indexes) > 0: + current = indexes[-1] + picked.append(current) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[:-1] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + np.expand_dims( + current_box, axis=0), ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom): + """Compute the areas of rectangles given two corners. + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + Returns: + area (N): return the area. + """ + hw = np.clip(right_bottom - left_top, 0.0, None) + return hw[..., 0] * hw[..., 1] + + +class PicoDetPostProcess(object): + """ + Args: + input_shape (int): network input image size + ori_shape (int): ori image shape of before padding + scale_factor (float): scale factor of ori image + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + input_shape, + ori_shape, + scale_factor, + strides=[8, 16, 32, 64], + score_threshold=0.4, + nms_threshold=0.5, + nms_top_k=1000, + keep_top_k=100): + self.ori_shape = ori_shape + self.input_shape = input_shape + self.scale_factor = scale_factor + self.strides = strides + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + + def warp_boxes(self, boxes, ori_shape): + """Apply transform to boxes + """ + width, height = ori_shape[1], ori_shape[0] + n = len(boxes) + if n: + # warp points + xy = np.ones((n * 4, 3)) + xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( + n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + # xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + # clip boxes + xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) + xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) + return xy.astype(np.float32) + else: + return boxes + + def __call__(self, scores, raw_boxes): + batch_size = raw_boxes[0].shape[0] + reg_max = int(raw_boxes[0].shape[-1] / 4 - 1) + out_boxes_num = [] + out_boxes_list = [] + for batch_id in range(batch_size): + # generate centers + decode_boxes = [] + select_scores = [] + for stride, box_distribute, score in zip(self.strides, raw_boxes, + scores): + box_distribute = box_distribute[batch_id] + score = score[batch_id] + # centers + fm_h = self.input_shape[0] / stride + fm_w = self.input_shape[1] / stride + h_range = np.arange(fm_h) + w_range = np.arange(fm_w) + ww, hh = np.meshgrid(w_range, h_range) + ct_row = (hh.flatten() + 0.5) * stride + ct_col = (ww.flatten() + 0.5) * stride + center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1) + + # box distribution to distance + reg_range = np.arange(reg_max + 1) + box_distance = box_distribute.reshape((-1, reg_max + 1)) + box_distance = softmax(box_distance, axis=1) + box_distance = box_distance * np.expand_dims(reg_range, axis=0) + box_distance = np.sum(box_distance, axis=1).reshape((-1, 4)) + box_distance = box_distance * stride + + # top K candidate + topk_idx = np.argsort(score.max(axis=1))[::-1] + topk_idx = topk_idx[:self.nms_top_k] + center = center[topk_idx] + score = score[topk_idx] + box_distance = box_distance[topk_idx] + + # decode box + decode_box = center + [-1, -1, 1, 1] * box_distance + + select_scores.append(score) + decode_boxes.append(decode_box) + + # nms + bboxes = np.concatenate(decode_boxes, axis=0) + confidences = np.concatenate(select_scores, axis=0) + picked_box_probs = [] + picked_labels = [] + for class_index in range(0, confidences.shape[1]): + probs = confidences[:, class_index] + mask = probs > self.score_threshold + probs = probs[mask] + if probs.shape[0] == 0: + continue + subset_boxes = bboxes[mask, :] + box_probs = np.concatenate( + [subset_boxes, probs.reshape(-1, 1)], axis=1) + box_probs = hard_nms( + box_probs, + iou_threshold=self.nms_threshold, + top_k=self.keep_top_k, ) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.shape[0]) + if not picked_box_probs: + return np.array([]), np.array([]), np.array([]) + picked_box_probs = np.concatenate(picked_box_probs) + + # resize output boxes + picked_box_probs[:, :4] = self.warp_boxes(picked_box_probs[:, :4], + self.ori_shape[batch_id]) + im_scale = np.concatenate([ + self.scale_factor[batch_id][::-1], + self.scale_factor[batch_id][::-1] + ]) + picked_box_probs[:, :4] /= im_scale + # clas score box + out_boxes_list.append( + np.concatenate( + [ + np.expand_dims( + np.array(picked_labels), axis=-1), np.expand_dims( + picked_box_probs[:, 4], axis=-1), + picked_box_probs[:, :4] + ], + axis=1)) + out_boxes_num.append(len(picked_labels)) + + out_boxes_list = np.concatenate(out_boxes_list, axis=0) + out_boxes_num = np.asarray(out_boxes_num).astype(np.int32) + return out_boxes_list, out_boxes_num diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 50cf5b277b9f995eb837f9947a1ae3b4da350a85..602854cb8f72eaa7ed8d2a2a47a3528a7d6fb70c 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -162,5 +162,13 @@ def _dump_infer_config(config, path, image_shape, model): infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader( reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:]) + if infer_arch == 'PicoDet': + infer_cfg['NMS'] = config['PicoHead']['nms'] + # In order to speed up the prediction, the threshold of nms + # is adjusted here, which can be changed in infer_cfg.yml + config['PicoHead']['nms']["score_threshold"] = 0.3 + config['PicoHead']['nms']["nms_threshold"] = 0.5 + infer_cfg['fpn_stride'] = config['PicoHead']['fpn_stride'] + yaml.dump(infer_cfg, open(path, 'w')) logger.info("Export inference config file to {}".format(os.path.join(path))) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index f7835f5158939db74110505812fe30b5e832846e..a97d0f16a1527d69531a75b37678bab51bff0cf1 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -561,8 +561,6 @@ class Trainer(object): if hasattr(self.model, 'fuse_norm'): self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize', False) - if hasattr(self.cfg, 'lite_deploy'): - self.model.lite_deploy = self.cfg.lite_deploy # Save infer cfg _dump_infer_config(self.cfg, diff --git a/ppdet/modeling/architectures/picodet.py b/ppdet/modeling/architectures/picodet.py index 2ec646b24fd0e8017811f85a472b99f80dd4a3f2..7e9382b7c0b83b102e90dc5f1cfa0d490f738248 100644 --- a/ppdet/modeling/architectures/picodet.py +++ b/ppdet/modeling/architectures/picodet.py @@ -41,7 +41,7 @@ class PicoDet(BaseArch): self.backbone = backbone self.neck = neck self.head = head - self.lite_deploy = False + self.deploy = False @classmethod def from_config(cls, cfg, *args, **kwargs): @@ -62,8 +62,8 @@ class PicoDet(BaseArch): def _forward(self): body_feats = self.backbone(self.inputs) fpn_feats = self.neck(body_feats) - head_outs = self.head(fpn_feats) - if self.training or self.lite_deploy: + head_outs = self.head(fpn_feats, self.deploy) + if self.training or self.deploy: return head_outs else: im_shape = self.inputs['im_shape'] @@ -83,7 +83,7 @@ class PicoDet(BaseArch): return loss def get_pred(self): - if self.lite_deploy: + if self.deploy: return {'picodet': self._forward()[0]} else: bbox_pred, bbox_num = self._forward() diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py index abcaed1173d4e3cf952be86395b0da03a6c94de2..b51dfe941b46882b2280582480fec7a5146ac0ae 100644 --- a/ppdet/modeling/heads/pico_head.py +++ b/ppdet/modeling/heads/pico_head.py @@ -226,7 +226,7 @@ class PicoHead(GFLHead): bias_attr=ParamAttr(initializer=Constant(value=0)))) self.head_reg_list.append(head_reg) - def forward(self, fpn_feats): + def forward(self, fpn_feats, deploy=False): assert len(fpn_feats) == len( self.fpn_stride ), "The size of fpn_feats is not equal to size of fpn_stride" @@ -243,11 +243,19 @@ class PicoHead(GFLHead): else: cls_score = self.head_cls_list[i](conv_cls_feat) bbox_pred = self.head_reg_list[i](conv_reg_feat) + if self.dgqp_module: quality_score = self.dgqp_module(bbox_pred) cls_score = F.sigmoid(cls_score) * quality_score - if not self.training: + if deploy: + # Now only supports batch size = 1 in deploy + # TODO(ygh): support batch size > 1 + cls_score = F.sigmoid(cls_score).reshape( + [1, self.cls_out_channels, -1]).transpose([0, 2, 1]) + bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4, + -1]).transpose([0, 2, 1]) + elif not self.training: cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1])) bbox_pred = bbox_pred.transpose([0, 2, 3, 1])