diff --git a/deploy/lite/README.md b/deploy/lite/README.md index c076f8b2d9f6c838d129da54b63a8d7a226166e2..32c69ed6ea36d8292cd29c1ed02da1eba21748e2 100644 --- a/deploy/lite/README.md +++ b/deploy/lite/README.md @@ -24,7 +24,7 @@ Paddle Lite是飞桨轻量化推理引擎,为手机、IOT端提供高效推理 1. [**建议**]直接下载,预测库下载链接如下: |平台|预测库下载链接| |-|-| - |Android|[arm7](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.8/inference_lite_lib.android.armv7.gcc.c++_static.with_extra.with_cv.tar.gz) / [arm8](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.8/inference_lite_lib.android.armv8.gcc.c++_static.with_extra.with_cv.tar.gz)| + |Android|[arm7](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9.1/inference_lite_lib.android.armv7.clang.c++_static.with_extra.with_cv.tar.gz) / [arm8](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9.1/inference_lite_lib.android.armv8.clang.c++_static.with_extra.with_cv.tar.gz)| **注意**:1. 如果是从 Paddle-Lite [官方文档](https://paddle-lite.readthedocs.io/zh/latest/quick_start/release_lib.html#android-toolchain-gcc)下载的预测库,注意选择`with_extra=ON,with_cv=ON`的下载链接。2. 目前只提供Android端demo,IOS端demo可以参考[Paddle-Lite IOS demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo/tree/master/PaddleLite-ios-demo) @@ -40,7 +40,7 @@ git checkout develop **注意**:编译Paddle-Lite获得预测库时,需要打开`--with_cv=ON --with_extra=ON`两个选项,`--arch`表示`arm`版本,这里指定为armv8,更多编译命令介绍请参考[链接](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_andriod.html#id2)。 -直接下载预测库并解压后,可以得到`inference_lite_lib.android.armv8.gcc.c++_static.with_extra.with_cv/`文件夹,通过编译Paddle-Lite得到的预测库位于`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/`文件夹下。 +直接下载预测库并解压后,可以得到`inference_lite_lib.android.armv8.clang.c++_static.with_extra.with_cv/`文件夹,通过编译Paddle-Lite得到的预测库位于`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/`文件夹下。 预测库的文件目录如下: ``` @@ -120,7 +120,7 @@ Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括 #### 2.1.3 转换示例 -下面以PaddleDetection中的 `PP-YOLO-tiny` 模型为例,介绍使用`paddle_lite_opt`完成预训练模型到inference模型,再到Paddle-Lite优化模型的转换。 +下面以PaddleDetection中的 `ppyolo` 模型为例,介绍使用`paddle_lite_opt`完成预训练模型到inference模型,再到Paddle-Lite优化模型的转换。 ```shell # 进入PaddleDetection根目录 @@ -130,7 +130,7 @@ cd PaddleDetection_root_path python tools/export_model.py -c configs/ppyolo/ppyolo_tiny_650e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyolo_tiny_650e_coco.pdparams # 将inference模型转化为Paddle-Lite优化模型 -paddle_lite_opt --valid_targets=arm --optimize_out_type=naive_buffe --model_file=output_inference/ppyolo_tiny_650e_coco/model.pdmodel --param_file=output_inference/ppyolo_tiny_650e_coco/model.pdiparams --optimize_out=output_inference/ppyolo_tiny_650e_coco/model +paddle_lite_opt --valid_targets=arm --model_file=output_inference/ppyolo_tiny_650e_coco/model.pdmodel --param_file=output_inference/ppyolo_tiny_650e_coco/model.pdiparams --optimize_out=output_inference/ppyolo_tiny_650e_coco/model # 将inference模型配置转化为json格式 python deploy/lite/convert_yml_to_json.py output_inference/ppyolo_tiny_650e_coco/infer_cfg.yml @@ -180,7 +180,7 @@ cd deploy/lite/ inference_lite_path=/{lite prediction library path}/inference_lite_lib.android.armv8.gcc.c++_static.with_extra.with_cv/ mkdir $inference_lite_path/demo/cxx/lite -cp -r Makefile src/ include/ runtime_config.json $inference_lite_path/demo/cxx/lite +cp -r Makefile src/ include/ *runtime_config.json $inference_lite_path/demo/cxx/lite cd $inference_lite_path/demo/cxx/lite @@ -194,7 +194,7 @@ make ARM_ABI = arm8 ```shell mdkir deploy -cp main runtime_config.json deploy/ +cp main *runtime_config.json deploy/ cd deploy mkdir model_det mkdir model_keypoint @@ -219,31 +219,42 @@ cp ../../../cxx/lib/libpaddle_light_api_shared.so ./ ``` deploy/ |-- model_det/ -| |--mdoel.nb 优化后的检测模型文件 -| |--infer_cfg.json 检测器模型配置文件 +| |--mdoel.nb 优化后的检测模型文件 +| |--infer_cfg.json 检测器模型配置文件 |-- model_keypoint/ -| |--mdoel.nb 优化后的关键点模型文件 -| |--infer_cfg.json 关键点模型配置文件 -|-- main 生成的移动端执行文件 -|-- runtime_config.json 移动端执行时参数配置文件 -|-- libpaddle_light_api_shared.so Paddle-Lite库文件 +| |--mdoel.nb 优化后的关键点模型文件 +| |--infer_cfg.json 关键点模型配置文件 +|-- main 生成的移动端执行文件 +|-- det_runtime_config.json 目标检测执行时参数配置文件 +|-- keypoint_runtime_config.json 关键点检测执行时参数配置文件 +|-- libpaddle_light_api_shared.so Paddle-Lite库文件 ``` **注意:** -* `runtime_config.json` 包含了检测器的超参数,请按需进行修改(注意配置中路径及文件需存在): +* `det_runtime_config.json` 包含了目标检测的超参数,请按需进行修改: ```shell { "model_dir_det": "./model_det/", #检测器模型路径 "batch_size_det": 1, #检测预测时batchsize "threshold_det": 0.5, #检测器输出阈值 + "image_file": "demo.jpg", #测试图片 + "image_dir": "", #测试图片文件夹 + "run_benchmark": false, #性能测试开关 + "cpu_threads": 4 #线程数 +} +``` + +* `keypoint_runtime_config.json` 包含了关键点检测的超参数,请按需进行修改: +```shell +{ "model_dir_keypoint": "./model_keypoint/", #关键点模型路径(不使用需为空字符) "batch_size_keypoint": 8, #关键点预测时batchsize "threshold_keypoint": 0.5, #关键点输出阈值 "image_file": "demo.jpg", #测试图片 "image_dir": "", #测试图片文件夹 "run_benchmark": false, #性能测试开关 - "cpu_threads": 1 #线程数 + "cpu_threads": 4 #线程数 } ``` @@ -259,8 +270,8 @@ export LD_LIBRARY_PATH=/data/local/tmp/deploy:$LD_LIBRARY_PATH # 修改权限为可执行 chmod 777 main -# 执行程序 -./main +# 以检测为例,执行程序 +./main det_runtime_config.json ``` 如果对代码做了修改,则需要重新编译并push到手机上。 diff --git a/deploy/lite/det_runtime_config.json b/deploy/lite/det_runtime_config.json new file mode 100644 index 0000000000000000000000000000000000000000..a1bc4ec3bdcf226f8c31caf2ff7b00e7b832050d --- /dev/null +++ b/deploy/lite/det_runtime_config.json @@ -0,0 +1,10 @@ +{ + "model_dir_det": "./model_det/", + "batch_size_det": 1, + "threshold_det": 0.5, + "image_file": "./demo.jpg", + "image_dir": "", + "run_benchmark": false, + "cpu_threads": 4 + } + \ No newline at end of file diff --git a/deploy/lite/include/config_parser.h b/deploy/lite/include/config_parser.h index 8c56846bd6249c218ee4167326fc2aa96101a9fd..67f662e7221fa71325b47995489af8902de090c0 100644 --- a/deploy/lite/include/config_parser.h +++ b/deploy/lite/include/config_parser.h @@ -78,12 +78,26 @@ class ConfigPaser { return false; } + // Get NMS for postprocess + if (config.isMember("NMS")) { + nms_info_ = config["NMS"]; + } + // Get fpn_stride in PicoDet + if (config.isMember("fpn_stride")) { + fpn_stride_.clear(); + for (auto item : config["fpn_stride"]) { + fpn_stride_.emplace_back(item.as()); + } + } + return true; } float draw_threshold_; std::string arch_; Json::Value preprocess_info_; + Json::Value nms_info_; std::vector label_list_; + std::vector fpn_stride_; }; } // namespace PaddleDetection diff --git a/deploy/lite/include/object_detector.h b/deploy/lite/include/object_detector.h index 220e562541f475fd5cc3dd365191db441d714bb2..7874a9b8bba087f5731ac9d91ebd308a8e0d5ef2 100644 --- a/deploy/lite/include/object_detector.h +++ b/deploy/lite/include/object_detector.h @@ -28,26 +28,19 @@ #include "include/config_parser.h" #include "include/preprocess_op.h" +#include "include/utils.h" +#include "include/picodet_postprocess.h" using namespace paddle::lite_api; // NOLINT namespace PaddleDetection { -// Object Detection Result -struct ObjectResult { - // Rectangle coordinates of detected object: left, right, top, down - std::vector rect; - // Class id of detected object - int class_id; - // Confidence of detected object - float confidence; -}; // Generate visualization colormap for each class std::vector GenerateColorMap(int num_class); // Visualiztion Detection Result cv::Mat VisualizeResult(const cv::Mat& img, - const std::vector& results, + const std::vector& results, const std::vector& lables, const std::vector& colormap, const bool is_rbox); @@ -74,7 +67,7 @@ class ObjectDetector { const double threshold = 0.5, const int warmup = 0, const int repeats = 1, - std::vector* result = nullptr, + std::vector* result = nullptr, std::vector* bbox_num = nullptr, std::vector* times = nullptr); @@ -88,7 +81,7 @@ class ObjectDetector { void Preprocess(const cv::Mat& image_mat); // Postprocess result void Postprocess(const std::vector mats, - std::vector* result, + std::vector* result, std::vector bbox_num, bool is_rbox); @@ -99,6 +92,7 @@ class ObjectDetector { std::vector out_bbox_num_data_; float threshold_; ConfigPaser config_; + }; } // namespace PaddleDetection diff --git a/deploy/lite/include/picodet_postprocess.h b/deploy/lite/include/picodet_postprocess.h new file mode 100644 index 0000000000000000000000000000000000000000..415ef69e548c9c1ce3b485d391dbe9945c6e0c83 --- /dev/null +++ b/deploy/lite/include/picodet_postprocess.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "include/utils.h" + +namespace PaddleDetection { + +void PicoDetPostProcess(std::vector* results, + std::vector outs, + std::vector fpn_stride, + std::vector im_shape, + std::vector scale_factor, + float score_threshold = 0.3, + float nms_threshold = 0.5, + int num_class = 80, + int reg_max = 7); + +} // namespace PaddleDetection \ No newline at end of file diff --git a/deploy/lite/include/utils.h b/deploy/lite/include/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3802e1267176a050402d1fdf742e54a79f33ffb9 --- /dev/null +++ b/deploy/lite/include/utils.h @@ -0,0 +1,39 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace PaddleDetection { + +// Object Detection Result +struct ObjectResult { + // Rectangle coordinates of detected object: left, right, top, down + std::vector rect; + // Class id of detected object + int class_id; + // Confidence of detected object + float confidence; +}; + +void nms(std::vector &input_boxes, float nms_threshold); + +} // namespace PaddleDetection \ No newline at end of file diff --git a/deploy/lite/runtime_config.json b/deploy/lite/keypoint_runtime_config.json similarity index 73% rename from deploy/lite/runtime_config.json rename to deploy/lite/keypoint_runtime_config.json index 80971e51a8c79534704d50be2a8959f631a3cf83..f8c9647afb9df444da3c4a0c5f6a2eea41c1567c 100644 --- a/deploy/lite/runtime_config.json +++ b/deploy/lite/keypoint_runtime_config.json @@ -1,7 +1,4 @@ { - "model_dir_det": "./model_det/", - "batch_size_det": 1, - "threshold_det": 0.5, "model_dir_keypoint": "./model_keypoint/", "batch_size_keypoint": 8, "threshold_keypoint": 0.5, diff --git a/deploy/lite/src/keypoint_postprocess.cc b/deploy/lite/src/keypoint_postprocess.cc index 6124e505dfff023e70133131796a503fef5f4de2..f1a44bcc05590362604601fd66bda3a8a20920c3 100644 --- a/deploy/lite/src/keypoint_postprocess.cc +++ b/deploy/lite/src/keypoint_postprocess.cc @@ -52,7 +52,7 @@ void get_affine_transform(std::vector& center, float dst_h = static_cast(output_size[1]); float rot_rad = rot * PI / HALF_CIRCLE_DEGREE; std::vector src_dir = get_dir(-0.5 * src_w, 0, rot_rad); - std::vector dst_dir{-0.5 * dst_w, 0.0}; + std::vector dst_dir{static_cast(-0.5) * dst_w, 0.0}; cv::Point2f srcPoint2f[3], dstPoint2f[3]; srcPoint2f[0] = cv::Point2f(center[0], center[1]); srcPoint2f[1] = cv::Point2f(center[0] + src_dir[0], center[1] + src_dir[1]); diff --git a/deploy/lite/src/main.cc b/deploy/lite/src/main.cc index 57cabecf208020fef2d49b41a998858306984858..cf0651091ebc7fea1251d0b53600f8de16815d87 100644 --- a/deploy/lite/src/main.cc +++ b/deploy/lite/src/main.cc @@ -152,9 +152,10 @@ void PredictImage(const std::vector all_img_paths, bool is_rbox = false; if (run_benchmark) { det->Predict( - batch_imgs, threshold_det, 10, 10, &result, &bbox_num, &det_times); + batch_imgs, threshold_det, 50, 50, &result, &bbox_num, &det_times); } else { - det->Predict(batch_imgs, 0.5, 0, 1, &result, &bbox_num, &det_times); + det->Predict( + batch_imgs, threshold_det, 0, 1, &result, &bbox_num, &det_times); } // get labels and colormap @@ -272,7 +273,7 @@ void PredictImage(const std::vector all_img_paths, cv::Mat vis_img = PaddleDetection::VisualizeResult( im, im_result, labels, colormap, is_rbox); std::string det_savepath = - output_path + + output_path + "result_" + image_file_path.substr(image_file_path.find_last_of('/') + 1); cv::imwrite(det_savepath, vis_img, compression_params); printf("Visualized output saved as %s\n", det_savepath.c_str()); @@ -284,7 +285,9 @@ void PredictImage(const std::vector all_img_paths, det_t[2] += det_times[2]; } PrintBenchmarkLog(det_t, all_img_paths.size()); - PrintBenchmarkLog(keypoint_t, kpts_imgs); + if (keypoint) { + PrintBenchmarkLog(keypoint_t, kpts_imgs); + } PrintTotalIimeLog((det_t[0] + det_t[1] + det_t[2]) / all_img_paths.size(), (keypoint_t[0] + keypoint_t[1] + keypoint_t[2]) / kpts_imgs, midtimecost / all_img_paths.size()); @@ -293,13 +296,15 @@ void PredictImage(const std::vector all_img_paths, int main(int argc, char** argv) { std::cout << "Usage: " << argv[0] << " [config_path](option) [image_dir](option)\n"; - std::string config_path = "runtime_config.json"; + if (argc < 2) { + std::cout << "Usage: ./main det_runtime_config.json" << std::endl; + return -1; + } + std::string config_path = argv[1]; std::string img_path = ""; - if (argc >= 2) { - config_path = argv[1]; - if (argc >= 3) { - img_path = argv[2]; - } + + if (argc >= 3) { + img_path = argv[2]; } // Parsing command-line PaddleDetection::load_jsonf(config_path, RT_Config); diff --git a/deploy/lite/src/object_detector.cc b/deploy/lite/src/object_detector.cc index 4842ecef5690259e5204949eaa4dbd3c5a23c73d..0909bd9194679485fd2a8b735ff6f7ffdb0bb2c9 100644 --- a/deploy/lite/src/object_detector.cc +++ b/deploy/lite/src/object_detector.cc @@ -31,7 +31,7 @@ void ObjectDetector::LoadModel(std::string model_file, int num_theads) { // Visualiztion MaskDetector results cv::Mat VisualizeResult(const cv::Mat& img, - const std::vector& results, + const std::vector& results, const std::vector& lables, const std::vector& colormap, const bool is_rbox = false) { @@ -100,7 +100,7 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) { } void ObjectDetector::Postprocess(const std::vector mats, - std::vector* result, + std::vector* result, std::vector bbox_num, bool is_rbox = false) { result->clear(); @@ -128,7 +128,7 @@ void ObjectDetector::Postprocess(const std::vector mats, int x4 = (output_data_[8 + j * 10] * rw); int y4 = (output_data_[9 + j * 10] * rh); - ObjectResult result_item; + PaddleDetection::ObjectResult result_item; result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4}; result_item.class_id = class_id; result_item.confidence = score; @@ -145,7 +145,7 @@ void ObjectDetector::Postprocess(const std::vector mats, int wd = xmax - xmin; int hd = ymax - ymin; - ObjectResult result_item; + PaddleDetection::ObjectResult result_item; result_item.rect = {xmin, ymin, xmax, ymax}; result_item.class_id = class_id; result_item.confidence = score; @@ -160,7 +160,7 @@ void ObjectDetector::Predict(const std::vector& imgs, const double threshold, const int warmup, const int repeats, - std::vector* result, + std::vector* result, std::vector* bbox_num, std::vector* times) { auto preprocess_start = std::chrono::steady_clock::now(); @@ -185,6 +185,7 @@ void ObjectDetector::Predict(const std::vector& imgs, in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end()); } auto preprocess_end = std::chrono::steady_clock::now(); + std::vector output_data_list_; // Prepare input tensor auto input_names = predictor_->GetInputNames(); @@ -213,16 +214,46 @@ void ObjectDetector::Predict(const std::vector& imgs, predictor_->Run(); // Get output tensor auto output_names = predictor_->GetOutputNames(); - auto out_tensor = predictor_->GetTensor(output_names[0]); - auto out_bbox_num = predictor_->GetTensor(output_names[1]); + if (config_.arch_ == "PicoDet") { + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = predictor_->GetTensor(output_names[j]); + const float* outptr = output_tensor->data(); + std::vector output_shape = output_tensor->shape(); + output_data_list_.push_back(outptr); + } + } else { + auto out_tensor = predictor_->GetTensor(output_names[0]); + auto out_bbox_num = predictor_->GetTensor(output_names[1]); + } } bool is_rbox = false; auto inference_start = std::chrono::steady_clock::now(); for (int i = 0; i < repeats; i++) { predictor_->Run(); - // Get output tensor - auto output_names = predictor_->GetOutputNames(); + } + auto inference_end = std::chrono::steady_clock::now(); + auto postprocess_start = std::chrono::steady_clock::now(); + // Get output tensor + output_data_list_.clear(); + int num_class = 80; + int reg_max = 7; + auto output_names = predictor_->GetOutputNames(); + // TODO: Unified model output. + if (config_.arch_ == "PicoDet") { + for (int i = 0; i < output_names.size(); i++) { + auto output_tensor = predictor_->GetTensor(output_names[i]); + const float* outptr = output_tensor->data(); + std::vector output_shape = output_tensor->shape(); + if (i == 0) { + num_class = output_shape[2]; + } + if (i == config_.fpn_stride_.size()) { + reg_max = output_shape[2] / 4 - 1; + } + output_data_list_.push_back(outptr); + } + } else { auto output_tensor = predictor_->GetTensor(output_names[0]); auto output_shape = output_tensor->shape(); auto out_bbox_num = predictor_->GetTensor(output_names[1]); @@ -250,15 +281,22 @@ void ObjectDetector::Predict(const std::vector& imgs, out_bbox_num_size, out_bbox_num_data_.data()); } - auto inference_end = std::chrono::steady_clock::now(); - auto postprocess_start = std::chrono::steady_clock::now(); // Postprocessing result result->clear(); - Postprocess(imgs, result, out_bbox_num_data_, is_rbox); - bbox_num->clear(); - for (int k = 0; k < out_bbox_num_data_.size(); k++) { - int tmp = out_bbox_num_data_[k]; - bbox_num->push_back(tmp); + if (config_.arch_ == "PicoDet") { + PaddleDetection::PicoDetPostProcess( + result, output_data_list_, config_.fpn_stride_, + inputs_.im_shape_, inputs_.scale_factor_, + config_.nms_info_["score_threshold"].as(), + config_.nms_info_["nms_threshold"].as(), num_class, reg_max); + bbox_num->push_back(result->size()); + } else { + Postprocess(imgs, result, out_bbox_num_data_, is_rbox); + bbox_num->clear(); + for (int k = 0; k < out_bbox_num_data_.size(); k++) { + int tmp = out_bbox_num_data_[k]; + bbox_num->push_back(tmp); + } } auto postprocess_end = std::chrono::steady_clock::now(); diff --git a/deploy/lite/src/picodet_postprocess.cc b/deploy/lite/src/picodet_postprocess.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba73c7d8cd60fb0ef04f678c27680628696fff5f --- /dev/null +++ b/deploy/lite/src/picodet_postprocess.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "include/picodet_postprocess.h" + +namespace PaddleDetection { + +float fast_exp(float x) { + union { + uint32_t i; + float f; + } v{}; + v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f); + return v.f; +} + +template +int activation_function_softmax(const _Tp *src, _Tp *dst, int length) { + const _Tp alpha = *std::max_element(src, src + length); + _Tp denominator{0}; + + for (int i = 0; i < length; ++i) { + dst[i] = fast_exp(src[i] - alpha); + denominator += dst[i]; + } + + for (int i = 0; i < length; ++i) { + dst[i] /= denominator; + } + + return 0; +} + +// PicoDet decode +PaddleDetection::ObjectResult disPred2Bbox(const float *&dfl_det, int label, float score, + int x, int y, int stride, std::vector im_shape, + int reg_max) { + float ct_x = (x + 0.5) * stride; + float ct_y = (y + 0.5) * stride; + std::vector dis_pred; + dis_pred.resize(4); + for (int i = 0; i < 4; i++) { + float dis = 0; + float* dis_after_sm = new float[reg_max + 1]; + activation_function_softmax(dfl_det + i * (reg_max + 1), dis_after_sm, reg_max + 1); + for (int j = 0; j < reg_max + 1; j++) { + dis += j * dis_after_sm[j]; + } + dis *= stride; + dis_pred[i] = dis; + delete[] dis_after_sm; + } + int xmin = (int)(std::max)(ct_x - dis_pred[0], .0f); + int ymin = (int)(std::max)(ct_y - dis_pred[1], .0f); + int xmax = (int)(std::min)(ct_x + dis_pred[2], (float)im_shape[0]); + int ymax = (int)(std::min)(ct_y + dis_pred[3], (float)im_shape[1]); + + PaddleDetection::ObjectResult result_item; + result_item.rect = {xmin, ymin, xmax, ymax}; + result_item.class_id = label; + result_item.confidence = score; + + return result_item; +} + + +void PicoDetPostProcess(std::vector* results, + std::vector outs, + std::vector fpn_stride, + std::vector im_shape, + std::vector scale_factor, + float score_threshold, + float nms_threshold, + int num_class, + int reg_max) { + std::vector> bbox_results; + bbox_results.resize(num_class); + int in_h = im_shape[0], in_w = im_shape[1]; + for (int i = 0; i < fpn_stride.size(); ++i) { + int feature_h = in_h / fpn_stride[i]; + int feature_w = in_w / fpn_stride[i]; + for (int idx = 0; idx < feature_h * feature_w; idx++) { + const float *scores = outs[i] + (idx * num_class); + + int row = idx / feature_w; + int col = idx % feature_w; + float score = 0; + int cur_label = 0; + for (int label = 0; label < num_class; label++) { + if (scores[label] > score) { + score = scores[label]; + cur_label = label; + } + } + if (score > score_threshold) { + const float *bbox_pred = outs[i + fpn_stride.size()] + + (idx * 4 * (reg_max + 1)); + bbox_results[cur_label].push_back(disPred2Bbox(bbox_pred, + cur_label, score, col, row, fpn_stride[i], im_shape, reg_max)); + } + } + } + for (int i = 0; i < (int)bbox_results.size(); i++) { + PaddleDetection::nms(bbox_results[i], nms_threshold); + + for (auto box : bbox_results[i]) { + box.rect[0] = box.rect[0] / scale_factor[1]; + box.rect[2] = box.rect[2] / scale_factor[1]; + box.rect[1] = box.rect[1] / scale_factor[0]; + box.rect[3] = box.rect[3] / scale_factor[0]; + results->push_back(box); + } + } +} + +} // namespace PaddleDetection \ No newline at end of file diff --git a/deploy/lite/src/utils.cc b/deploy/lite/src/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b4731cd9e25b3536417ade20d3f9ce5089755fd --- /dev/null +++ b/deploy/lite/src/utils.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "include/utils.h" + +namespace PaddleDetection { + +void nms(std::vector &input_boxes, float nms_threshold) { + std::sort(input_boxes.begin(), + input_boxes.end(), + [](ObjectResult a, ObjectResult b) { return a.confidence > b.confidence; }); + std::vector vArea(input_boxes.size()); + for (int i = 0; i < int(input_boxes.size()); ++i) { + vArea[i] = (input_boxes.at(i).rect[2] - input_boxes.at(i).rect[0] + 1) + * (input_boxes.at(i).rect[3] - input_boxes.at(i).rect[1] + 1); + } + for (int i = 0; i < int(input_boxes.size()); ++i) { + for (int j = i + 1; j < int(input_boxes.size());) { + float xx1 = (std::max)(input_boxes[i].rect[0], input_boxes[j].rect[0]); + float yy1 = (std::max)(input_boxes[i].rect[1], input_boxes[j].rect[1]); + float xx2 = (std::min)(input_boxes[i].rect[2], input_boxes[j].rect[2]); + float yy2 = (std::min)(input_boxes[i].rect[3], input_boxes[j].rect[3]); + float w = (std::max)(float(0), xx2 - xx1 + 1); + float h = (std::max)(float(0), yy2 - yy1 + 1); + float inter = w * h; + float ovr = inter / (vArea[i] + vArea[j] - inter); + if (ovr >= nms_threshold) { + input_boxes.erase(input_boxes.begin() + j); + vArea.erase(vArea.begin() + j); + } + else { + j++; + } + } + } +} + +} // namespace PaddleDetection