提交 cb6edab6 编写于 作者: J jack

add batch predict in classification task

上级 e9b8c938
...@@ -14,13 +14,17 @@ ...@@ -14,13 +14,17 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm>
#include <chrono>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include "include/paddlex/paddlex.h" #include "include/paddlex/paddlex.h"
using namespace std::chrono;
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT"); DEFINE_bool(use_trt, false, "Infering with TensorRT");
...@@ -28,6 +32,7 @@ DEFINE_int32(gpu_id, 0, "GPU card id"); ...@@ -28,6 +32,7 @@ DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(key, "", "key of encryption"); DEFINE_string(key, "", "key of encryption");
DEFINE_string(image, "", "Path of test image file"); DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file"); DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_int32(batch_size, 1, "Batch size when infering");
int main(int argc, char** argv) { int main(int argc, char** argv) {
// Parsing command-line // Parsing command-line
...@@ -44,32 +49,68 @@ int main(int argc, char** argv) { ...@@ -44,32 +49,68 @@ int main(int argc, char** argv) {
// 加载模型 // 加载模型
PaddleX::Model model; PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key); model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
// 进行预测 // 进行预测
double total_running_time_s = 0.0;
double total_imreaad_time_s = 0.0;
if (FLAGS_image_list != "") { if (FLAGS_image_list != "") {
std::ifstream inf(FLAGS_image_list); std::ifstream inf(FLAGS_image_list);
if (!inf) { if (!inf) {
std::cerr << "Fail to open file " << FLAGS_image_list << std::endl; std::cerr << "Fail to open file " << FLAGS_image_list << std::endl;
return -1; return -1;
} }
// 多batch预测
std::string image_path; std::string image_path;
std::vector<std::string> image_path_vec;
while (getline(inf, image_path)) { while (getline(inf, image_path)) {
PaddleX::ClsResult result; image_path_vec.push_back(image_path);
cv::Mat im = cv::imread(image_path, 1); }
model.predict(im, &result); for(int i = 0; i < image_path_vec.size(); i += FLAGS_batch_size) {
std::cout << "Predict label: " << result.category auto start = system_clock::now();
<< ", label_id:" << result.category_id // 读图像
<< ", score: " << result.score << std::endl; int im_vec_size = std::min((int)image_path_vec.size(), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::ClsResult> results(im_vec_size - i, PaddleX::ClsResult());
#pragma omp parallel for num_threads(im_vec_size - i)
for(int j = i; j < im_vec_size; ++j){
im_vec[j - i] = std::move(cv::imread(image_path_vec[j], 1));
}
auto imread_end = system_clock::now();
model.predict(im_vec, results);
auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imreaad_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
for(int j = i; j < im_vec_size; ++j) {
std::cout << "Path:" << image_path_vec[j]
<< ", predict label: " << results[j - i].category
<< ", label_id:" << results[j - i].category_id
<< ", score: " << results[j - i].score << std::endl;
}
} }
} else { } else {
auto start = system_clock::now();
PaddleX::ClsResult result; PaddleX::ClsResult result;
cv::Mat im = cv::imread(FLAGS_image, 1); cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(im, &result); model.predict(im, &result);
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
std::cout << "Predict label: " << result.category std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id << ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl; << ", score: " << result.score << std::endl;
} }
std::cout << "Total average running time: "
<< total_running_time_s
<< " s, total average read img time: "
<< total_imreaad_time_s
<< " s, batch_size = "
<< FLAGS_batch_size
<< std::endl;
return 0; return 0;
} }
...@@ -45,22 +45,28 @@ class Model { ...@@ -45,22 +45,28 @@ class Model {
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false, bool use_trt = false,
int gpu_id = 0, int gpu_id = 0,
std::string key = "") { std::string key = "",
create_predictor(model_dir, use_gpu, use_trt, gpu_id, key); int batch_size = 1) {
create_predictor(model_dir, use_gpu, use_trt, gpu_id, key, batch_size);
} }
void create_predictor(const std::string& model_dir, void create_predictor(const std::string& model_dir,
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false, bool use_trt = false,
int gpu_id = 0, int gpu_id = 0,
std::string key = ""); std::string key = "",
int batch_size = 1);
bool load_config(const std::string& model_dir); bool load_config(const std::string& model_dir);
bool preprocess(const cv::Mat& input_im, ImageBlob* blob); bool preprocess(const cv::Mat& input_im, ImageBlob* blob);
bool preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch);
bool predict(const cv::Mat& im, ClsResult* result); bool predict(const cv::Mat& im, ClsResult* result);
bool predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results);
bool predict(const cv::Mat& im, DetResult* result); bool predict(const cv::Mat& im, DetResult* result);
bool predict(const cv::Mat& im, SegResult* result); bool predict(const cv::Mat& im, SegResult* result);
...@@ -74,6 +80,7 @@ class Model { ...@@ -74,6 +80,7 @@ class Model {
std::map<int, std::string> labels; std::map<int, std::string> labels;
Transforms transforms_; Transforms transforms_;
ImageBlob inputs_; ImageBlob inputs_;
std::vector<ImageBlob> inputs_batch_;
std::vector<float> outputs_; std::vector<float> outputs_;
std::unique_ptr<paddle::PaddlePredictor> predictor_; std::unique_ptr<paddle::PaddlePredictor> predictor_;
}; };
......
# 是否使用GPU(即是否使用 CUDA) # 是否使用GPU(即是否使用 CUDA)
WITH_GPU=OFF WITH_GPU=ON
# 使用MKL or openblas # 使用MKL or openblas
WITH_MKL=ON WITH_MKL=ON
# 是否集成 TensorRT(仅WITH_GPU=ON 有效) # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
...@@ -7,7 +7,7 @@ WITH_TENSORRT=OFF ...@@ -7,7 +7,7 @@ WITH_TENSORRT=OFF
# TensorRT 的路径 # TensorRT 的路径
TENSORRT_DIR=/path/to/TensorRT/ TENSORRT_DIR=/path/to/TensorRT/
# Paddle 预测库路径 # Paddle 预测库路径
PADDLE_DIR=/docker/jiangjiajun/PaddleDetection/deploy/cpp/fluid_inference PADDLE_DIR=/mnt/zhoushunjie/projects/fluid_inference
# Paddle 的预测库是否使用静态库来编译 # Paddle 的预测库是否使用静态库来编译
# 使用TensorRT时,Paddle的预测库通常为动态库 # 使用TensorRT时,Paddle的预测库通常为动态库
WITH_STATIC_LIB=OFF WITH_STATIC_LIB=OFF
...@@ -42,4 +42,4 @@ cmake .. \ ...@@ -42,4 +42,4 @@ cmake .. \
-DCUDNN_LIB=${CUDNN_LIB} \ -DCUDNN_LIB=${CUDNN_LIB} \
-DENCRYPTION_DIR=${ENCRYPTION_DIR} \ -DENCRYPTION_DIR=${ENCRYPTION_DIR} \
-DOPENCV_DIR=${OPENCV_DIR} -DOPENCV_DIR=${OPENCV_DIR}
make make -j4
...@@ -11,16 +11,17 @@ ...@@ -11,16 +11,17 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <omp.h>
#include "include/paddlex/paddlex.h" #include "include/paddlex/paddlex.h"
namespace PaddleX { namespace PaddleX {
void Model::create_predictor(const std::string& model_dir, void Model::create_predictor(const std::string& model_dir,
bool use_gpu, bool use_gpu,
bool use_trt, bool use_trt,
int gpu_id, int gpu_id,
std::string key) { std::string key,
int batch_size) {
// 读取配置文件 // 读取配置文件
if (!load_config(model_dir)) { if (!load_config(model_dir)) {
std::cerr << "Parse file 'model.yml' failed!" << std::endl; std::cerr << "Parse file 'model.yml' failed!" << std::endl;
...@@ -58,6 +59,7 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -58,6 +59,7 @@ void Model::create_predictor(const std::string& model_dir,
false /* use_calib_mode*/); false /* use_calib_mode*/);
} }
predictor_ = std::move(CreatePaddlePredictor(config)); predictor_ = std::move(CreatePaddlePredictor(config));
inputs_batch_.assign(batch_size, ImageBlob());
} }
bool Model::load_config(const std::string& model_dir) { bool Model::load_config(const std::string& model_dir) {
...@@ -104,6 +106,21 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) { ...@@ -104,6 +106,21 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
return true; return true;
} }
// use openmp
bool Model::preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch) {
int batch_size = inputs_batch_.size();
bool success = true;
//int i;
#pragma omp parallel for num_threads(batch_size)
for(int i = 0; i < input_im_batch.size(); ++i) {
cv::Mat im = input_im_batch[i].clone();
if(!transforms_.Run(&im, &blob_batch[i])){
success = false;
}
}
return success;
}
bool Model::predict(const cv::Mat& im, ClsResult* result) { bool Model::predict(const cv::Mat& im, ClsResult* result) {
inputs_.clear(); inputs_.clear();
if (type == "detector") { if (type == "detector") {
...@@ -146,6 +163,64 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) { ...@@ -146,6 +163,64 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
result->category = labels[result->category_id]; result->category = labels[result->category_id];
} }
bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results) {
for(auto &inputs: inputs_batch_) {
inputs.clear();
}
if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
<< std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!"
<< std::endl;
return false;
}
// 处理输入图像
if (!preprocess(im_batch, inputs_batch_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
// 使用加载的模型进行预测
int batch_size = im_batch.size();
auto in_tensor = predictor_->GetInputTensor("image");
int h = inputs_batch_[0].new_im_size_[0];
int w = inputs_batch_[0].new_im_size_[1];
in_tensor->Reshape({batch_size, 3, h, w});
std::vector<float> inputs_data(batch_size * 3 * h * w);
for(int i = 0; i <inputs_batch_.size(); ++i) {
std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
}
in_tensor->copy_from_cpu(inputs_data.data());
//in_tensor->copy_from_cpu(inputs_.im_data_.data());
predictor_->ZeroCopyRun();
// 取出模型的输出结果
auto output_names = predictor_->GetOutputNames();
auto output_tensor = predictor_->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
int size = 1;
for (const auto& i : output_shape) {
size *= i;
}
outputs_.resize(size);
output_tensor->copy_to_cpu(outputs_.data());
// 对模型输出结果进行后处理
int single_batch_size = size / batch_size;
for(int i = 0; i < batch_size; ++i) {
auto start_ptr = std::begin(outputs_);
auto end_ptr = std::begin(outputs_);
std::advance(start_ptr, i * single_batch_size);
std::advance(end_ptr, (i + 1) * single_batch_size);
auto ptr = std::max_element(start_ptr, end_ptr);
results[i].category_id = std::distance(start_ptr, ptr);
results[i].score = *ptr;
results[i].category = labels[results[i].category_id];
}
return true;
}
bool Model::predict(const cv::Mat& im, DetResult* result) { bool Model::predict(const cv::Mat& im, DetResult* result) {
result->clear(); result->clear();
inputs_.clear(); inputs_.clear();
...@@ -288,6 +363,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { ...@@ -288,6 +363,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
size *= i; size *= i;
result->label_map.shape.push_back(i); result->label_map.shape.push_back(i);
} }
result->label_map.data.resize(size); result->label_map.data.resize(size);
output_label_tensor->copy_to_cpu(result->label_map.data.data()); output_label_tensor->copy_to_cpu(result->label_map.data.data());
...@@ -299,6 +375,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { ...@@ -299,6 +375,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
size *= i; size *= i;
result->score_map.shape.push_back(i); result->score_map.shape.push_back(i);
} }
result->score_map.data.resize(size); result->score_map.data.resize(size);
output_score_tensor->copy_to_cpu(result->score_map.data.data()); output_score_tensor->copy_to_cpu(result->score_map.data.data());
...@@ -325,8 +402,8 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { ...@@ -325,8 +402,8 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
inputs_.im_size_before_resize_.pop_back(); inputs_.im_size_before_resize_.pop_back();
auto padding_w = before_shape[0]; auto padding_w = before_shape[0];
auto padding_h = before_shape[1]; auto padding_h = before_shape[1];
mask_label = mask_label(cv::Rect(0, 0, padding_w, padding_h)); mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
mask_score = mask_score(cv::Rect(0, 0, padding_w, padding_h)); mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
} else if (*iter == "resize") { } else if (*iter == "resize") {
auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx]; auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
inputs_.im_size_before_resize_.pop_back(); inputs_.im_size_before_resize_.pop_back();
...@@ -343,7 +420,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { ...@@ -343,7 +420,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
cv::Size(resize_h, resize_w), cv::Size(resize_h, resize_w),
0, 0,
0, 0,
cv::INTER_NEAREST); cv::INTER_LINEAR);
} }
++idx; ++idx;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册