From cb6edab61887023cbbb453cc90351f7f1bb73cc1 Mon Sep 17 00:00:00 2001 From: jack <136876878@qq.com> Date: Mon, 15 Jun 2020 14:44:08 +0800 Subject: [PATCH] add batch predict in classification task --- deploy/cpp/demo/classifier.cpp | 59 +++++++++++++++--- deploy/cpp/include/paddlex/paddlex.h | 13 +++- deploy/cpp/scripts/build.sh | 6 +- deploy/cpp/src/paddlex.cpp | 89 ++++++++++++++++++++++++++-- 4 files changed, 146 insertions(+), 21 deletions(-) diff --git a/deploy/cpp/demo/classifier.cpp b/deploy/cpp/demo/classifier.cpp index badb835..4393043 100644 --- a/deploy/cpp/demo/classifier.cpp +++ b/deploy/cpp/demo/classifier.cpp @@ -14,13 +14,17 @@ #include +#include +#include #include #include #include #include - +#include #include "include/paddlex/paddlex.h" +using namespace std::chrono; + DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_trt, false, "Infering with TensorRT"); @@ -28,6 +32,7 @@ DEFINE_int32(gpu_id, 0, "GPU card id"); DEFINE_string(key, "", "key of encryption"); DEFINE_string(image, "", "Path of test image file"); DEFINE_string(image_list, "", "Path of test image list file"); +DEFINE_int32(batch_size, 1, "Batch size when infering"); int main(int argc, char** argv) { // Parsing command-line @@ -44,32 +49,68 @@ int main(int argc, char** argv) { // 加载模型 PaddleX::Model model; - model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key); + model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size); // 进行预测 + double total_running_time_s = 0.0; + double total_imreaad_time_s = 0.0; + if (FLAGS_image_list != "") { std::ifstream inf(FLAGS_image_list); if (!inf) { std::cerr << "Fail to open file " << FLAGS_image_list << std::endl; return -1; } + // 多batch预测 std::string image_path; + std::vector image_path_vec; while (getline(inf, image_path)) { - PaddleX::ClsResult result; - cv::Mat im = cv::imread(image_path, 1); - model.predict(im, &result); - std::cout << "Predict label: " << result.category - << ", label_id:" << result.category_id - << ", score: " << result.score << std::endl; + image_path_vec.push_back(image_path); + } + for(int i = 0; i < image_path_vec.size(); i += FLAGS_batch_size) { + auto start = system_clock::now(); + // 读图像 + int im_vec_size = std::min((int)image_path_vec.size(), i + FLAGS_batch_size); + std::vector im_vec(im_vec_size - i); + std::vector results(im_vec_size - i, PaddleX::ClsResult()); + #pragma omp parallel for num_threads(im_vec_size - i) + for(int j = i; j < im_vec_size; ++j){ + im_vec[j - i] = std::move(cv::imread(image_path_vec[j], 1)); + } + auto imread_end = system_clock::now(); + model.predict(im_vec, results); + + auto imread_duration = duration_cast(imread_end - start); + total_imreaad_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den; + + auto end = system_clock::now(); + auto duration = duration_cast(end - start); + total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; + for(int j = i; j < im_vec_size; ++j) { + std::cout << "Path:" << image_path_vec[j] + << ", predict label: " << results[j - i].category + << ", label_id:" << results[j - i].category_id + << ", score: " << results[j - i].score << std::endl; + } } } else { + auto start = system_clock::now(); PaddleX::ClsResult result; cv::Mat im = cv::imread(FLAGS_image, 1); model.predict(im, &result); + auto end = system_clock::now(); + auto duration = duration_cast(end - start); + total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; std::cout << "Predict label: " << result.category << ", label_id:" << result.category_id << ", score: " << result.score << std::endl; } - + std::cout << "Total average running time: " + << total_running_time_s + << " s, total average read img time: " + << total_imreaad_time_s + << " s, batch_size = " + << FLAGS_batch_size + << std::endl; return 0; } diff --git a/deploy/cpp/include/paddlex/paddlex.h b/deploy/cpp/include/paddlex/paddlex.h index d000728..3b63343 100644 --- a/deploy/cpp/include/paddlex/paddlex.h +++ b/deploy/cpp/include/paddlex/paddlex.h @@ -45,22 +45,28 @@ class Model { bool use_gpu = false, bool use_trt = false, int gpu_id = 0, - std::string key = "") { - create_predictor(model_dir, use_gpu, use_trt, gpu_id, key); + std::string key = "", + int batch_size = 1) { + create_predictor(model_dir, use_gpu, use_trt, gpu_id, key, batch_size); } void create_predictor(const std::string& model_dir, bool use_gpu = false, bool use_trt = false, int gpu_id = 0, - std::string key = ""); + std::string key = "", + int batch_size = 1); bool load_config(const std::string& model_dir); bool preprocess(const cv::Mat& input_im, ImageBlob* blob); + + bool preprocess(const std::vector &input_im_batch, std::vector &blob_batch); bool predict(const cv::Mat& im, ClsResult* result); + bool predict(const std::vector &im_batch, std::vector &results); + bool predict(const cv::Mat& im, DetResult* result); bool predict(const cv::Mat& im, SegResult* result); @@ -74,6 +80,7 @@ class Model { std::map labels; Transforms transforms_; ImageBlob inputs_; + std::vector inputs_batch_; std::vector outputs_; std::unique_ptr predictor_; }; diff --git a/deploy/cpp/scripts/build.sh b/deploy/cpp/scripts/build.sh index 74ab96a..d2d29cf 100644 --- a/deploy/cpp/scripts/build.sh +++ b/deploy/cpp/scripts/build.sh @@ -1,5 +1,5 @@ # 是否使用GPU(即是否使用 CUDA) -WITH_GPU=OFF +WITH_GPU=ON # 使用MKL or openblas WITH_MKL=ON # 是否集成 TensorRT(仅WITH_GPU=ON 有效) @@ -7,7 +7,7 @@ WITH_TENSORRT=OFF # TensorRT 的路径 TENSORRT_DIR=/path/to/TensorRT/ # Paddle 预测库路径 -PADDLE_DIR=/docker/jiangjiajun/PaddleDetection/deploy/cpp/fluid_inference +PADDLE_DIR=/mnt/zhoushunjie/projects/fluid_inference # Paddle 的预测库是否使用静态库来编译 # 使用TensorRT时,Paddle的预测库通常为动态库 WITH_STATIC_LIB=OFF @@ -42,4 +42,4 @@ cmake .. \ -DCUDNN_LIB=${CUDNN_LIB} \ -DENCRYPTION_DIR=${ENCRYPTION_DIR} \ -DOPENCV_DIR=${OPENCV_DIR} -make +make -j4 diff --git a/deploy/cpp/src/paddlex.cpp b/deploy/cpp/src/paddlex.cpp index 90a4a44..6d6de8b 100644 --- a/deploy/cpp/src/paddlex.cpp +++ b/deploy/cpp/src/paddlex.cpp @@ -11,16 +11,17 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include +#include #include "include/paddlex/paddlex.h" - namespace PaddleX { void Model::create_predictor(const std::string& model_dir, bool use_gpu, bool use_trt, int gpu_id, - std::string key) { + std::string key, + int batch_size) { // 读取配置文件 if (!load_config(model_dir)) { std::cerr << "Parse file 'model.yml' failed!" << std::endl; @@ -58,6 +59,7 @@ void Model::create_predictor(const std::string& model_dir, false /* use_calib_mode*/); } predictor_ = std::move(CreatePaddlePredictor(config)); + inputs_batch_.assign(batch_size, ImageBlob()); } bool Model::load_config(const std::string& model_dir) { @@ -104,6 +106,21 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) { return true; } +// use openmp +bool Model::preprocess(const std::vector &input_im_batch, std::vector &blob_batch) { + int batch_size = inputs_batch_.size(); + bool success = true; + //int i; + #pragma omp parallel for num_threads(batch_size) + for(int i = 0; i < input_im_batch.size(); ++i) { + cv::Mat im = input_im_batch[i].clone(); + if(!transforms_.Run(&im, &blob_batch[i])){ + success = false; + } + } + return success; +} + bool Model::predict(const cv::Mat& im, ClsResult* result) { inputs_.clear(); if (type == "detector") { @@ -146,6 +163,64 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) { result->category = labels[result->category_id]; } +bool Model::predict(const std::vector &im_batch, std::vector &results) { + for(auto &inputs: inputs_batch_) { + inputs.clear(); + } + if (type == "detector") { + std::cerr << "Loading model is a 'detector', DetResult should be passed to " + "function predict()!" + << std::endl; + return false; + } else if (type == "segmenter") { + std::cerr << "Loading model is a 'segmenter', SegResult should be passed " + "to function predict()!" + << std::endl; + return false; + } + // 处理输入图像 + if (!preprocess(im_batch, inputs_batch_)) { + std::cerr << "Preprocess failed!" << std::endl; + return false; + } + // 使用加载的模型进行预测 + int batch_size = im_batch.size(); + auto in_tensor = predictor_->GetInputTensor("image"); + int h = inputs_batch_[0].new_im_size_[0]; + int w = inputs_batch_[0].new_im_size_[1]; + in_tensor->Reshape({batch_size, 3, h, w}); + std::vector inputs_data(batch_size * 3 * h * w); + for(int i = 0; i copy_from_cpu(inputs_data.data()); + //in_tensor->copy_from_cpu(inputs_.im_data_.data()); + predictor_->ZeroCopyRun(); + // 取出模型的输出结果 + auto output_names = predictor_->GetOutputNames(); + auto output_tensor = predictor_->GetOutputTensor(output_names[0]); + std::vector output_shape = output_tensor->shape(); + int size = 1; + for (const auto& i : output_shape) { + size *= i; + } + outputs_.resize(size); + output_tensor->copy_to_cpu(outputs_.data()); + // 对模型输出结果进行后处理 + int single_batch_size = size / batch_size; + for(int i = 0; i < batch_size; ++i) { + auto start_ptr = std::begin(outputs_); + auto end_ptr = std::begin(outputs_); + std::advance(start_ptr, i * single_batch_size); + std::advance(end_ptr, (i + 1) * single_batch_size); + auto ptr = std::max_element(start_ptr, end_ptr); + results[i].category_id = std::distance(start_ptr, ptr); + results[i].score = *ptr; + results[i].category = labels[results[i].category_id]; + } + return true; +} + bool Model::predict(const cv::Mat& im, DetResult* result) { result->clear(); inputs_.clear(); @@ -288,6 +363,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { size *= i; result->label_map.shape.push_back(i); } + result->label_map.data.resize(size); output_label_tensor->copy_to_cpu(result->label_map.data.data()); @@ -299,6 +375,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { size *= i; result->score_map.shape.push_back(i); } + result->score_map.data.resize(size); output_score_tensor->copy_to_cpu(result->score_map.data.data()); @@ -325,8 +402,8 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { inputs_.im_size_before_resize_.pop_back(); auto padding_w = before_shape[0]; auto padding_h = before_shape[1]; - mask_label = mask_label(cv::Rect(0, 0, padding_w, padding_h)); - mask_score = mask_score(cv::Rect(0, 0, padding_w, padding_h)); + mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w)); + mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w)); } else if (*iter == "resize") { auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx]; inputs_.im_size_before_resize_.pop_back(); @@ -343,7 +420,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { cv::Size(resize_h, resize_w), 0, 0, - cv::INTER_NEAREST); + cv::INTER_LINEAR); } ++idx; } -- GitLab