提交 ea2148ab 编写于 作者: Z zhiboniu 提交者: zhiboniu

add tinypose openvino deploy support

上级 147fb450
......@@ -53,6 +53,7 @@ python tools/export_model.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml
## 4.第三方部署(MNN、NCNN、Openvino)
- 第三方部署提供PicoDet、TinyPose案例,其他模型请参考修改
- TinyPose部署推荐工具:Intel CPU端推荐使用Openvino,GPU端推荐使用PaddleInference,ARM/ANDROID端推荐使用PaddleLite或者MNN
| Third_Engine | MNN | NCNN | OPENVINO |
| ------------ | ---- | ----- | ---------- |
......
......@@ -53,6 +53,7 @@ For details on model export, please refer to the documentation [Tutorial on Padd
## 4.Third-Engine deploy(MNN、NCNN、Openvino)
- The Third-Engine deploy take example of PicoDet、TinyPose,the others model is the same
- Suggestion for TinyPose: For Intel CPU Openvino is recommended,for Nvidia GPU PaddleInference is recommended,and for ARM/ANDROID PaddleLite or MNN is recommended.
| Third_Engine | MNN | NCNN | OPENVINO |
| ------------ | ------------------------------------------------------ | -------------------------------------------------- | ------------------------------------------------------------ |
......
......@@ -103,7 +103,7 @@ make
## Run demo
Download PicoDet openvino model [PicoDet openvino model download link](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_416_openvino.zip).
move picodet openvino model files to the demo's weight folder. Then run these commands:
move picodet openvino model files to the demo's weight folder.
### Edit file
```
......
cmake_minimum_required(VERSION 3.4.1)
set(CMAKE_CXX_STANDARD 14)
project(tinypose_demo)
find_package(OpenCV REQUIRED)
find_package(InferenceEngine REQUIRED)
find_package(ngraph REQUIRED)
include_directories(
${OpenCV_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
)
add_executable(tinypose_demo main.cpp picodet_openvino.cpp keypoint_detector.cpp keypoint_postprocess.cpp)
target_link_libraries(
tinypose_demo
${InferenceEngine_LIBRARIES}
${NGRAPH_LIBRARIES}
${OpenCV_LIBS}
)
# TinyPose OpenVINO Demo
This fold provides TinyPose inference code using
[Intel's OpenVINO Toolkit](https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit.html). Most of the implements in this fold are same as *demo_ncnn*.
**Recommand** to use the xxx.tar.gz file to install instead of github method, [link](https://registrationcenter-download.intel.com/akdlm/irc_nas/18096/l_openvino_toolkit_p_2021.4.689.tgz).
## Install OpenVINO Toolkit
Go to [OpenVINO HomePage](https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit.html)
Download a suitable version and install.
Follow the official Get Started Guides: https://docs.openvinotoolkit.org/latest/get_started_guides.html
## Set the Environment Variables
### Windows:
Run this command in cmd. (Every time before using OpenVINO)
```cmd
<INSTSLL_DIR>\openvino_2021\bin\setupvars.bat
```
Or set the system environment variables once for all:
Name |Value
:--------------------:|:--------:
INTEL_OPENVINO_DIR | <INSTSLL_DIR>\openvino_2021
INTEL_CVSDK_DIR | %INTEL_OPENVINO_DIR%
InferenceEngine_DIR | %INTEL_OPENVINO_DIR%\deployment_tools\inference_engine\share
HDDL_INSTALL_DIR | %INTEL_OPENVINO_DIR%\deployment_tools\inference_engine\external\hddl
ngraph_DIR | %INTEL_OPENVINO_DIR%\deployment_tools\ngraph\cmake
And add this to ```Path```
```
%INTEL_OPENVINO_DIR%\deployment_tools\inference_engine\bin\intel64\Debug;%INTEL_OPENVINO_DIR%\deployment_tools\inference_engine\bin\intel64\Release;%HDDL_INSTALL_DIR%\bin;%INTEL_OPENVINO_DIR%\deployment_tools\inference_engine\external\tbb\bin;%INTEL_OPENVINO_DIR%\deployment_tools\ngraph\lib
```
### Linux
Run this command in shell. (Every time before using OpenVINO)
```shell
source /opt/intel/openvino_2021/bin/setupvars.sh
```
Or edit .bashrc
```shell
vi ~/.bashrc
```
Add this line to the end of the file
```shell
source /opt/intel/openvino_2021/bin/setupvars.sh
```
## Convert model
Convert to OpenVINO
``` shell
cd <INSTSLL_DIR>/openvino_2021/deployment_tools/model_optimizer
```
Install requirements for convert tool
```shell
cd ./install_prerequisites
sudo install_prerequisites_onnx.sh
```
Then convert model. Notice: mean_values and scale_values should be the same with your training settings in YAML config file.
```shell
python3 mo_onnx.py --input_model <ONNX_MODEL> --mean_values [103.53,116.28,123.675] --scale_values [57.375,57.12,58.395]
```
## Build
### Windows
```cmd
<OPENVINO_INSTSLL_DIR>\openvino_2021\bin\setupvars.bat
mkdir -p build
cd build
cmake ..
msbuild tinypose_demo.vcxproj /p:configuration=release /p:platform=x64
```
### Linux
```shell
source /opt/intel/openvino_2021/bin/setupvars.sh
mkdir build
cd build
cmake ..
make
```
## Run demo
Download PicoDet openvino model [PicoDet openvino model download link](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_416_openvino.zip).
Download TinyPose openvino model [TinyPose openvino model download link](https://paddledet.bj.bcebos.com/deploy/third_engine/tinypose_256_openvino.zip).
move picodet and tinypose openvino model files to the demo's weight folder.
### Edit file
```
step1:
main.cpp
#define image_size 416
...
cv::Mat image(256, 192, CV_8UC3, cv::Scalar(1, 1, 1));
std::vector<float> center = {128, 96};
std::vector<float> scale = {256, 192};
...
auto detector = PicoDet("../weight/picodet_m_416.xml");
auto kpts_detector = new KeyPointDetector("../weight/tinypose256.xml", -1, 256, 192);
...
step2:
picodet_openvino.h
#define image_size 416
```
### Run
Run command:
``` shell
./tinypose_demo [mode] [image_file]
```
| param | detail |
| ---- | ---- |
| --mode | input mode,0:camera;1:image;2:video;3:benchmark |
| --image_file | input image path |
#### Webcam
```shell
tinypose_demo 0 0
```
#### Inference images
```shell
tinypose_demo 1 IMAGE_FOLDER/*.jpg
```
#### Inference video
```shell
tinypose_demo 2 VIDEO_PATH
```
### Benchmark
```shell
tinypose_demo 3 0
```
Plateform: Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz x 24(核)
Model: [Tinypose256_Openvino](https://paddledet.bj.bcebos.com/deploy/third_engine/tinypose_256_openvino.zip)
| param | Min | Max | Avg |
| ------------- | ----- | ----- | ----- |
| infer time(s) | 0.018 | 0.062 | 0.028 |
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
// for setprecision
#include <chrono>
#include <iomanip>
#include "keypoint_detector.h"
namespace PaddleDetection {
// Visualiztion MaskDetector results
cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap,
float threshold) {
const int edge[][2] = {{0, 1},
{0, 2},
{1, 3},
{2, 4},
{3, 5},
{4, 6},
{5, 7},
{6, 8},
{7, 9},
{8, 10},
{5, 11},
{6, 12},
{11, 13},
{12, 14},
{13, 15},
{14, 16},
{11, 12}};
cv::Mat vis_img = img.clone();
for (int batchid = 0; batchid < results.size(); batchid++) {
for (int i = 0; i < results[batchid].num_joints; i++) {
if (results[batchid].keypoints[i * 3] > threshold) {
int x_coord = int(results[batchid].keypoints[i * 3 + 1]);
int y_coord = int(results[batchid].keypoints[i * 3 + 2]);
cv::circle(vis_img,
cv::Point2d(x_coord, y_coord),
1,
cv::Scalar(0, 0, 255),
2);
}
}
for (int i = 0; i < results[batchid].num_joints; i++) {
if (results[batchid].keypoints[edge[i][0] * 3] > threshold &&
results[batchid].keypoints[edge[i][1] * 3] > threshold) {
int x_start = int(results[batchid].keypoints[edge[i][0] * 3 + 1]);
int y_start = int(results[batchid].keypoints[edge[i][0] * 3 + 2]);
int x_end = int(results[batchid].keypoints[edge[i][1] * 3 + 1]);
int y_end = int(results[batchid].keypoints[edge[i][1] * 3 + 2]);
cv::line(vis_img,
cv::Point2d(x_start, y_start),
cv::Point2d(x_end, y_end),
colormap[i],
1);
}
}
}
return vis_img;
}
void KeyPointDetector::Postprocess(std::vector<float>& output,
std::vector<uint64_t>& output_shape,
std::vector<int>& idxout,
std::vector<uint64_t>& idx_shape,
std::vector<KeyPointResult>* result,
std::vector<std::vector<float>>& center_bs,
std::vector<std::vector<float>>& scale_bs) {
std::vector<float> preds(output_shape[1] * 3, 0);
for (int batchid = 0; batchid < output_shape[0]; batchid++) {
get_final_preds(output,
output_shape,
idxout,
idx_shape,
center_bs[batchid],
scale_bs[batchid],
preds,
batchid,
this->use_dark());
KeyPointResult result_item;
result_item.num_joints = output_shape[1];
result_item.keypoints.clear();
for (int i = 0; i < output_shape[1]; i++) {
result_item.keypoints.emplace_back(preds[i * 3]);
result_item.keypoints.emplace_back(preds[i * 3 + 1]);
result_item.keypoints.emplace_back(preds[i * 3 + 2]);
}
result->push_back(result_item);
}
}
void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
std::vector<std::vector<float>>& center_bs,
std::vector<std::vector<float>>& scale_bs,
std::vector<KeyPointResult>* result) {
int batch_size = imgs.size();
auto insize = 3 * in_h * in_w;
InferenceEngine::Blob::Ptr input_blob = infer_request_.GetBlob(input_name_);
// Preprocess image
InferenceEngine::MemoryBlob::Ptr mblob =
InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob);
if (!mblob) {
THROW_IE_EXCEPTION
<< "We expect blob to be inherited from MemoryBlob in matU8ToBlob, "
<< "but by fact we were not able to cast inputBlob to MemoryBlob";
}
auto mblobHolder = mblob->wmap();
float* blob_data = mblobHolder.as<float*>();
cv::Mat resized_im;
for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat im = imgs.at(bs_idx);
cv::resize(im, resized_im, cv::Size(in_w, in_h));
for (size_t c = 0; c < 3; c++) {
for (size_t h = 0; h < in_h; h++) {
for (size_t w = 0; w < in_w; w++) {
blob_data[c * in_w * in_h + h * in_w + w] =
(float)resized_im.at<cv::Vec3b>(h, w)[c];
}
}
}
}
// Run predictor
auto inference_start = std::chrono::steady_clock::now();
// do inference
infer_request_.Infer();
InferenceEngine::Blob::Ptr output_blob =
infer_request_.GetBlob("save_infer_model/scale_0.tmp_1");
auto output_shape = output_blob->getTensorDesc().getDims();
InferenceEngine::MemoryBlob::Ptr moutput =
InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob);
if (moutput) {
// locked memory holder should be alive all time while access to its
// buffer happens
auto minputHolder = moutput->rmap();
auto data = minputHolder.as<const InferenceEngine::PrecisionTrait<
InferenceEngine::Precision::FP32>::value_type*>();
// Calculate output length
int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
}
output_data_.resize(output_size);
std::copy_n(data, output_size, output_data_.data());
}
InferenceEngine::Blob::Ptr output_blob2 =
infer_request_.GetBlob("save_infer_model/scale_1.tmp_1");
auto idx_shape = output_blob2->getTensorDesc().getDims();
InferenceEngine::MemoryBlob::Ptr moutput2 =
InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob2);
if (moutput2) {
// locked memory holder should be alive all time while access to its
// buffer happens
auto minputHolder = moutput2->rmap();
// Original I64 precision was converted to I32
auto data = minputHolder.as<const InferenceEngine::PrecisionTrait<
InferenceEngine::Precision::I32>::value_type*>();
// Calculate output length
int output_size = 1;
for (int j = 0; j < idx_shape.size(); ++j) {
output_size *= idx_shape[j];
}
idx_data_.resize(output_size);
std::copy_n(data, output_size, idx_data_.data());
}
auto inference_end = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed = inference_end - inference_start;
printf("keypoint inference time: %f s\n", elapsed.count());
// Postprocessing result
Postprocess(output_data_,
output_shape,
idx_data_,
idx_shape,
result,
center_bs,
scale_bs);
}
} // namespace PaddleDetection
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ctime>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <inference_engine.hpp>
#include "keypoint_postprocess.h"
namespace PaddleDetection {
// Object KeyPoint Result
struct KeyPointResult {
// Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf
std::vector<float> keypoints;
int num_joints = -1;
};
// Visualiztion KeyPoint Result
cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap,
float threshold = 0.2);
class KeyPointDetector {
public:
explicit KeyPointDetector(const std::string& model_path,
int input_height = 256,
int input_width = 192,
float score_threshold = 0.3,
const int batch_size = 1,
bool use_dark = true) {
use_dark_ = use_dark;
in_w = input_width;
in_h = input_height;
threshold_ = score_threshold;
InferenceEngine::Core ie;
auto model = ie.ReadNetwork(model_path);
// prepare input settings
InferenceEngine::InputsDataMap inputs_map(model.getInputsInfo());
input_name_ = inputs_map.begin()->first;
InferenceEngine::InputInfo::Ptr input_info = inputs_map.begin()->second;
// prepare output settings
InferenceEngine::OutputsDataMap outputs_map(model.getOutputsInfo());
int idx = 0;
for (auto& output_info : outputs_map) {
if (idx == 0) {
output_info.second->setPrecision(InferenceEngine::Precision::FP32);
} else {
output_info.second->setPrecision(InferenceEngine::Precision::I32);
}
idx++;
}
// get network
network_ = ie.LoadNetwork(model, "CPU");
infer_request_ = network_.CreateInferRequest();
}
// Load Paddle inference model
void LoadModel(std::string model_file, int num_theads);
// Run predictor
void Predict(const std::vector<cv::Mat> imgs,
std::vector<std::vector<float>>& center,
std::vector<std::vector<float>>& scale,
std::vector<KeyPointResult>* result = nullptr);
bool use_dark() { return this->use_dark_; }
inline float get_threshold() { return threshold_; };
int in_w = 128;
int in_h = 256;
private:
// Postprocess result
void Postprocess(std::vector<float>& output,
std::vector<uint64_t>& output_shape,
std::vector<int>& idxout,
std::vector<uint64_t>& idx_shape,
std::vector<KeyPointResult>* result,
std::vector<std::vector<float>>& center,
std::vector<std::vector<float>>& scale);
std::vector<float> output_data_;
std::vector<int> idx_data_;
float threshold_;
bool use_dark_;
InferenceEngine::ExecutableNetwork network_;
InferenceEngine::InferRequest infer_request_;
std::string input_name_;
};
} // namespace PaddleDetection
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "keypoint_postprocess.h"
#define PI 3.1415926535
#define HALF_CIRCLE_DEGREE 180
cv::Point2f get_3rd_point(cv::Point2f& a, cv::Point2f& b) {
cv::Point2f direct{a.x - b.x, a.y - b.y};
return cv::Point2f(a.x - direct.y, a.y + direct.x);
}
std::vector<float> get_dir(float src_point_x,
float src_point_y,
float rot_rad) {
float sn = sin(rot_rad);
float cs = cos(rot_rad);
std::vector<float> src_result{0.0, 0.0};
src_result[0] = src_point_x * cs - src_point_y * sn;
src_result[1] = src_point_x * sn + src_point_y * cs;
return src_result;
}
void affine_tranform(
float pt_x, float pt_y, cv::Mat& trans, std::vector<float>& preds, int p) {
double new1[3] = {pt_x, pt_y, 1.0};
cv::Mat new_pt(3, 1, trans.type(), new1);
cv::Mat w = trans * new_pt;
preds[p * 3 + 1] = static_cast<float>(w.at<double>(0, 0));
preds[p * 3 + 2] = static_cast<float>(w.at<double>(1, 0));
}
void get_affine_transform(std::vector<float>& center,
std::vector<float>& scale,
float rot,
std::vector<int>& output_size,
cv::Mat& trans,
int inv) {
float src_w = scale[0];
float dst_w = static_cast<float>(output_size[0]);
float dst_h = static_cast<float>(output_size[1]);
float rot_rad = rot * PI / HALF_CIRCLE_DEGREE;
std::vector<float> src_dir = get_dir(-0.5 * src_w, 0, rot_rad);
std::vector<float> dst_dir{static_cast<float>(-0.5) * dst_w, 0.0};
cv::Point2f srcPoint2f[3], dstPoint2f[3];
srcPoint2f[0] = cv::Point2f(center[0], center[1]);
srcPoint2f[1] = cv::Point2f(center[0] + src_dir[0], center[1] + src_dir[1]);
srcPoint2f[2] = get_3rd_point(srcPoint2f[0], srcPoint2f[1]);
dstPoint2f[0] = cv::Point2f(dst_w * 0.5, dst_h * 0.5);
dstPoint2f[1] =
cv::Point2f(dst_w * 0.5 + dst_dir[0], dst_h * 0.5 + dst_dir[1]);
dstPoint2f[2] = get_3rd_point(dstPoint2f[0], dstPoint2f[1]);
if (inv == 0) {
trans = cv::getAffineTransform(srcPoint2f, dstPoint2f);
} else {
trans = cv::getAffineTransform(dstPoint2f, srcPoint2f);
}
}
void transform_preds(std::vector<float>& coords,
std::vector<float>& center,
std::vector<float>& scale,
std::vector<int>& output_size,
std::vector<uint64_t>& dim,
std::vector<float>& target_coords) {
cv::Mat trans(2, 3, CV_64FC1);
get_affine_transform(center, scale, 0, output_size, trans, 1);
for (int p = 0; p < dim[1]; ++p) {
affine_tranform(coords[p * 2], coords[p * 2 + 1], trans, target_coords, p);
}
}
// only for batchsize == 1
void get_max_preds(std::vector<float>& heatmap,
std::vector<int>& dim,
std::vector<float>& preds,
std::vector<float>& maxvals,
int batchid,
int joint_idx) {
int num_joints = dim[1];
int width = dim[3];
std::vector<int> idx;
idx.resize(num_joints * 2);
for (int j = 0; j < dim[1]; j++) {
float* index = &(
heatmap[batchid * num_joints * dim[2] * dim[3] + j * dim[2] * dim[3]]);
float* end = index + dim[2] * dim[3];
float* max_dis = std::max_element(index, end);
auto max_id = std::distance(index, max_dis);
maxvals[j] = *max_dis;
if (*max_dis > 0) {
preds[j * 2] = static_cast<float>(max_id % width);
preds[j * 2 + 1] = static_cast<float>(max_id / width);
}
}
}
void dark_parse(std::vector<float>& heatmap,
std::vector<uint64_t>& dim,
std::vector<float>& coords,
int px,
int py,
int index,
int ch) {
/*DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
1) offset = - hassian.inv() * derivative
2) dx = (heatmap[x+1] - heatmap[x-1])/2.
3) dxx = (dx[x+1] - dx[x-1])/2.
4) derivative = Mat([dx, dy])
5) hassian = Mat([[dxx, dxy], [dxy, dyy]])
*/
std::vector<float>::const_iterator first1 = heatmap.begin() + index;
std::vector<float>::const_iterator last1 =
heatmap.begin() + index + dim[2] * dim[3];
std::vector<float> heatmap_ch(first1, last1);
cv::Mat heatmap_mat = cv::Mat(heatmap_ch).reshape(0, dim[2]);
heatmap_mat.convertTo(heatmap_mat, CV_32FC1);
cv::GaussianBlur(heatmap_mat, heatmap_mat, cv::Size(3, 3), 0, 0);
heatmap_mat = heatmap_mat.reshape(1, 1);
heatmap_ch = std::vector<float>(heatmap_mat.reshape(1, 1));
float epsilon = 1e-10;
// sample heatmap to get values in around target location
float xy = log(fmax(heatmap_ch[py * dim[3] + px], epsilon));
float xr = log(fmax(heatmap_ch[py * dim[3] + px + 1], epsilon));
float xl = log(fmax(heatmap_ch[py * dim[3] + px - 1], epsilon));
float xr2 = log(fmax(heatmap_ch[py * dim[3] + px + 2], epsilon));
float xl2 = log(fmax(heatmap_ch[py * dim[3] + px - 2], epsilon));
float yu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px], epsilon));
float yd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px], epsilon));
float yu2 = log(fmax(heatmap_ch[(py + 2) * dim[3] + px], epsilon));
float yd2 = log(fmax(heatmap_ch[(py - 2) * dim[3] + px], epsilon));
float xryu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px + 1], epsilon));
float xryd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px + 1], epsilon));
float xlyu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px - 1], epsilon));
float xlyd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px - 1], epsilon));
// compute dx/dy and dxx/dyy with sampled values
float dx = 0.5 * (xr - xl);
float dy = 0.5 * (yu - yd);
float dxx = 0.25 * (xr2 - 2 * xy + xl2);
float dxy = 0.25 * (xryu - xryd - xlyu + xlyd);
float dyy = 0.25 * (yu2 - 2 * xy + yd2);
// finally get offset by derivative and hassian, which combined by dx/dy and
// dxx/dyy
if (dxx * dyy - dxy * dxy != 0) {
float M[2][2] = {dxx, dxy, dxy, dyy};
float D[2] = {dx, dy};
cv::Mat hassian(2, 2, CV_32F, M);
cv::Mat derivative(2, 1, CV_32F, D);
cv::Mat offset = -hassian.inv() * derivative;
coords[ch * 2] += offset.at<float>(0, 0);
coords[ch * 2 + 1] += offset.at<float>(1, 0);
}
}
void get_final_preds(std::vector<float>& heatmap,
std::vector<uint64_t>& dim,
std::vector<int>& idxout,
std::vector<uint64_t>& idxdim,
std::vector<float>& center,
std::vector<float> scale,
std::vector<float>& preds,
int batchid,
bool DARK) {
std::vector<float> coords;
coords.resize(dim[1] * 2);
int heatmap_height = dim[2];
int heatmap_width = dim[3];
for (int j = 0; j < dim[1]; ++j) {
int index = (batchid * dim[1] + j) * dim[2] * dim[3];
int idx = idxout[batchid * dim[1] + j];
preds[j * 3] = heatmap[index + idx];
coords[j * 2] = idx % heatmap_width;
coords[j * 2 + 1] = idx / heatmap_width;
int px = int(coords[j * 2] + 0.5);
int py = int(coords[j * 2 + 1] + 0.5);
if (DARK && px > 1 && px < heatmap_width - 2 && py > 1 &&
py < heatmap_height - 2) {
dark_parse(heatmap, dim, coords, px, py, index, j);
} else {
if (px > 0 && px < heatmap_width - 1) {
float diff_x = heatmap[index + py * dim[3] + px + 1] -
heatmap[index + py * dim[3] + px - 1];
coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25;
}
if (py > 0 && py < heatmap_height - 1) {
float diff_y = heatmap[index + (py + 1) * dim[3] + px] -
heatmap[index + (py - 1) * dim[3] + px];
coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25;
}
}
}
std::vector<int> img_size{heatmap_width, heatmap_height};
transform_preds(coords, center, scale, img_size, dim, preds);
}
void CropImg(cv::Mat& img,
cv::Mat& crop_img,
std::vector<int>& area,
std::vector<float>& center,
std::vector<float>& scale,
float expandratio) {
int crop_x1 = std::max(0, area[0]);
int crop_y1 = std::max(0, area[1]);
int crop_x2 = std::min(img.cols - 1, area[2]);
int crop_y2 = std::min(img.rows - 1, area[3]);
int center_x = (crop_x1 + crop_x2) / 2.;
int center_y = (crop_y1 + crop_y2) / 2.;
int half_h = (crop_y2 - crop_y1) / 2.;
int half_w = (crop_x2 - crop_x1) / 2.;
if (half_h * 3 > half_w * 4) {
half_w = static_cast<int>(half_h * 0.75);
} else {
half_h = static_cast<int>(half_w * 4 / 3);
}
crop_x1 =
std::max(0, center_x - static_cast<int>(half_w * (1 + expandratio)));
crop_y1 =
std::max(0, center_y - static_cast<int>(half_h * (1 + expandratio)));
crop_x2 = std::min(img.cols - 1,
static_cast<int>(center_x + half_w * (1 + expandratio)));
crop_y2 = std::min(img.rows - 1,
static_cast<int>(center_y + half_h * (1 + expandratio)));
crop_img =
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
center.clear();
center.emplace_back((crop_x1 + crop_x2) / 2);
center.emplace_back((crop_y1 + crop_y2) / 2);
scale.clear();
scale.emplace_back((crop_x2 - crop_x1));
scale.emplace_back((crop_y2 - crop_y1));
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <vector>
std::vector<float> get_3rd_point(std::vector<float>& a, std::vector<float>& b);
std::vector<float> get_dir(float src_point_x, float src_point_y, float rot_rad);
void affine_tranform(float pt_x,
float pt_y,
cv::Mat& trans,
std::vector<float>& x,
int p,
int num);
cv::Mat get_affine_transform(std::vector<float>& center,
std::vector<float>& scale,
float rot,
std::vector<int>& output_size,
int inv);
void transform_preds(std::vector<float>& coords,
std::vector<float>& center,
std::vector<float>& scale,
std::vector<uint64_t>& output_size,
std::vector<int>& dim,
std::vector<float>& target_coords);
void box_to_center_scale(std::vector<int>& box,
int width,
int height,
std::vector<float>& center,
std::vector<float>& scale);
void get_max_preds(std::vector<float>& heatmap,
std::vector<int>& dim,
std::vector<float>& preds,
std::vector<float>& maxvals,
int batchid,
int joint_idx);
void get_final_preds(std::vector<float>& heatmap,
std::vector<uint64_t>& dim,
std::vector<int>& idxout,
std::vector<uint64_t>& idxdim,
std::vector<float>& center,
std::vector<float> scale,
std::vector<float>& preds,
int batchid,
bool DARK = true);
void CropImg(cv::Mat& img,
cv::Mat& crop_img,
std::vector<int>& area,
std::vector<float>& center,
std::vector<float>& scale,
float expandratio = 0.25);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// reference from https://github.com/RangiLyu/nanodet
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#define image_size 416
#include "keypoint_detector.h"
#include "picodet_openvino.h"
using namespace PaddleDetection;
struct object_rect {
int x;
int y;
int width;
int height;
};
int resize_uniform(cv::Mat& src,
cv::Mat& dst,
cv::Size dst_size,
object_rect& effect_area) {
int w = src.cols;
int h = src.rows;
int dst_w = dst_size.width;
int dst_h = dst_size.height;
dst = cv::Mat(cv::Size(dst_w, dst_h), CV_8UC3, cv::Scalar(0));
float ratio_src = w * 1.0 / h;
float ratio_dst = dst_w * 1.0 / dst_h;
int tmp_w = 0;
int tmp_h = 0;
if (ratio_src > ratio_dst) {
tmp_w = dst_w;
tmp_h = floor((dst_w * 1.0 / w) * h);
} else if (ratio_src < ratio_dst) {
tmp_h = dst_h;
tmp_w = floor((dst_h * 1.0 / h) * w);
} else {
cv::resize(src, dst, dst_size);
effect_area.x = 0;
effect_area.y = 0;
effect_area.width = dst_w;
effect_area.height = dst_h;
return 0;
}
cv::Mat tmp;
cv::resize(src, tmp, cv::Size(tmp_w, tmp_h));
if (tmp_w != dst_w) {
int index_w = floor((dst_w - tmp_w) / 2.0);
for (int i = 0; i < dst_h; i++) {
memcpy(dst.data + i * dst_w * 3 + index_w * 3,
tmp.data + i * tmp_w * 3,
tmp_w * 3);
}
effect_area.x = index_w;
effect_area.y = 0;
effect_area.width = tmp_w;
effect_area.height = tmp_h;
} else if (tmp_h != dst_h) {
int index_h = floor((dst_h - tmp_h) / 2.0);
memcpy(dst.data + index_h * dst_w * 3, tmp.data, tmp_w * tmp_h * 3);
effect_area.x = 0;
effect_area.y = index_h;
effect_area.width = tmp_w;
effect_area.height = tmp_h;
} else {
printf("error\n");
}
return 0;
}
const int color_list[80][3] = {
{216, 82, 24}, {236, 176, 31}, {125, 46, 141}, {118, 171, 47},
{76, 189, 237}, {238, 19, 46}, {76, 76, 76}, {153, 153, 153},
{255, 0, 0}, {255, 127, 0}, {190, 190, 0}, {0, 255, 0},
{0, 0, 255}, {170, 0, 255}, {84, 84, 0}, {84, 170, 0},
{84, 255, 0}, {170, 84, 0}, {170, 170, 0}, {170, 255, 0},
{255, 84, 0}, {255, 170, 0}, {255, 255, 0}, {0, 84, 127},
{0, 170, 127}, {0, 255, 127}, {84, 0, 127}, {84, 84, 127},
{84, 170, 127}, {84, 255, 127}, {170, 0, 127}, {170, 84, 127},
{170, 170, 127}, {170, 255, 127}, {255, 0, 127}, {255, 84, 127},
{255, 170, 127}, {255, 255, 127}, {0, 84, 255}, {0, 170, 255},
{0, 255, 255}, {84, 0, 255}, {84, 84, 255}, {84, 170, 255},
{84, 255, 255}, {170, 0, 255}, {170, 84, 255}, {170, 170, 255},
{170, 255, 255}, {255, 0, 255}, {255, 84, 255}, {255, 170, 255},
{42, 0, 0}, {84, 0, 0}, {127, 0, 0}, {170, 0, 0},
{212, 0, 0}, {255, 0, 0}, {0, 42, 0}, {0, 84, 0},
{0, 127, 0}, {0, 170, 0}, {0, 212, 0}, {0, 255, 0},
{0, 0, 42}, {0, 0, 84}, {0, 0, 127}, {0, 0, 170},
{0, 0, 212}, {0, 0, 255}, {0, 0, 0}, {36, 36, 36},
{72, 72, 72}, {109, 109, 109}, {145, 145, 145}, {182, 182, 182},
{218, 218, 218}, {0, 113, 188}, {80, 182, 188}, {127, 127, 0},
};
void draw_bboxes(const cv::Mat& bgr,
const std::vector<BoxInfo>& bboxes,
object_rect effect_roi) {
static const char* class_names[] = {
"person", "bicycle", "car",
"motorcycle", "airplane", "bus",
"train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign",
"parking meter", "bench", "bird",
"cat", "dog", "horse",
"sheep", "cow", "elephant",
"bear", "zebra", "giraffe",
"backpack", "umbrella", "handbag",
"tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball",
"kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket",
"bottle", "wine glass", "cup",
"fork", "knife", "spoon",
"bowl", "banana", "apple",
"sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza",
"donut", "cake", "chair",
"couch", "potted plant", "bed",
"dining table", "toilet", "tv",
"laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave",
"oven", "toaster", "sink",
"refrigerator", "book", "clock",
"vase", "scissors", "teddy bear",
"hair drier", "toothbrush"};
cv::Mat image = bgr.clone();
int src_w = image.cols;
int src_h = image.rows;
int dst_w = effect_roi.width;
int dst_h = effect_roi.height;
float width_ratio = (float)src_w / (float)dst_w;
float height_ratio = (float)src_h / (float)dst_h;
for (size_t i = 0; i < bboxes.size(); i++) {
const BoxInfo& bbox = bboxes[i];
cv::Scalar color = cv::Scalar(color_list[bbox.label][0],
color_list[bbox.label][1],
color_list[bbox.label][2]);
cv::rectangle(image,
cv::Rect(cv::Point((bbox.x1 - effect_roi.x) * width_ratio,
(bbox.y1 - effect_roi.y) * height_ratio),
cv::Point((bbox.x2 - effect_roi.x) * width_ratio,
(bbox.y2 - effect_roi.y) * height_ratio)),
color);
char text[256];
sprintf(text, "%s %.1f%%", class_names[bbox.label], bbox.score * 100);
int baseLine = 0;
cv::Size label_size =
cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = (bbox.x1 - effect_roi.x) * width_ratio;
int y =
(bbox.y1 - effect_roi.y) * height_ratio - label_size.height - baseLine;
if (y < 0) y = 0;
if (x + label_size.width > image.cols) x = image.cols - label_size.width;
cv::rectangle(
image,
cv::Rect(cv::Point(x, y),
cv::Size(label_size.width, label_size.height + baseLine)),
color,
-1);
cv::putText(image,
text,
cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX,
0.4,
cv::Scalar(255, 255, 255));
}
cv::imwrite("../predict.jpg", image);
}
std::vector<BoxInfo> coordsback(const cv::Mat image,
const object_rect effect_roi,
const std::vector<BoxInfo>& bboxes) {
int src_w = image.cols;
int src_h = image.rows;
int dst_w = effect_roi.width;
int dst_h = effect_roi.height;
float width_ratio = (float)src_w / (float)dst_w;
float height_ratio = (float)src_h / (float)dst_h;
std::vector<BoxInfo> bboxes_oimg;
for (int i = 0; i < bboxes.size(); i++) {
auto bbox = bboxes[i];
bbox.x1 = (bbox.x1 - effect_roi.x) * width_ratio;
bbox.y1 = (bbox.y1 - effect_roi.y) * height_ratio;
bbox.x2 = (bbox.x2 - effect_roi.x) * width_ratio;
bbox.y2 = (bbox.y2 - effect_roi.y) * height_ratio;
bboxes_oimg.emplace_back(bbox);
}
return bboxes_oimg;
}
void image_infer_kpts(KeyPointDetector* kpts_detector,
cv::Mat image,
const object_rect effect_roi,
const std::vector<BoxInfo>& results,
std::string img_name = "kpts_vis",
bool save_img = true) {
std::vector<cv::Mat> cropimgs;
std::vector<std::vector<float>> center_bs;
std::vector<std::vector<float>> scale_bs;
std::vector<KeyPointResult> kpts_results;
auto results_oimg = coordsback(image, effect_roi, results);
for (int i = 0; i < results_oimg.size(); i++) {
auto rect = results_oimg[i];
if (rect.label == 0) {
cv::Mat cropimg;
std::vector<float> center, scale;
std::vector<int> area = {static_cast<int>(rect.x1),
static_cast<int>(rect.y1),
static_cast<int>(rect.x2),
static_cast<int>(rect.y2)};
CropImg(image, cropimg, area, center, scale);
cropimgs.emplace_back(cropimg);
center_bs.emplace_back(center);
scale_bs.emplace_back(scale);
}
if (cropimgs.size() == 1 ||
(cropimgs.size() > 0 && i == results_oimg.size() - 1)) {
kpts_detector->Predict(cropimgs, center_bs, scale_bs, &kpts_results);
cropimgs.clear();
center_bs.clear();
scale_bs.clear();
}
}
std::vector<int> compression_params;
compression_params.push_back(cv::IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
std::string kpts_savepath =
"keypoint_" + img_name.substr(img_name.find_last_of('/') + 1);
cv::Mat kpts_vis_img =
VisualizeKptsResult(image, kpts_results, {0, 255, 0}, 0.1);
if (save_img) {
cv::imwrite(kpts_savepath, kpts_vis_img, compression_params);
printf("Visualized output saved as %s\n", kpts_savepath.c_str());
} else {
cv::imshow("image", kpts_vis_img);
}
}
int image_demo(PicoDet& detector,
KeyPointDetector* kpts_detector,
const char* imagepath) {
std::vector<std::string> filenames;
cv::glob(imagepath, filenames, false);
for (auto img_name : filenames) {
cv::Mat image = cv::imread(img_name);
if (image.empty()) {
return -1;
}
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(
image, resized_img, cv::Size(image_size, image_size), effect_roi);
auto results = detector.detect(resized_img, 0.4, 0.5);
if (kpts_detector) {
image_infer_kpts(kpts_detector, image, effect_roi, results, img_name);
}
}
return 0;
}
int webcam_demo(PicoDet& detector,
KeyPointDetector* kpts_detector,
int cam_id) {
cv::Mat image;
cv::VideoCapture cap(cam_id);
while (true) {
cap >> image;
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(
image, resized_img, cv::Size(image_size, image_size), effect_roi);
auto results = detector.detect(resized_img, 0.4, 0.5);
if (kpts_detector) {
image_infer_kpts(kpts_detector, image, effect_roi, results, "", false);
}
}
return 0;
}
int video_demo(PicoDet& detector,
KeyPointDetector* kpts_detector,
const char* path) {
cv::Mat image;
cv::VideoCapture cap(path);
while (true) {
cap >> image;
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(
image, resized_img, cv::Size(image_size, image_size), effect_roi);
auto results = detector.detect(resized_img, 0.4, 0.5);
if (kpts_detector) {
image_infer_kpts(kpts_detector, image, effect_roi, results, "", false);
}
}
return 0;
}
int benchmark(KeyPointDetector* kpts_detector) {
int loop_num = 100;
int warm_up = 8;
double time_min = DBL_MAX;
double time_max = -DBL_MAX;
double time_avg = 0;
cv::Mat image(256, 192, CV_8UC3, cv::Scalar(1, 1, 1));
std::vector<float> center = {128, 96};
std::vector<float> scale = {256, 192};
std::vector<cv::Mat> cropimgs = {image};
std::vector<std::vector<float>> center_bs = {center};
std::vector<std::vector<float>> scale_bs = {scale};
std::vector<KeyPointResult> kpts_results;
for (int i = 0; i < warm_up + loop_num; i++) {
auto start = std::chrono::steady_clock::now();
std::vector<BoxInfo> results;
kpts_detector->Predict(cropimgs, center_bs, scale_bs, &kpts_results);
auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed = end - start;
double time = elapsed.count();
if (i >= warm_up) {
time_min = (std::min)(time_min, time);
time_max = (std::max)(time_max, time);
time_avg += time;
}
}
time_avg /= loop_num;
fprintf(stderr,
"%20s min = %7.4f max = %7.4f avg = %7.4f\n",
"tinypose",
time_min,
time_max,
time_avg);
return 0;
}
int main(int argc, char** argv) {
if (argc != 3) {
fprintf(stderr,
"usage: %s [mode] [path]. \n For webcam mode=0, path is cam id; \n "
"For image demo, mode=1, path=xxx/xxx/*.jpg; \n For video, mode=2; "
"\n For benchmark, mode=3 path=0.\n",
argv[0]);
return -1;
}
std::cout << "start init model" << std::endl;
auto detector = PicoDet("../weight/picodet_m_416.xml");
auto kpts_detector =
new KeyPointDetector("../weight/tinypose256.xml", 256, 192);
std::cout << "success" << std::endl;
int mode = atoi(argv[1]);
switch (mode) {
case 0: {
int cam_id = atoi(argv[2]);
webcam_demo(detector, kpts_detector, cam_id);
break;
}
case 1: {
const char* images = argv[2];
image_demo(detector, kpts_detector, images);
break;
}
case 2: {
const char* path = argv[2];
video_demo(detector, kpts_detector, path);
break;
}
case 3: {
benchmark(kpts_detector);
break;
}
default: {
fprintf(stderr,
"usage: %s [mode] [path]. \n For webcam mode=0, path is cam id; "
"\n For image demo, mode=1, path=xxx/xxx/*.jpg; \n For video, "
"mode=2; \n For benchmark, mode=3 path=0.\n",
argv[0]);
break;
}
}
delete kpts_detector;
kpts_detector = nullptr;
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_openvino
#include "picodet_openvino.h"
inline float fast_exp(float x) {
union {
uint32_t i;
float f;
} v{};
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
return v.f;
}
inline float sigmoid(float x) { return 1.0f / (1.0f + fast_exp(-x)); }
template <typename _Tp>
int activation_function_softmax(const _Tp* src, _Tp* dst, int length) {
const _Tp alpha = *std::max_element(src, src + length);
_Tp denominator{0};
for (int i = 0; i < length; ++i) {
dst[i] = fast_exp(src[i] - alpha);
denominator += dst[i];
}
for (int i = 0; i < length; ++i) {
dst[i] /= denominator;
}
return 0;
}
PicoDet::PicoDet(const char* model_path) {
InferenceEngine::Core ie;
InferenceEngine::CNNNetwork model = ie.ReadNetwork(model_path);
// prepare input settings
InferenceEngine::InputsDataMap inputs_map(model.getInputsInfo());
input_name_ = inputs_map.begin()->first;
InferenceEngine::InputInfo::Ptr input_info = inputs_map.begin()->second;
// prepare output settings
InferenceEngine::OutputsDataMap outputs_map(model.getOutputsInfo());
for (auto& output_info : outputs_map) {
output_info.second->setPrecision(InferenceEngine::Precision::FP32);
}
// get network
network_ = ie.LoadNetwork(model, "CPU");
infer_request_ = network_.CreateInferRequest();
}
PicoDet::~PicoDet() {}
void PicoDet::preprocess(cv::Mat& image, InferenceEngine::Blob::Ptr& blob) {
int img_w = image.cols;
int img_h = image.rows;
int channels = 3;
InferenceEngine::MemoryBlob::Ptr mblob =
InferenceEngine::as<InferenceEngine::MemoryBlob>(blob);
if (!mblob) {
THROW_IE_EXCEPTION
<< "We expect blob to be inherited from MemoryBlob in matU8ToBlob, "
<< "but by fact we were not able to cast inputBlob to MemoryBlob";
}
auto mblobHolder = mblob->wmap();
float* blob_data = mblobHolder.as<float*>();
for (size_t c = 0; c < channels; c++) {
for (size_t h = 0; h < img_h; h++) {
for (size_t w = 0; w < img_w; w++) {
blob_data[c * img_w * img_h + h * img_w + w] =
(float)image.at<cv::Vec3b>(h, w)[c];
}
}
}
}
std::vector<BoxInfo> PicoDet::detect(cv::Mat image,
float score_threshold,
float nms_threshold) {
InferenceEngine::Blob::Ptr input_blob = infer_request_.GetBlob(input_name_);
preprocess(image, input_blob);
// do inference
infer_request_.Infer();
// get output
std::vector<std::vector<BoxInfo>> results;
results.resize(this->num_class_);
for (const auto& head_info : this->heads_info_) {
const InferenceEngine::Blob::Ptr dis_pred_blob =
infer_request_.GetBlob(head_info.dis_layer);
const InferenceEngine::Blob::Ptr cls_pred_blob =
infer_request_.GetBlob(head_info.cls_layer);
auto mdis_pred =
InferenceEngine::as<InferenceEngine::MemoryBlob>(dis_pred_blob);
auto mdis_pred_holder = mdis_pred->rmap();
const float* dis_pred = mdis_pred_holder.as<const float*>();
auto mcls_pred =
InferenceEngine::as<InferenceEngine::MemoryBlob>(cls_pred_blob);
auto mcls_pred_holder = mcls_pred->rmap();
const float* cls_pred = mcls_pred_holder.as<const float*>();
this->decode_infer(
cls_pred, dis_pred, head_info.stride, score_threshold, results);
}
std::vector<BoxInfo> dets;
for (int i = 0; i < (int)results.size(); i++) {
this->nms(results[i], nms_threshold);
for (auto& box : results[i]) {
dets.push_back(box);
}
}
return dets;
}
void PicoDet::decode_infer(const float*& cls_pred,
const float*& dis_pred,
int stride,
float threshold,
std::vector<std::vector<BoxInfo>>& results) {
int feature_h = input_size_ / stride;
int feature_w = input_size_ / stride;
for (int idx = 0; idx < feature_h * feature_w; idx++) {
int row = idx / feature_w;
int col = idx % feature_w;
float score = 0;
int cur_label = 0;
for (int label = 0; label < num_class_; label++) {
if (cls_pred[idx * num_class_ + label] > score) {
score = cls_pred[idx * num_class_ + label];
cur_label = label;
}
}
if (score > threshold) {
const float* bbox_pred = dis_pred + idx * (reg_max_ + 1) * 4;
results[cur_label].push_back(
this->disPred2Bbox(bbox_pred, cur_label, score, col, row, stride));
}
}
}
BoxInfo PicoDet::disPred2Bbox(
const float*& dfl_det, int label, float score, int x, int y, int stride) {
float ct_x = (x + 0.5) * stride;
float ct_y = (y + 0.5) * stride;
std::vector<float> dis_pred;
dis_pred.resize(4);
for (int i = 0; i < 4; i++) {
float dis = 0;
float* dis_after_sm = new float[reg_max_ + 1];
activation_function_softmax(
dfl_det + i * (reg_max_ + 1), dis_after_sm, reg_max_ + 1);
for (int j = 0; j < reg_max_ + 1; j++) {
dis += j * dis_after_sm[j];
}
dis *= stride;
dis_pred[i] = dis;
delete[] dis_after_sm;
}
float xmin = (std::max)(ct_x - dis_pred[0], .0f);
float ymin = (std::max)(ct_y - dis_pred[1], .0f);
float xmax = (std::min)(ct_x + dis_pred[2], (float)this->input_size_);
float ymax = (std::min)(ct_y + dis_pred[3], (float)this->input_size_);
return BoxInfo{xmin, ymin, xmax, ymax, score, label};
}
void PicoDet::nms(std::vector<BoxInfo>& input_boxes, float NMS_THRESH) {
std::sort(input_boxes.begin(), input_boxes.end(), [](BoxInfo a, BoxInfo b) {
return a.score > b.score;
});
std::vector<float> vArea(input_boxes.size());
for (int i = 0; i < int(input_boxes.size()); ++i) {
vArea[i] = (input_boxes.at(i).x2 - input_boxes.at(i).x1 + 1) *
(input_boxes.at(i).y2 - input_boxes.at(i).y1 + 1);
}
for (int i = 0; i < int(input_boxes.size()); ++i) {
for (int j = i + 1; j < int(input_boxes.size());) {
float xx1 = (std::max)(input_boxes[i].x1, input_boxes[j].x1);
float yy1 = (std::max)(input_boxes[i].y1, input_boxes[j].y1);
float xx2 = (std::min)(input_boxes[i].x2, input_boxes[j].x2);
float yy2 = (std::min)(input_boxes[i].y2, input_boxes[j].y2);
float w = (std::max)(float(0), xx2 - xx1 + 1);
float h = (std::max)(float(0), yy2 - yy1 + 1);
float inter = w * h;
float ovr = inter / (vArea[i] + vArea[j] - inter);
if (ovr >= NMS_THRESH) {
input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j);
} else {
j++;
}
}
}
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_openvino
#ifndef _PICODET_OPENVINO_H_
#define _PICODET_OPENVINO_H_
#include <inference_engine.hpp>
#include <opencv2/core.hpp>
#include <string>
#define image_size 416
typedef struct HeadInfo {
std::string cls_layer;
std::string dis_layer;
int stride;
} HeadInfo;
typedef struct BoxInfo {
float x1;
float y1;
float x2;
float y2;
float score;
int label;
} BoxInfo;
class PicoDet {
public:
PicoDet(const char* param);
~PicoDet();
InferenceEngine::ExecutableNetwork network_;
InferenceEngine::InferRequest infer_request_;
std::vector<HeadInfo> heads_info_{
// cls_pred|dis_pred|stride
{"save_infer_model/scale_0.tmp_1", "save_infer_model/scale_4.tmp_1", 8},
{"save_infer_model/scale_1.tmp_1", "save_infer_model/scale_5.tmp_1", 16},
{"save_infer_model/scale_2.tmp_1", "save_infer_model/scale_6.tmp_1", 32},
{"save_infer_model/scale_3.tmp_1", "save_infer_model/scale_7.tmp_1", 64},
};
std::vector<BoxInfo> detect(cv::Mat image,
float score_threshold,
float nms_threshold);
private:
void preprocess(cv::Mat& image, InferenceEngine::Blob::Ptr& blob);
void decode_infer(const float*& cls_pred,
const float*& dis_pred,
int stride,
float threshold,
std::vector<std::vector<BoxInfo>>& results);
BoxInfo disPred2Bbox(
const float*& dfl_det, int label, float score, int x, int y, int stride);
static void nms(std::vector<BoxInfo>& result, float nms_threshold);
std::string input_name_;
int input_size_ = image_size;
int num_class_ = 80;
int reg_max_ = 7;
};
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册