未验证 提交 e226a78d 编写于 作者: Z zhiboniu 提交者: GitHub

Paddle-Infer cpp deploy support keypoint (#4211)

* add keypoint top-down cpp infer; total:5ms,pre:0.7ms,infer:3.8ms,post:0.01ms

* add dark_pose support
上级 9c655bad
......@@ -139,3 +139,4 @@ TestReader:
is_scale: true
- Permute: {}
batch_size: 1
fuse_normalize: false #whether to fuse nomalize layer into model while export model
......@@ -5,6 +5,7 @@ option(WITH_MKL "Compile demo with MKL/OpenBlas support,defaultuseMKL."
option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." ON)
option(WITH_TENSORRT "Compile demo with TensorRT." OFF)
option(WITH_KEYPOINT "Whether to Compile KeyPoint detector" ON)
SET(PADDLE_DIR "" CACHE PATH "Location of libraries")
SET(PADDLE_LIB_NAME "" CACHE STRING "libpaddle_inference")
......@@ -20,6 +21,12 @@ include_directories("${CMAKE_SOURCE_DIR}/")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/src/ext-yaml-cpp/include")
link_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/lib")
if (WITH_KEYPOINT)
set(SRCS src/main_keypoint.cc src/preprocess_op.cc src/object_detector.cc src/keypoint_detector.cc src/keypoint_postprocess.cc)
else ()
set(SRCS src/main.cc src/preprocess_op.cc src/object_detector.cc)
endif()
macro(safe_set_static_flag)
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
......@@ -37,7 +44,7 @@ endif()
if (NOT DEFINED PADDLE_DIR OR ${PADDLE_DIR} STREQUAL "")
message(FATAL_ERROR "please set PADDLE_DIR with -DPADDLE_DIR=/path/paddle_influence_dir")
endif()
message("PADDLE_DIR IS:"${PADDLE_DIR})
message("PADDLE_DIR IS:" ${PADDLE_DIR})
if (NOT DEFINED OPENCV_DIR OR ${OPENCV_DIR} STREQUAL "")
message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv")
......@@ -217,7 +224,7 @@ if (NOT WIN32)
endif()
set(DEPS ${DEPS} ${OpenCV_LIBS})
add_executable(main src/main.cc src/preprocess_op.cc src/object_detector.cc)
add_executable(main ${SRCS})
ADD_DEPENDENCIES(main ext-yaml-cpp)
message("DEPS:" $DEPS)
target_link_libraries(main ${DEPS})
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ctime>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "paddle_inference_api.h" // NOLINT
#include "include/config_parser.h"
#include "include/keypoint_postprocess.h"
#include "include/preprocess_op.h"
using namespace paddle_infer;
namespace PaddleDetection {
// Object KeyPoint Result
struct KeyPointResult {
// Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf
std::vector<float> keypoints;
int num_joints = -1;
};
// Visualiztion KeyPoint Result
cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap);
class KeyPointDetector {
public:
explicit KeyPointDetector(const std::string& model_dir,
const std::string& device = "CPU",
bool use_mkldnn = false,
int cpu_threads = 1,
const std::string& run_mode = "fluid",
const int batch_size = 1,
const int gpu_id = 0,
const int trt_min_shape = 1,
const int trt_max_shape = 1280,
const int trt_opt_shape = 640,
bool trt_calib_mode = false,
bool use_dark = true) {
this->device_ = device;
this->gpu_id_ = gpu_id;
this->cpu_math_library_num_threads_ = cpu_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_dark = use_dark;
this->trt_min_shape_ = trt_min_shape;
this->trt_max_shape_ = trt_max_shape;
this->trt_opt_shape_ = trt_opt_shape;
this->trt_calib_mode_ = trt_calib_mode;
config_.load_config(model_dir);
this->use_dynamic_shape_ = config_.use_dynamic_shape_;
this->min_subgraph_size_ = config_.min_subgraph_size_;
threshold_ = config_.draw_threshold_;
preprocessor_.Init(config_.preprocess_info_);
LoadModel(model_dir, batch_size, run_mode);
}
// Load Paddle inference model
void LoadModel(const std::string& model_dir,
const int batch_size = 1,
const std::string& run_mode = "fluid");
// Run predictor
void Predict(const std::vector<cv::Mat> imgs,
std::vector<std::vector<float>>& center,
std::vector<std::vector<float>>& scale,
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
std::vector<KeyPointResult>* result = nullptr,
std::vector<double>* times = nullptr);
// Get Model Label list
const std::vector<std::string>& GetLabelList() const {
return config_.label_list_;
}
private:
std::string device_ = "CPU";
int gpu_id_ = 0;
int cpu_math_library_num_threads_ = 1;
bool use_dark = true;
bool use_mkldnn_ = false;
int min_subgraph_size_ = 3;
bool use_dynamic_shape_ = false;
int trt_min_shape_ = 1;
int trt_max_shape_ = 1280;
int trt_opt_shape_ = 640;
bool trt_calib_mode_ = false;
// Preprocess image and copy data to input buffer
void Preprocess(const cv::Mat& image_mat);
// Postprocess result
void Postprocess(std::vector<float>& output,
std::vector<int> output_shape,
std::vector<int64_t>& idxout,
std::vector<int> idx_shape,
std::vector<KeyPointResult>* result,
std::vector<std::vector<float>>& center,
std::vector<std::vector<float>>& scale);
std::shared_ptr<Predictor> predictor_;
Preprocessor preprocessor_;
ImageBlob inputs_;
std::vector<float> output_data_;
std::vector<int64_t> idx_data_;
float threshold_;
ConfigPaser config_;
};
} // namespace PaddleDetection
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <vector>
std::vector<float> get_3rd_point(std::vector<float>& a, std::vector<float>& b);
std::vector<float> get_dir(float src_point_x, float src_point_y, float rot_rad);
void affine_tranform(
float pt_x, float pt_y, cv::Mat& trans, std::vector<float>& preds, int p);
cv::Mat get_affine_transform(std::vector<float>& center,
std::vector<float>& scale,
float rot,
std::vector<int>& output_size,
int inv);
void transform_preds(std::vector<float>& coords,
std::vector<float>& center,
std::vector<float>& scale,
std::vector<int>& output_size,
std::vector<int>& dim,
std::vector<float>& target_coords);
void box_to_center_scale(std::vector<int>& box,
int width,
int height,
std::vector<float>& center,
std::vector<float>& scale);
void get_max_preds(float* heatmap,
std::vector<int>& dim,
std::vector<float>& preds,
float* maxvals,
int batchid,
int joint_idx);
void get_final_preds(std::vector<float>& heatmap,
std::vector<int>& dim,
std::vector<int64_t>& idxout,
std::vector<int>& idxdim,
std::vector<float>& center,
std::vector<float> scale,
std::vector<float>& preds,
int batchid,
bool DARK = true);
......@@ -116,6 +116,21 @@ class PadStride : public PreprocessOp {
int stride_;
};
class TopDownEvalAffine : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item) {
trainsize_ = item["trainsize"].as<std::vector<int>>();
}
virtual void Run(cv::Mat* im, ImageBlob* data);
private:
int interp_ = 1;
std::vector<int> trainsize_;
};
void CropImg(cv::Mat &img, cv::Mat &crop_img, std::vector<int> &area, std::vector<float> &center, std::vector<float> &scale, float expandratio=0.15);
class Preprocessor {
public:
void Init(const YAML::Node& config_node) {
......@@ -139,6 +154,8 @@ class Preprocessor {
} else if (name == "PadStride") {
// use PadStride instead of PadBatch
return std::make_shared<PadStride>();
} else if (name == "TopDownEvalAffine") {
return std::make_shared<TopDownEvalAffine>();
}
std::cerr << "can not find function of OP: " << name << " and return: nullptr" << std::endl;
return nullptr;
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
// for setprecision
#include <chrono>
#include <iomanip>
#include "include/keypoint_detector.h"
using namespace paddle_infer;
namespace PaddleDetection {
// Load Model and create model predictor
void KeyPointDetector::LoadModel(const std::string& model_dir,
const int batch_size,
const std::string& run_mode) {
paddle_infer::Config config;
std::string prog_file = model_dir + OS_PATH_SEP + "model.pdmodel";
std::string params_file = model_dir + OS_PATH_SEP + "model.pdiparams";
config.SetModel(prog_file, params_file);
if (this->device_ == "GPU") {
config.EnableUseGpu(200, this->gpu_id_);
config.SwitchIrOptim(true);
// use tensorrt
if (run_mode != "fluid") {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (run_mode == "trt_fp32") {
precision = paddle_infer::Config::Precision::kFloat32;
} else if (run_mode == "trt_fp16") {
precision = paddle_infer::Config::Precision::kHalf;
} else if (run_mode == "trt_int8") {
precision = paddle_infer::Config::Precision::kInt8;
} else {
printf(
"run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'");
}
// set tensorrt
config.EnableTensorRtEngine(1 << 30,
batch_size,
this->min_subgraph_size_,
precision,
false,
this->trt_calib_mode_);
// set use dynamic shape
if (this->use_dynamic_shape_) {
// set DynamicShsape for image tensor
const std::vector<int> min_input_shape = {
1, 3, this->trt_min_shape_, this->trt_min_shape_};
const std::vector<int> max_input_shape = {
1, 3, this->trt_max_shape_, this->trt_max_shape_};
const std::vector<int> opt_input_shape = {
1, 3, this->trt_opt_shape_, this->trt_opt_shape_};
const std::map<std::string, std::vector<int>> map_min_input_shape = {
{"image", min_input_shape}};
const std::map<std::string, std::vector<int>> map_max_input_shape = {
{"image", max_input_shape}};
const std::map<std::string, std::vector<int>> map_opt_input_shape = {
{"image", opt_input_shape}};
config.SetTRTDynamicShapeInfo(
map_min_input_shape, map_max_input_shape, map_opt_input_shape);
std::cout << "TensorRT dynamic shape enabled" << std::endl;
}
}
} else if (this->device_ == "XPU") {
config.EnableXpu(10 * 1024 * 1024);
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
config.SwitchIrOptim(true);
config.DisableGlogInfo();
// Memory optimization
config.EnableMemoryOptim();
predictor_ = std::move(CreatePredictor(config));
}
// Visualiztion MaskDetector results
cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap) {
const int edge[][2] = {{0, 1},
{0, 2},
{1, 3},
{2, 4},
{3, 5},
{4, 6},
{5, 7},
{6, 8},
{7, 9},
{8, 10},
{5, 11},
{6, 12},
{11, 13},
{12, 14},
{13, 15},
{14, 16},
{11, 12}};
cv::Mat vis_img = img.clone();
for (int batchid = 0; batchid < results.size(); batchid++) {
for (int i = 0; i < results[batchid].num_joints; i++) {
if (results[batchid].keypoints[i * 3] > 0.5) {
int x_coord = int(results[batchid].keypoints[i * 3 + 1]);
int y_coord = int(results[batchid].keypoints[i * 3 + 2]);
cv::circle(vis_img,
cv::Point2d(x_coord, y_coord),
1,
cv::Scalar(0, 0, 255),
2);
}
}
for (int i = 0; i < results[batchid].num_joints; i++) {
int x_start = int(results[batchid].keypoints[edge[i][0] * 3 + 1]);
int y_start = int(results[batchid].keypoints[edge[i][0] * 3 + 2]);
int x_end = int(results[batchid].keypoints[edge[i][1] * 3 + 1]);
int y_end = int(results[batchid].keypoints[edge[i][1] * 3 + 2]);
cv::line(vis_img,
cv::Point2d(x_start, y_start),
cv::Point2d(x_end, y_end),
colormap[i],
1);
}
}
return vis_img;
}
void KeyPointDetector::Preprocess(const cv::Mat& ori_im) {
// Clone the image : keep the original mat for postprocess
cv::Mat im = ori_im.clone();
cv::cvtColor(im, im, cv::COLOR_BGR2RGB);
preprocessor_.Run(&im, &inputs_);
}
void KeyPointDetector::Postprocess(std::vector<float>& output,
std::vector<int> output_shape,
std::vector<int64_t>& idxout,
std::vector<int> idx_shape,
std::vector<KeyPointResult>* result,
std::vector<std::vector<float>>& center_bs,
std::vector<std::vector<float>>& scale_bs) {
std::vector<float> preds(output_shape[1] * 3, 0);
for (int batchid = 0; batchid < output_shape[0]; batchid++) {
get_final_preds(output,
output_shape,
idxout,
idx_shape,
center_bs[batchid],
scale_bs[batchid],
preds,
batchid,
this->use_dark);
KeyPointResult result_item;
result_item.num_joints = output_shape[1];
result_item.keypoints.clear();
for (int i = 0; i < output_shape[1]; i++) {
result_item.keypoints.emplace_back(preds[i * 3]);
result_item.keypoints.emplace_back(preds[i * 3 + 1]);
result_item.keypoints.emplace_back(preds[i * 3 + 2]);
}
result->push_back(result_item);
}
}
void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
std::vector<std::vector<float>>& center_bs,
std::vector<std::vector<float>>& scale_bs,
const double threshold,
const int warmup,
const int repeats,
std::vector<KeyPointResult>* result,
std::vector<double>* times) {
auto preprocess_start = std::chrono::steady_clock::now();
int batch_size = imgs.size();
// in_data_batch
std::vector<float> in_data_all;
std::vector<float> im_shape_all(batch_size * 2);
std::vector<float> scale_factor_all(batch_size * 2);
// Preprocess image
for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat im = imgs.at(bs_idx);
Preprocess(im);
im_shape_all[bs_idx * 2] = inputs_.im_shape_[0];
im_shape_all[bs_idx * 2 + 1] = inputs_.im_shape_[1];
scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0];
scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1];
// TODO: reduce cost time
in_data_all.insert(
in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
}
// Prepare input tensor
auto input_names = predictor_->GetInputNames();
for (const auto& tensor_name : input_names) {
auto in_tensor = predictor_->GetInputHandle(tensor_name);
if (tensor_name == "image") {
int rh = inputs_.in_net_shape_[0];
int rw = inputs_.in_net_shape_[1];
in_tensor->Reshape({batch_size, 3, rh, rw});
in_tensor->CopyFromCpu(in_data_all.data());
} else if (tensor_name == "im_shape") {
in_tensor->Reshape({batch_size, 2});
in_tensor->CopyFromCpu(im_shape_all.data());
} else if (tensor_name == "scale_factor") {
in_tensor->Reshape({batch_size, 2});
in_tensor->CopyFromCpu(scale_factor_all.data());
}
}
auto preprocess_end = std::chrono::steady_clock::now();
std::vector<int> output_shape, idx_shape;
// Run predictor
// warmup
for (int i = 0; i < warmup; i++) {
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
auto out_tensor = predictor_->GetOutputHandle(output_names[0]);
output_shape = out_tensor->shape();
// Calculate output length
int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
}
output_data_.resize(output_size);
out_tensor->CopyToCpu(output_data_.data());
auto idx_tensor = predictor_->GetOutputHandle(output_names[1]);
idx_shape = idx_tensor->shape();
// Calculate output length
output_size = 1;
for (int j = 0; j < idx_shape.size(); ++j) {
output_size *= idx_shape[j];
}
idx_data_.resize(output_size);
idx_tensor->CopyToCpu(idx_data_.data());
}
auto inference_start = std::chrono::steady_clock::now();
for (int i = 0; i < repeats; i++) {
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
auto out_tensor = predictor_->GetOutputHandle(output_names[0]);
output_shape = out_tensor->shape();
// Calculate output length
int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
}
if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl;
}
output_data_.resize(output_size);
out_tensor->CopyToCpu(output_data_.data());
auto idx_tensor = predictor_->GetOutputHandle(output_names[1]);
idx_shape = idx_tensor->shape();
// Calculate output length
output_size = 1;
for (int j = 0; j < idx_shape.size(); ++j) {
output_size *= idx_shape[j];
}
idx_data_.resize(output_size);
idx_tensor->CopyToCpu(idx_data_.data());
}
auto inference_end = std::chrono::steady_clock::now();
auto postprocess_start = std::chrono::steady_clock::now();
// Postprocessing result
Postprocess(output_data_,
output_shape,
idx_data_,
idx_shape,
result,
center_bs,
scale_bs);
auto postprocess_end = std::chrono::steady_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = inference_end - inference_start;
times->push_back(double(inference_diff.count() / repeats * 1000));
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
}
} // namespace PaddleDetection
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include "include/keypoint_postprocess.h"
#define PI 3.1415926535
#define HALF_CIRCLE_DEGREE 180
cv::Point2f get_3rd_point(cv::Point2f& a, cv::Point2f& b) {
cv::Point2f direct{a.x - b.x, a.y - b.y};
return cv::Point2f(a.x - direct.y, a.y + direct.x);
}
std::vector<float> get_dir(float src_point_x,
float src_point_y,
float rot_rad) {
float sn = sin(rot_rad);
float cs = cos(rot_rad);
std::vector<float> src_result{0.0, 0.0};
src_result[0] = src_point_x * cs - src_point_y * sn;
src_result[1] = src_point_x * sn + src_point_y * cs;
return src_result;
}
void affine_tranform(
float pt_x, float pt_y, cv::Mat& trans, std::vector<float>& preds, int p) {
double new1[3] = {pt_x, pt_y, 1.0};
cv::Mat new_pt(3, 1, trans.type(), new1);
cv::Mat w = trans * new_pt;
preds[p * 3 + 1] = static_cast<float>(w.at<double>(0, 0));
preds[p * 3 + 2] = static_cast<float>(w.at<double>(1, 0));
}
void get_affine_transform(std::vector<float>& center,
std::vector<float>& scale,
float rot,
std::vector<int>& output_size,
cv::Mat& trans,
int inv) {
float src_w = scale[0];
float dst_w = static_cast<float>(output_size[0]);
float dst_h = static_cast<float>(output_size[1]);
float rot_rad = rot * PI / HALF_CIRCLE_DEGREE;
std::vector<float> src_dir = get_dir(-0.5 * src_w, 0, rot_rad);
std::vector<float> dst_dir{-0.5 * dst_w, 0.0};
cv::Point2f srcPoint2f[3], dstPoint2f[3];
srcPoint2f[0] = cv::Point2f(center[0], center[1]);
srcPoint2f[1] = cv::Point2f(center[0] + src_dir[0], center[1] + src_dir[1]);
srcPoint2f[2] = get_3rd_point(srcPoint2f[0], srcPoint2f[1]);
dstPoint2f[0] = cv::Point2f(dst_w * 0.5, dst_h * 0.5);
dstPoint2f[1] =
cv::Point2f(dst_w * 0.5 + dst_dir[0], dst_h * 0.5 + dst_dir[1]);
dstPoint2f[2] = get_3rd_point(dstPoint2f[0], dstPoint2f[1]);
if (inv == 0) {
trans = cv::getAffineTransform(srcPoint2f, dstPoint2f);
} else {
trans = cv::getAffineTransform(dstPoint2f, srcPoint2f);
}
}
void transform_preds(std::vector<float>& coords,
std::vector<float>& center,
std::vector<float>& scale,
std::vector<int>& output_size,
std::vector<int>& dim,
std::vector<float>& target_coords) {
cv::Mat trans(2, 3, CV_64FC1);
get_affine_transform(center, scale, 0, output_size, trans, 1);
for (int p = 0; p < dim[1]; ++p) {
affine_tranform(coords[p * 2], coords[p * 2 + 1], trans, target_coords, p);
}
}
// only for batchsize == 1
void get_max_preds(float* heatmap,
std::vector<int>& dim,
std::vector<float>& preds,
float* maxvals,
int batchid,
int joint_idx) {
int num_joints = dim[1];
int width = dim[3];
std::vector<int> idx;
idx.resize(num_joints * 2);
for (int j = 0; j < dim[1]; j++) {
float* index = &(
heatmap[batchid * num_joints * dim[2] * dim[3] + j * dim[2] * dim[3]]);
float* end = index + dim[2] * dim[3];
float* max_dis = std::max_element(index, end);
auto max_id = std::distance(index, max_dis);
maxvals[j] = *max_dis;
if (*max_dis > 0) {
preds[j * 2] = static_cast<float>(max_id % width);
preds[j * 2 + 1] = static_cast<float>(max_id / width);
}
}
}
void dark_parse(std::vector<float>& heatmap,
std::vector<int>& dim,
std::vector<float>& coords,
int px,
int py,
int index,
int ch){
/*DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
1) offset = - hassian.inv() * derivative
2) dx = (heatmap[x+1] - heatmap[x-1])/2.
3) dxx = (dx[x+1] - dx[x-1])/2.
4) derivative = Mat([dx, dy])
5) hassian = Mat([[dxx, dxy], [dxy, dyy]])
*/
std::vector<float>::const_iterator first1 = heatmap.begin() + index;
std::vector<float>::const_iterator last1 = heatmap.begin() + index + dim[2]*dim[3];
std::vector<float> heatmap_ch(first1, last1);
cv::Mat heatmap_mat{heatmap_ch};
heatmap_mat.resize(dim[2],dim[3]);
cv::GaussianBlur(heatmap_mat, heatmap_mat, cv::Size(3,3), 0, 0);
heatmap_ch.assign(heatmap_mat.datastart, heatmap_mat.dataend);
float epsilon = 1e-10;
//sample heatmap to get values in around target location
float xy = log(fmax(heatmap_ch[py * dim[3] + px], epsilon));
float xr = log(fmax(heatmap_ch[py * dim[3] + px + 1], epsilon));
float xl = log(fmax(heatmap_ch[py * dim[3] + px - 1], epsilon));
float xr2 = log(fmax(heatmap_ch[py * dim[3] + px + 2], epsilon));
float xl2 = log(fmax(heatmap_ch[py * dim[3] + px - 2], epsilon));
float yu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px], epsilon));
float yd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px], epsilon));
float yu2 = log(fmax(heatmap_ch[(py + 2) * dim[3] + px], epsilon));
float yd2 = log(fmax(heatmap_ch[(py - 2) * dim[3] + px], epsilon));
float xryu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px + 1], epsilon));
float xryd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px + 1], epsilon));
float xlyu = log(fmax(heatmap_ch[(py + 1) * dim[3] + px - 1], epsilon));
float xlyd = log(fmax(heatmap_ch[(py - 1) * dim[3] + px - 1], epsilon));
//compute dx/dy and dxx/dyy with sampled values
float dx = 0.5 * (xr - xl);
float dy = 0.5 * (yu - yd);
float dxx = 0.25 * (xr2 - 2*xy + xl2);
float dxy = 0.25 * (xryu - xryd - xlyu + xlyd);
float dyy = 0.25 * (yu2 - 2*xy + yd2);
//finally get offset by derivative and hassian, which combined by dx/dy and dxx/dyy
if(dxx * dyy - dxy*dxy != 0){
float M[2][2] = {dxx, dxy, dxy, dyy};
float D[2] = {dx, dy};
cv::Mat hassian(2,2,CV_32F,M);
cv::Mat derivative(2,1,CV_32F,D);
cv::Mat offset = - hassian.inv() * derivative;
coords[ch * 2] += offset.at<float>(0,0);
coords[ch * 2 + 1] += offset.at<float>(1,0);
}
}
void get_final_preds(std::vector<float>& heatmap,
std::vector<int>& dim,
std::vector<int64_t>& idxout,
std::vector<int>& idxdim,
std::vector<float>& center,
std::vector<float> scale,
std::vector<float>& preds,
int batchid,
bool DARK) {
std::vector<float> coords;
coords.resize(dim[1] * 2);
int heatmap_height = dim[2];
int heatmap_width = dim[3];
for (int j = 0; j < dim[1]; ++j) {
int index = (batchid * dim[1] + j) * dim[2] * dim[3];
int idx = idxout[batchid * dim[1] + j];
preds[j * 3] = heatmap[index + idx];
coords[j * 2] = idx % heatmap_width;
coords[j * 2 + 1] = idx / heatmap_width;
int px = int(coords[j * 2] + 0.5);
int py = int(coords[j * 2 + 1] + 0.5);
if(DARK && px > 1 && px < heatmap_width - 2){
dark_parse(heatmap, dim, coords, px, py, index, j);
}
else{
if (px > 0 && px < heatmap_width - 1) {
float diff_x = heatmap[index + py * dim[3] + px + 1] -
heatmap[index + py * dim[3] + px - 1];
coords[j * 2] += diff_x > 0 ? 1 : -1 * 0.25;
}
if (py > 0 && py < heatmap_height - 1) {
float diff_y = heatmap[index + (py + 1) * dim[3] + px] -
heatmap[index + (py - 1) * dim[3] + px];
coords[j * 2 + 1] += diff_y > 0 ? 1 : -1 * 0.25;
}
}
}
std::vector<int> img_size{heatmap_width, heatmap_height};
transform_preds(coords, center, scale, img_size, dim, preds);
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <iostream>
#include <string>
#include <vector>
#include <numeric>
#include <sys/types.h>
#include <sys/stat.h>
#include <math.h>
#include <algorithm>
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#elif LINUX
#include <stdarg.h>
#include <sys/stat.h>
#endif
#include "include/object_detector.h"
#include "include/keypoint_detector.h"
#include "include/preprocess_op.h"
#include <gflags/gflags.h>
DEFINE_string(model_dir, "", "Path of object detector inference model");
DEFINE_string(model_dir_keypoint, "", "Path of keypoint detector inference model");
DEFINE_string(image_file, "", "Path of input image");
DEFINE_string(image_dir, "", "Dir of input image, `image_file` has a higher priority.");
DEFINE_int32(batch_size, 1, "batch_size of object detector");
DEFINE_int32(batch_size_keypoint, 8, "batch_size of keypoint detector");
DEFINE_string(video_file, "", "Path of input video, `video_file` or `camera_id` has a highest priority.");
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(use_gpu, false, "Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device, "CPU", "Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU.");
DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_double(threshold_keypoint, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16/trt_int8)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark");
DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU");
DEFINE_int32(cpu_threads, 1, "Num of threads with CPU");
DEFINE_int32(trt_min_shape, 1, "Min shape of TRT DynamicShapeI");
DEFINE_int32(trt_max_shape, 1280, "Max shape of TRT DynamicShapeI");
DEFINE_int32(trt_opt_shape, 640, "Opt shape of TRT DynamicShapeI");
DEFINE_bool(trt_calib_mode, false, "If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True");
DEFINE_bool(use_dark, true, "Whether use dark decode in keypoint postprocess");
void PrintBenchmarkLog(std::vector<double> det_time, int img_num){
LOG(INFO) << "----------------------- Config info -----------------------";
LOG(INFO) << "runtime_device: " << FLAGS_device;
LOG(INFO) << "ir_optim: " << "True";
LOG(INFO) << "enable_memory_optim: " << "True";
int has_trt = FLAGS_run_mode.find("trt");
if (has_trt >= 0) {
LOG(INFO) << "enable_tensorrt: " << "True";
std::string precision = FLAGS_run_mode.substr(4, 8);
LOG(INFO) << "precision: " << precision;
} else {
LOG(INFO) << "enable_tensorrt: " << "False";
LOG(INFO) << "precision: " << "fp32";
}
LOG(INFO) << "enable_mkldnn: " << (FLAGS_use_mkldnn ? "True" : "False");
LOG(INFO) << "cpu_math_library_num_threads: " << FLAGS_cpu_threads;
LOG(INFO) << "----------------------- Data info -----------------------";
LOG(INFO) << "batch_size: " << FLAGS_batch_size;
LOG(INFO) << "batch_size_keypoint: " << FLAGS_batch_size_keypoint;
LOG(INFO) << "input_shape: " << "dynamic shape";
LOG(INFO) << "----------------------- Model info -----------------------";
FLAGS_model_dir.erase(FLAGS_model_dir.find_last_not_of("/") + 1);
LOG(INFO) << "model_name: " << FLAGS_model_dir.substr(FLAGS_model_dir.find_last_of('/') + 1);
FLAGS_model_dir_keypoint.erase(FLAGS_model_dir_keypoint.find_last_not_of("/") + 1);
LOG(INFO) << "model_name: " << FLAGS_model_dir_keypoint.substr(FLAGS_model_dir_keypoint.find_last_of('/') + 1);
LOG(INFO) << "----------------------- Perf info ------------------------";
LOG(INFO) << "Total number of predicted data: " << img_num
<< " and total time spent(ms): "
<< std::accumulate(det_time.begin(), det_time.end(), 0);
LOG(INFO) << "preproce_time(ms): " << det_time[0] / img_num
<< ", inference_time(ms): " << det_time[1] / img_num
<< ", postprocess_time(ms): " << det_time[2] / img_num;
}
static std::string DirName(const std::string &filepath) {
auto pos = filepath.rfind(OS_PATH_SEP);
if (pos == std::string::npos) {
return "";
}
return filepath.substr(0, pos);
}
static bool PathExists(const std::string& path){
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
#else
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0);
#endif // !_WIN32
}
static void MkDir(const std::string& path) {
if (PathExists(path)) return;
int ret = 0;
#ifdef _WIN32
ret = _mkdir(path.c_str());
#else
ret = mkdir(path.c_str(), 0755);
#endif // !_WIN32
if (ret != 0) {
std::string path_error(path);
path_error += " mkdir failed!";
throw std::runtime_error(path_error);
}
}
static void MkDirs(const std::string& path) {
if (path.empty()) return;
if (PathExists(path)) return;
MkDirs(DirName(path));
MkDir(path);
}
void PredictVideo(const std::string& video_path,
PaddleDetection::ObjectDetector* det,
PaddleDetection::KeyPointDetector* keypoint) {
// Open video
cv::VideoCapture capture;
if (FLAGS_camera_id != -1){
capture.open(FLAGS_camera_id);
}else{
capture.open(video_path.c_str());
}
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
return;
}
// Get Video info : resolution, fps
int video_width = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_WIDTH));
int video_height = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_HEIGHT));
int video_fps = static_cast<int>(capture.get(CV_CAP_PROP_FPS));
// Create VideoWriter for output
cv::VideoWriter video_out;
std::string video_out_path = "output.mp4";
video_out.open(video_out_path.c_str(),
0x00000021,
video_fps,
cv::Size(video_width, video_height),
true);
if (!video_out.isOpened()) {
printf("create video writer failed!\n");
return;
}
std::vector<PaddleDetection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times;
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
// Store keypoint results
std::vector<PaddleDetection::KeyPointResult> result_kpts;
std::vector<cv::Mat> imgs_kpts;
std::vector<std::vector<float>> center_bs;
std::vector<std::vector<float>> scale_bs;
std::vector<int> colormap_kpts = PaddleDetection::GenerateColorMap(20);
// Capture all frames and do inference
cv::Mat frame;
int frame_id = 0;
bool is_rbox = false;
while (capture.read(frame)) {
if (frame.empty()) {
break;
}
std::vector<cv::Mat> imgs;
imgs.push_back(frame);
det->Predict(imgs, 0.5, 0, 1, &result, &bbox_num, &det_times);
for (const auto& item : result) {
if (item.rect.size() > 6){
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
}
else{
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
}
if(keypoint)
{
int imsize = result.size();
for (int i=0; i<imsize; i++){
auto item = result[i];
cv::Mat crop_img;
std::vector<double> keypoint_times;
std::vector<int> rect = {item.rect[0], item.rect[1], item.rect[2], item.rect[3]};
std::vector<float> center;
std::vector<float> scale;
if(item.class_id == 0)
{
PaddleDetection::CropImg(frame, crop_img, rect, center, scale);
center_bs.emplace_back(center);
scale_bs.emplace_back(scale);
imgs_kpts.emplace_back(crop_img);
}
if (imgs_kpts.size()==FLAGS_batch_size_keypoint || ((i==imsize-1)&&!imgs_kpts.empty()))
{
keypoint->Predict(imgs_kpts, center_bs, scale_bs, 0.5, 0, 1, &result_kpts, &keypoint_times);
imgs_kpts.clear();
center_bs.clear();
scale_bs.clear();
}
}
cv::Mat out_im = VisualizeKptsResult(frame, result_kpts, colormap_kpts);
video_out.write(out_im);
}
else{
// Visualization result
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap, is_rbox);
video_out.write(out_im);
}
frame_id += 1;
}
capture.release();
video_out.release();
}
void PredictImage(const std::vector<std::string> all_img_paths,
const int batch_size,
const double threshold,
const bool run_benchmark,
PaddleDetection::ObjectDetector* det,
PaddleDetection::KeyPointDetector* keypoint,
const std::string& output_dir = "output") {
std::vector<double> det_t = {0, 0, 0};
int steps = ceil(float(all_img_paths.size()) / batch_size);
int kpts_imgs = 0;
std::vector<double> keypoint_t = {0, 0, 0};
printf("total images = %d, batch_size = %d, total steps = %d\n",
all_img_paths.size(), batch_size, steps);
for (int idx = 0; idx < steps; idx++) {
std::vector<cv::Mat> batch_imgs;
int left_image_cnt = all_img_paths.size() - idx * batch_size;
if (left_image_cnt > batch_size) {
left_image_cnt = batch_size;
}
for (int bs = 0; bs < left_image_cnt; bs++) {
std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
cv::Mat im = cv::imread(image_file_path, 1);
batch_imgs.insert(batch_imgs.end(), im);
}
// Store all detected result
std::vector<PaddleDetection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times;
// Store keypoint results
std::vector<PaddleDetection::KeyPointResult> result_kpts;
std::vector<cv::Mat> imgs_kpts;
std::vector<std::vector<float>> center_bs;
std::vector<std::vector<float>> scale_bs;
std::vector<int> colormap_kpts = PaddleDetection::GenerateColorMap(20);
bool is_rbox = false;
if (run_benchmark) {
det->Predict(batch_imgs, threshold, 10, 10, &result, &bbox_num, &det_times);
} else {
det->Predict(batch_imgs, 0.5, 10, 10, &result, &bbox_num, &det_times);
}
// get labels and colormap
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
int item_start_idx = 0;
for (int i = 0; i < left_image_cnt; i++) {
cv::Mat im = batch_imgs[i];
std::vector<PaddleDetection::ObjectResult> im_result;
int detect_num = 0;
for (int j = 0; j < bbox_num[i]; j++) {
PaddleDetection::ObjectResult item = result[item_start_idx + j];
if (item.confidence < threshold || item.class_id == -1) {
continue;
}
detect_num += 1;
im_result.push_back(item);
if (item.rect.size() > 6){
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
}
else{
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
}
std::cout << all_img_paths.at(idx * batch_size + i) << " The number of detected box: " << detect_num << std::endl;
item_start_idx = item_start_idx + bbox_num[i];
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
std::string output_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
output_path += OS_PATH_SEP;
}
std::string image_file_path = all_img_paths.at(idx * batch_size + i);
if(keypoint)
{
int imsize = im_result.size();
for (int i=0; i<imsize; i++){
auto item = im_result[i];
cv::Mat crop_img;
std::vector<double> keypoint_times;
std::vector<int> rect = {item.rect[0], item.rect[1], item.rect[2], item.rect[3]};
std::vector<float> center;
std::vector<float> scale;
if(item.class_id == 0)
{
PaddleDetection::CropImg(im, crop_img, rect, center, scale);
center_bs.emplace_back(center);
scale_bs.emplace_back(scale);
imgs_kpts.emplace_back(crop_img);
kpts_imgs += 1;
}
if (imgs_kpts.size()==FLAGS_batch_size_keypoint || ((i==imsize-1)&&!imgs_kpts.empty()))
{
if (run_benchmark) {
keypoint->Predict(imgs_kpts, center_bs, scale_bs, 0.5, 10, 10, &result_kpts, &keypoint_times);
}
else{
keypoint->Predict(imgs_kpts, center_bs, scale_bs, 0.5, 0, 1, &result_kpts, &keypoint_times);
}
imgs_kpts.clear();
center_bs.clear();
scale_bs.clear();
keypoint_t[0] += keypoint_times[0];
keypoint_t[1] += keypoint_times[1];
keypoint_t[2] += keypoint_times[2];
}
}
std::string kpts_savepath = output_path + "keypoint_" + image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::Mat kpts_vis_img = VisualizeKptsResult(im, result_kpts, colormap_kpts);
cv::imwrite(kpts_savepath, kpts_vis_img, compression_params);
printf("Visualized output saved as %s\n", kpts_savepath.c_str());
}
else{
// Visualization result
cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, im_result, labels, colormap, is_rbox);
std::string det_savepath = output_path + image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::imwrite(det_savepath, vis_img, compression_params);
printf("Visualized output saved as %s\n", det_savepath.c_str());
}
}
det_t[0] += det_times[0];
det_t[1] += det_times[1];
det_t[2] += det_times[2];
}
PrintBenchmarkLog(det_t, all_img_paths.size());
PrintBenchmarkLog(keypoint_t, kpts_imgs);
}
int main(int argc, char** argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir.empty()
|| (FLAGS_image_file.empty() && FLAGS_image_dir.empty() && FLAGS_video_file.empty())) {
std::cout << "Usage: ./main --model_dir=/PATH/TO/INFERENCE_MODEL/ (--model_dir_keypoint=/PATH/TO/INFERENCE_MODEL/)"
<< "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl;
return -1;
}
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32"
|| FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout << "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
return -1;
}
transform(FLAGS_device.begin(),FLAGS_device.end(),FLAGS_device.begin(),::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" || FLAGS_device == "XPU")) {
std::cout << "device should be 'CPU', 'GPU' or 'XPU'.";
return -1;
}
if (FLAGS_use_gpu) {
std::cout << "Deprecated, please use `--device` to set the device you want to run.";
return -1;
}
// Load model and create a object detector
PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_device, FLAGS_use_mkldnn,
FLAGS_cpu_threads, FLAGS_run_mode, FLAGS_batch_size,FLAGS_gpu_id,
FLAGS_trt_min_shape, FLAGS_trt_max_shape, FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode);
PaddleDetection::KeyPointDetector* keypoint = nullptr;
if (!FLAGS_model_dir_keypoint.empty())
{
keypoint = new PaddleDetection::KeyPointDetector(FLAGS_model_dir_keypoint, FLAGS_device, FLAGS_use_mkldnn,
FLAGS_cpu_threads, FLAGS_run_mode, FLAGS_batch_size,FLAGS_gpu_id,
FLAGS_trt_min_shape, FLAGS_trt_max_shape, FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode, FLAGS_use_dark);
}
// Do inference on input video or image
if (!FLAGS_video_file.empty() || FLAGS_camera_id != -1) {
PredictVideo(FLAGS_video_file, &det, keypoint);
} else if (!FLAGS_image_file.empty() || !FLAGS_image_dir.empty()) {
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
std::vector<std::string> all_img_paths;
std::vector<cv::String> cv_all_img_paths;
if (!FLAGS_image_file.empty()) {
all_img_paths.push_back(FLAGS_image_file);
if (FLAGS_batch_size > 1) {
std::cout << "batch_size should be 1, when set `image_file`." << std::endl;
return -1;
}
} else {
cv::glob(FLAGS_image_dir, cv_all_img_paths);
for (const auto & img_path : cv_all_img_paths) {
all_img_paths.push_back(img_path);
}
}
PredictImage(all_img_paths, FLAGS_batch_size, FLAGS_threshold,
FLAGS_run_benchmark, &det, keypoint, FLAGS_output_dir);
}
delete keypoint;
keypoint = nullptr;
return 0;
}
......@@ -256,9 +256,9 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
// TODO: reduce cost time
in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
}
auto preprocess_end = std::chrono::steady_clock::now();
// Prepare input tensor
auto input_names = predictor_->GetInputNames();
for (const auto& tensor_name : input_names) {
auto in_tensor = predictor_->GetInputHandle(tensor_name);
......@@ -276,7 +276,6 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
}
}
auto preprocess_end = std::chrono::steady_clock::now();
// Run predictor
// warmup
for (int i = 0; i < warmup; i++)
......
......@@ -14,6 +14,7 @@
#include <vector>
#include <string>
#include <thread>
#include "include/preprocess_op.h"
......@@ -50,6 +51,7 @@ void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) {
}
void Permute::Run(cv::Mat* im, ImageBlob* data) {
(*im).convertTo(*im, CV_32FC3);
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
......@@ -131,10 +133,19 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) {
};
}
void TopDownEvalAffine::Run(cv::Mat* im, ImageBlob* data) {
cv::resize(
*im, *im, cv::Size(trainsize_[0],trainsize_[1]), 0, 0, interp_);
// todo: Simd::ResizeBilinear();
data->in_net_shape_ = {
static_cast<float>(trainsize_[1]),
static_cast<float>(trainsize_[0]),
};
}
// Preprocessor op running order
const std::vector<std::string> Preprocessor::RUN_ORDER = {
"InitInfo", "Resize", "NormalizeImage", "PadStride", "Permute"
"InitInfo", "TopDownEvalAffine", "Resize", "NormalizeImage", "PadStride", "Permute"
};
void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
......@@ -145,4 +156,37 @@ void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
}
}
void CropImg(cv::Mat &img, cv::Mat &crop_img, std::vector<int> &area, std::vector<float> &center, std::vector<float> &scale, float expandratio) {
int crop_x1 = std::max(0, area[0]);
int crop_y1 = std::max(0, area[1]);
int crop_x2 = std::min(img.cols -1, area[2]);
int crop_y2 = std::min(img.rows - 1, area[3]);
int center_x = (crop_x1 + crop_x2)/2.;
int center_y = (crop_y1 + crop_y2)/2.;
int half_h = (crop_y2 - crop_y1)/2.;
int half_w = (crop_x2 - crop_x1)/2.;
//adjust h or w to keep image ratio, expand the shorter edge
if (half_h*3 > half_w*4){
half_w = static_cast<int>(half_h*0.75);
}
else{
half_h = static_cast<int>(half_w*4/3);
}
crop_x1 = std::max(0, center_x - static_cast<int>(half_w*(1+expandratio)));
crop_y1 = std::max(0, center_y - static_cast<int>(half_h*(1+expandratio)));
crop_x2 = std::min(img.cols -1, static_cast<int>(center_x + half_w*(1+expandratio)));
crop_y2 = std::min(img.rows - 1, static_cast<int>(center_y + half_h*(1+expandratio)));
crop_img = img(cv::Range(crop_y1, crop_y2+1), cv::Range(crop_x1, crop_x2 + 1));
center.clear();
center.emplace_back((crop_x1+crop_x2)/2);
center.emplace_back((crop_y1+crop_y2)/2);
scale.clear();
scale.emplace_back((crop_x2-crop_x1));
scale.emplace_back((crop_y2-crop_y1));
}
} // namespace PaddleDetection
......@@ -59,6 +59,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
label_list = [str(cat) for cat in catid2name.values()]
fuse_normalize = reader_cfg.get('fuse_normalize', False)
sample_transforms = reader_cfg['sample_transforms']
for st in sample_transforms[1:]:
for key, value in st.items():
......@@ -66,6 +67,8 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
if key == 'Resize':
if int(image_shape[1]) != -1:
value['target_size'] = image_shape[1:]
if fuse_normalize and key == 'NormalizeImage':
continue
p.update(value)
preprocess_list.append(p)
batch_transforms = reader_cfg.get('batch_transforms', None)
......
......@@ -88,6 +88,9 @@ class Trainer(object):
self.model = self.cfg.model
self.is_loaded_weights = True
#normalize params for deploy
self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998)
......@@ -552,7 +555,11 @@ class Trainer(object):
if image_shape is None:
image_shape = [3, -1, -1]
if hasattr(self.model, 'deploy'): self.model.deploy = True
if hasattr(self.model, 'deploy'):
self.model.deploy = True
if hasattr(self.model, 'fuse_norm'):
self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
False)
if hasattr(self.cfg, 'lite_deploy'):
self.model.lite_deploy = self.cfg.lite_deploy
......
......@@ -76,7 +76,12 @@ class TopDownHRNet(BaseArch):
if self.training:
return self.loss(hrnet_outputs, self.inputs)
elif self.deploy:
return hrnet_outputs
outshape = hrnet_outputs.shape
max_idx = paddle.argmax(
hrnet_outputs.reshape(
(outshape[0], outshape[1], outshape[2] * outshape[3])),
axis=-1)
return hrnet_outputs, max_idx
else:
if self.flip:
self.inputs['image'] = self.inputs['image'].flip([3])
......@@ -199,6 +204,10 @@ class HRNetPostProcess(object):
return coord
def dark_postprocess(self, hm, coords, kernelsize):
'''DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
'''
hm = self.gaussian_blur(hm, kernelsize)
hm = np.maximum(hm, 1e-10)
hm = np.log(hm)
......
......@@ -14,12 +14,40 @@ class BaseArch(nn.Layer):
def __init__(self, data_format='NCHW'):
super(BaseArch, self).__init__()
self.data_format = data_format
self.inputs = {}
self.fuse_norm = False
def load_meanstd(self, cfg_transform):
self.scale = 1.
self.mean = paddle.to_tensor([0.485, 0.456, 0.406]).reshape(
(1, 3, 1, 1))
self.std = paddle.to_tensor([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))
for item in cfg_transform:
if 'NormalizeImage' in item:
self.mean = paddle.to_tensor(item['NormalizeImage'][
'mean']).reshape((1, 3, 1, 1))
self.std = paddle.to_tensor(item['NormalizeImage'][
'std']).reshape((1, 3, 1, 1))
if item['NormalizeImage']['is_scale']:
self.scale = 1. / 255.
break
if self.data_format == 'NHWC':
self.mean = self.mean.reshape(1, 1, 1, 3)
self.std = self.std.reshape(1, 1, 1, 3)
def forward(self, inputs):
if self.data_format == 'NHWC':
image = inputs['image']
inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
self.inputs = inputs
if self.fuse_norm:
image = inputs['image']
self.inputs['image'] = (image * self.scale - self.mean) / self.std
self.inputs['im_shape'] = inputs['im_shape']
self.inputs['scale_factor'] = inputs['scale_factor']
else:
self.inputs = inputs
self.model_arch()
if self.training:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册