未验证 提交 8d6ce4e0 编写于 作者: Z zhiboniu 提交者: GitHub

update openvino codes (#7196)

上级 256f779b
...@@ -2,8 +2,12 @@ ...@@ -2,8 +2,12 @@
This fold provides TinyPose inference code using This fold provides TinyPose inference code using
[Intel's OpenVINO Toolkit](https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit.html). Most of the implements in this fold are same as *demo_ncnn*. [Intel's OpenVINO Toolkit](https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit.html). Most of the implements in this fold are same as *demo_ncnn*.
**Recommand** to use the xxx.tar.gz file to install instead of github method, [link](https://registrationcenter-download.intel.com/akdlm/irc_nas/18096/l_openvino_toolkit_p_2021.4.689.tgz). **Recommand**
1. To use the xxx.tar.gz file to install instead of github method, [link](https://registrationcenter-download.intel.com/akdlm/irc_nas/18096/l_openvino_toolkit_p_2021.4.689.tgz).
2. Your can also deploy openvino with docker, the command is :
```
docker pull openvino/ubuntu18_dev:2021.4.1
```
## Install OpenVINO Toolkit ## Install OpenVINO Toolkit
...@@ -59,7 +63,30 @@ source /opt/intel/openvino_2021/bin/setupvars.sh ...@@ -59,7 +63,30 @@ source /opt/intel/openvino_2021/bin/setupvars.sh
## Convert model ## Convert model
Convert to OpenVINO **1. Conver to onnx**
Create picodet_m_416_coco.onnx and tinypose256.onnx
example:
```shell
modelName=picodet_m_416_coco
# export model
python tools/export_model.py \
-c configs/picodet/${modelName}.yml \
-o weights=${modelName}.pdparams \
--output_dir=inference_model
# convert to onnx
paddle2onnx --model_dir inference_model/${modelName} \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--opset_version 11 \
--save_file ${modelName}.onnx
# onnxsim
python -m onnxsim ${modelName}.onnx ${modelName}_sim.onnx
```
**2.Convert to OpenVINO**
``` shell ``` shell
cd <INSTSLL_DIR>/openvino_2021/deployment_tools/model_optimizer cd <INSTSLL_DIR>/openvino_2021/deployment_tools/model_optimizer
...@@ -75,9 +102,11 @@ source /opt/intel/openvino_2021/bin/setupvars.sh ...@@ -75,9 +102,11 @@ source /opt/intel/openvino_2021/bin/setupvars.sh
Then convert model. Notice: mean_values and scale_values should be the same with your training settings in YAML config file. Then convert model. Notice: mean_values and scale_values should be the same with your training settings in YAML config file.
```shell ```shell
python3 mo_onnx.py --input_model <ONNX_MODEL> --mean_values [103.53,116.28,123.675] --scale_values [57.375,57.12,58.395] mo_onnx.py --input_model <ONNX_MODEL> --mean_values [103.53,116.28,123.675] --scale_values [57.375,57.12,58.395] --input_shape [1,3,256,192]
``` ```
**Note: The new version of openvino convert tools may cause error in Resize op. If you has problem with this, please try the version: openvino_2021.4.689**
## Build ## Build
### Windows ### Windows
...@@ -101,11 +130,41 @@ make ...@@ -101,11 +130,41 @@ make
## Run demo ## Run demo
Download PicoDet openvino model [PicoDet openvino model download link](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_416_openvino.zip). Download PicoDet openvino model [PicoDet openvino model download link](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_416_openvino.zip).
Download TinyPose openvino model [TinyPose openvino model download link](https://paddledet.bj.bcebos.com/deploy/third_engine/tinypose_256_openvino.zip).
Download TinyPose openvino model [TinyPose openvino model download link](https://bj.bcebos.com/v1/paddledet/deploy/third_engine/demo_openvino_kpts.tar.gz), the origin paddlepaddle model is [Tinypose256](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_enhance/tinypose_256x192.pdparams).
move picodet and tinypose openvino model files to the demo's weight folder. move picodet and tinypose openvino model files to the demo's weight folder.
Note:
1. The model output node name may update by new version of paddle\paddle2onnx\onnxsim\openvino, please checkout your own model output node when the code can't find "conv2d_441.tmp_1"\"argmax_0.tmp_0".
2. If you happened with this error "Cannot find blob with name: transpose_1.tmp_0", it means your picodet model is oldversion. you can modify the below code to fix it.
```
#picodet_openvino.h line 50-54
std::vector<HeadInfo> heads_info_{
// cls_pred|dis_pred|stride
{"transpose_0.tmp_0", "transpose_1.tmp_0", 8},
{"transpose_2.tmp_0", "transpose_3.tmp_0", 16},
{"transpose_4.tmp_0", "transpose_5.tmp_0", 32},
{"transpose_6.tmp_0", "transpose_7.tmp_0", 64},
};
modify to:
std::vector<HeadInfo> heads_info_{
// cls_pred|dis_pred|stride
{"save_infer_model/scale_0.tmp_1", "save_infer_model/scale_4.tmp_1", 8},
{"save_infer_model/scale_1.tmp_1", "save_infer_model/scale_5.tmp_1", 16},
{"save_infer_model/scale_2.tmp_1", "save_infer_model/scale_6.tmp_1", 32},
{"save_infer_model/scale_3.tmp_1", "save_infer_model/scale_7.tmp_1", 64},
};
```
3. you can view your onnx model with [Netron](https://netron.app/).
### Edit file ### Edit file
``` ```
step1: step1:
......
...@@ -74,7 +74,7 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img, ...@@ -74,7 +74,7 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img,
void KeyPointDetector::Postprocess(std::vector<float>& output, void KeyPointDetector::Postprocess(std::vector<float>& output,
std::vector<uint64_t>& output_shape, std::vector<uint64_t>& output_shape,
std::vector<int>& idxout, std::vector<float>& idxout,
std::vector<uint64_t>& idx_shape, std::vector<uint64_t>& idx_shape,
std::vector<KeyPointResult>* result, std::vector<KeyPointResult>* result,
std::vector<std::vector<float>>& center_bs, std::vector<std::vector<float>>& center_bs,
...@@ -141,7 +141,7 @@ void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -141,7 +141,7 @@ void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
infer_request_.Infer(); infer_request_.Infer();
InferenceEngine::Blob::Ptr output_blob = InferenceEngine::Blob::Ptr output_blob =
infer_request_.GetBlob("save_infer_model/scale_0.tmp_1"); infer_request_.GetBlob("conv2d_441.tmp_1");
auto output_shape = output_blob->getTensorDesc().getDims(); auto output_shape = output_blob->getTensorDesc().getDims();
InferenceEngine::MemoryBlob::Ptr moutput = InferenceEngine::MemoryBlob::Ptr moutput =
InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob); InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob);
...@@ -159,12 +159,15 @@ void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -159,12 +159,15 @@ void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
for (int j = 0; j < output_shape.size(); ++j) { for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j]; output_size *= output_shape[j];
} }
output_data_.resize(output_size); output_data_.resize(output_size);
std::copy_n(data, output_size, output_data_.data()); std::copy_n(data, output_size, output_data_.data());
} }
InferenceEngine::Blob::Ptr output_blob2 = InferenceEngine::Blob::Ptr output_blob2 =
infer_request_.GetBlob("save_infer_model/scale_1.tmp_1"); infer_request_.GetBlob("argmax_0.tmp_0");
auto idx_shape = output_blob2->getTensorDesc().getDims(); auto idx_shape = output_blob2->getTensorDesc().getDims();
InferenceEngine::MemoryBlob::Ptr moutput2 = InferenceEngine::MemoryBlob::Ptr moutput2 =
InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob2); InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob2);
...@@ -175,7 +178,7 @@ void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -175,7 +178,7 @@ void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
auto minputHolder = moutput2->rmap(); auto minputHolder = moutput2->rmap();
// Original I64 precision was converted to I32 // Original I64 precision was converted to I32
auto data = minputHolder.as<const InferenceEngine::PrecisionTrait< auto data = minputHolder.as<const InferenceEngine::PrecisionTrait<
InferenceEngine::Precision::I32>::value_type*>(); InferenceEngine::Precision::FP32>::value_type*>();
// Calculate output length // Calculate output length
int output_size = 1; int output_size = 1;
......
...@@ -69,7 +69,7 @@ class KeyPointDetector { ...@@ -69,7 +69,7 @@ class KeyPointDetector {
if (idx == 0) { if (idx == 0) {
output_info.second->setPrecision(InferenceEngine::Precision::FP32); output_info.second->setPrecision(InferenceEngine::Precision::FP32);
} else { } else {
output_info.second->setPrecision(InferenceEngine::Precision::I32); output_info.second->setPrecision(InferenceEngine::Precision::FP32);
} }
idx++; idx++;
} }
...@@ -99,14 +99,14 @@ class KeyPointDetector { ...@@ -99,14 +99,14 @@ class KeyPointDetector {
// Postprocess result // Postprocess result
void Postprocess(std::vector<float>& output, void Postprocess(std::vector<float>& output,
std::vector<uint64_t>& output_shape, std::vector<uint64_t>& output_shape,
std::vector<int>& idxout, std::vector<float>& idxout,
std::vector<uint64_t>& idx_shape, std::vector<uint64_t>& idx_shape,
std::vector<KeyPointResult>* result, std::vector<KeyPointResult>* result,
std::vector<std::vector<float>>& center, std::vector<std::vector<float>>& center,
std::vector<std::vector<float>>& scale); std::vector<std::vector<float>>& scale);
std::vector<float> output_data_; std::vector<float> output_data_;
std::vector<int> idx_data_; std::vector<float> idx_data_;
float threshold_; float threshold_;
bool use_dark_; bool use_dark_;
......
...@@ -74,11 +74,26 @@ void transform_preds(std::vector<float>& coords, ...@@ -74,11 +74,26 @@ void transform_preds(std::vector<float>& coords,
std::vector<float>& scale, std::vector<float>& scale,
std::vector<int>& output_size, std::vector<int>& output_size,
std::vector<uint64_t>& dim, std::vector<uint64_t>& dim,
std::vector<float>& target_coords) { std::vector<float>& target_coords,
cv::Mat trans(2, 3, CV_64FC1); bool affine=false) {
get_affine_transform(center, scale, 0, output_size, trans, 1); if (affine) {
for (int p = 0; p < dim[1]; ++p) { cv::Mat trans(2, 3, CV_64FC1);
affine_tranform(coords[p * 2], coords[p * 2 + 1], trans, target_coords, p); get_affine_transform(center, scale, 0, output_size, trans, 1);
for (int p = 0; p < dim[1]; ++p) {
affine_tranform(
coords[p * 2], coords[p * 2 + 1], trans, target_coords, p);
}
} else {
float heat_w = static_cast<float>(output_size[0]);
float heat_h = static_cast<float>(output_size[1]);
float x_scale = scale[0] / heat_w;
float y_scale = scale[1] / heat_h;
float offset_x = center[0] - scale[0] / 2.;
float offset_y = center[1] - scale[1] / 2.;
for (int i = 0; i < dim[1]; i++) {
target_coords[i * 3 + 1] = x_scale * coords[i * 2] + offset_x;
target_coords[i * 3 + 2] = y_scale * coords[i * 2 + 1] + offset_y;
}
} }
} }
...@@ -172,7 +187,7 @@ void dark_parse(std::vector<float>& heatmap, ...@@ -172,7 +187,7 @@ void dark_parse(std::vector<float>& heatmap,
void get_final_preds(std::vector<float>& heatmap, void get_final_preds(std::vector<float>& heatmap,
std::vector<uint64_t>& dim, std::vector<uint64_t>& dim,
std::vector<int>& idxout, std::vector<float>& idxout,
std::vector<uint64_t>& idxdim, std::vector<uint64_t>& idxdim,
std::vector<float>& center, std::vector<float>& center,
std::vector<float> scale, std::vector<float> scale,
...@@ -187,7 +202,7 @@ void get_final_preds(std::vector<float>& heatmap, ...@@ -187,7 +202,7 @@ void get_final_preds(std::vector<float>& heatmap,
for (int j = 0; j < dim[1]; ++j) { for (int j = 0; j < dim[1]; ++j) {
int index = (batchid * dim[1] + j) * dim[2] * dim[3]; int index = (batchid * dim[1] + j) * dim[2] * dim[3];
int idx = idxout[batchid * dim[1] + j]; int idx = int(idxout[batchid * dim[1] + j]);
preds[j * 3] = heatmap[index + idx]; preds[j * 3] = heatmap[index + idx];
coords[j * 2] = idx % heatmap_width; coords[j * 2] = idx % heatmap_width;
coords[j * 2 + 1] = idx / heatmap_width; coords[j * 2 + 1] = idx / heatmap_width;
......
...@@ -37,7 +37,8 @@ void transform_preds(std::vector<float>& coords, ...@@ -37,7 +37,8 @@ void transform_preds(std::vector<float>& coords,
std::vector<float>& scale, std::vector<float>& scale,
std::vector<uint64_t>& output_size, std::vector<uint64_t>& output_size,
std::vector<int>& dim, std::vector<int>& dim,
std::vector<float>& target_coords); std::vector<float>& target_coords,
bool affine);
void box_to_center_scale(std::vector<int>& box, void box_to_center_scale(std::vector<int>& box,
int width, int width,
int height, int height,
...@@ -51,7 +52,7 @@ void get_max_preds(std::vector<float>& heatmap, ...@@ -51,7 +52,7 @@ void get_max_preds(std::vector<float>& heatmap,
int joint_idx); int joint_idx);
void get_final_preds(std::vector<float>& heatmap, void get_final_preds(std::vector<float>& heatmap,
std::vector<uint64_t>& dim, std::vector<uint64_t>& dim,
std::vector<int>& idxout, std::vector<float>& idxout,
std::vector<uint64_t>& idxdim, std::vector<uint64_t>& idxdim,
std::vector<float>& center, std::vector<float>& center,
std::vector<float> scale, std::vector<float> scale,
......
...@@ -375,9 +375,9 @@ int main(int argc, char** argv) { ...@@ -375,9 +375,9 @@ int main(int argc, char** argv) {
return -1; return -1;
} }
std::cout << "start init model" << std::endl; std::cout << "start init model" << std::endl;
auto detector = PicoDet("../weight/picodet_m_416.xml"); auto detector = PicoDet("./weight/picodet_m_416.xml");
auto kpts_detector = auto kpts_detector =
new KeyPointDetector("../weight/tinypose256.xml", 256, 192); new KeyPointDetector("./weight/tinypose256_git2-sim.xml", 256, 192);
std::cout << "success" << std::endl; std::cout << "success" << std::endl;
int mode = atoi(argv[1]); int mode = atoi(argv[1]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册