提交 eeec4b15 编写于 作者: J jack

add yolov3 batch predict

上级 26cefa3d
......@@ -14,14 +14,19 @@
#include <glog/logging.h>
#include <algorithm>
#include <chrono>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <utility>
#include "include/paddlex/paddlex.h"
#include "include/paddlex/visualize.h"
using namespace std::chrono;
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
......@@ -30,6 +35,7 @@ DEFINE_string(key, "", "key of encryption");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_string(save_dir, "output", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "");
int main(int argc, char** argv) {
// 解析命令行参数
......@@ -46,8 +52,11 @@ int main(int argc, char** argv) {
// 加载模型
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key);
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
double total_running_time_s = 0.0;
double total_imread_time_s = 0.0;
int imgs = 1;
auto colormap = PaddleX::GenerateColorMap(model.labels.size());
std::string save_dir = "output";
// 进行预测
......@@ -58,35 +67,57 @@ int main(int argc, char** argv) {
return -1;
}
std::string image_path;
std::vector<std::string> image_paths;
while (getline(inf, image_path)) {
PaddleX::DetResult result;
cv::Mat im = cv::imread(image_path, 1);
model.predict(im, &result);
for (int i = 0; i < result.boxes.size(); ++i) {
std::cout << "image file: " << image_path
<< ", predict label: " << result.boxes[i].category
<< ", label_id:" << result.boxes[i].category_id
<< ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):("
<< result.boxes[i].coordinate[0] << ", "
<< result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << ")" << std::endl;
image_paths.push_back(image_path);
}
imgs = image_paths.size();
for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
auto start = system_clock::now();
int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::DetResult> results(im_vec_size - i, PaddleX::DetResult());
#pragma omp parallel for num_threads(im_vec_size - i)
for(int j = i; j < im_vec_size; ++j){
im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
}
auto imread_end = system_clock::now();
model.predict(im_vec, results);
auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
//输出结果目标框
for(int j = 0; j < im_vec_size - i; ++j) {
std::cout << "image file: " << image_paths[i + j] << std::endl;
for(int k = 0; k < results[j].boxes.size(); ++k) {
std::cout << "predict label: " << results[j].boxes[k].category
<< ", label_id:" << results[j].boxes[k].category_id
<< ", score: " << results[j].boxes[k].score << ", box(xmin, ymin, w, h):("
<< results[j].boxes[k].coordinate[0] << ", "
<< results[j].boxes[k].coordinate[1] << ", "
<< results[j].boxes[k].coordinate[2] << ", "
<< results[j].boxes[k].coordinate[3] << ")" << std::endl;
}
}
// 可视化
cv::Mat vis_img =
PaddleX::Visualize(im, result, model.labels, colormap, 0.5);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, image_path);
cv::imwrite(save_path, vis_img);
result.clear();
std::cout << "Visualized output saved as " << save_path << std::endl;
for(int j = 0; j < im_vec_size - i; ++j) {
cv::Mat vis_img =
PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap, 0.5);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, image_paths[i + j]);
cv::imwrite(save_path, vis_img);
std::cout << "Visualized output saved as " << save_path << std::endl;
}
}
} else {
PaddleX::DetResult result;
cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(im, &result);
for (int i = 0; i < result.boxes.size(); ++i) {
std::cout << "image file: " << FLAGS_image << std::endl;
std::cout << ", predict label: " << result.boxes[i].category
<< ", label_id:" << result.boxes[i].category_id
<< ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):("
......@@ -105,6 +136,18 @@ int main(int argc, char** argv) {
result.clear();
std::cout << "Visualized output saved as " << save_path << std::endl;
}
std::cout << "Total running time: "
<< total_running_time_s
<< " s, average running time: "
<< total_running_time_s / imgs
<< " s/img, total read img time: "
<< total_imread_time_s
<< " s, average read img time: "
<< total_imread_time_s / imgs
<< " s, batch_size = "
<< FLAGS_batch_size
<< std::endl;
return 0;
}
......@@ -83,7 +83,6 @@ int main(int argc, char** argv) {
model.predict(im_vec, results);
auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
......
......@@ -58,6 +58,7 @@ class Transform {
public:
virtual void Init(const YAML::Node& item) = 0;
virtual bool Run(cv::Mat* im, ImageBlob* data) = 0;
virtual void SetPaddingSize(int max_h, int max_w) {}
};
class Normalize : public Transform {
......@@ -169,11 +170,13 @@ class Padding : public Transform {
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
virtual void SetPaddingSize(int max_h, int max_w);
private:
int coarsest_stride_ = -1;
int width_ = 0;
int height_ = 0;
int max_height_ = 0;
int max_width_ = 0;
};
class Transforms {
......@@ -181,10 +184,12 @@ class Transforms {
void Init(const YAML::Node& node, bool to_rgb = true);
std::shared_ptr<Transform> CreateTransform(const std::string& name);
bool Run(cv::Mat* im, ImageBlob* data);
void SetPaddingSize(int max_h, int max_w);
private:
std::vector<std::shared_ptr<Transform>> transforms_;
bool to_rgb_ = true;
int max_h_ = 0;
int max_w_ = 0;
};
} // namespace PaddleX
......@@ -14,6 +14,7 @@
#include <algorithm>
#include <omp.h>
#include "include/paddlex/paddlex.h"
#include <fstream>
namespace PaddleX {
void Model::create_predictor(const std::string& model_dir,
......@@ -100,6 +101,9 @@ bool Model::load_config(const std::string& model_dir) {
bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
cv::Mat im = input_im.clone();
int max_h = im.rows;
int max_w = im.cols;
transforms_.SetPaddingSize(max_h, max_w);
if (!transforms_.Run(&im, blob)) {
return false;
}
......@@ -110,7 +114,13 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
bool Model::preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch) {
int batch_size = inputs_batch_.size();
bool success = true;
//int i;
int max_h = -1;
int max_w = -1;
for(int i = 0; i < input_im_batch.size(); ++i) {
max_h = std::max(max_h, input_im_batch[i].rows);
max_w = std::max(max_w, input_im_batch[i].cols);
}
transforms_.SetPaddingSize(max_h, max_w);
#pragma omp parallel for num_threads(batch_size)
for(int i = 0; i < input_im_batch.size(); ++i) {
cv::Mat im = input_im_batch[i].clone();
......@@ -126,10 +136,6 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
<< std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!"
<< std::endl;
return false;
......@@ -224,7 +230,6 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult>
bool Model::predict(const cv::Mat& im, DetResult* result) {
result->clear();
inputs_.clear();
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!"
......@@ -248,9 +253,15 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
auto im_tensor = predictor_->GetInputTensor("image");
im_tensor->Reshape({1, 3, h, w});
im_tensor->copy_from_cpu(inputs_.im_data_.data());
std::ofstream fout("test_single.dat", std::ios::out);
if (name == "YOLOv3") {
auto im_size_tensor = predictor_->GetInputTensor("im_size");
im_size_tensor->Reshape({1, 2});
for(int i = 0; i < inputs_.ori_im_size_.size(); ++i) {
fout << inputs_.ori_im_size_[i] << " ";
}
fout << std::endl;
im_size_tensor->copy_from_cpu(inputs_.ori_im_size_.data());
} else if (name == "FasterRCNN" || name == "MaskRCNN") {
auto im_info_tensor = predictor_->GetInputTensor("im_info");
......@@ -283,6 +294,9 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
std::cerr << "[WARNING] There's no object detected." << std::endl;
return true;
}
for(int i = 0; i < output_box.size(); ++i) {
fout << output_box[i] << " ";
}
int num_boxes = size / 6;
// 解析预测框box
for (int i = 0; i < num_boxes; ++i) {
......@@ -326,6 +340,141 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
return true;
}
bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult> &result) {
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!"
<< std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!"
<< std::endl;
return false;
}
// 处理输入图像
if (!preprocess(im_batch, inputs_batch_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
int batch_size = im_batch.size();
int h = inputs_batch_[0].new_im_size_[0];
int w = inputs_batch_[0].new_im_size_[1];
auto im_tensor = predictor_->GetInputTensor("image");
im_tensor->Reshape({batch_size, 3, h, w});
std::vector<float> inputs_data(batch_size * 3 * h * w);
for(int i = 0; i < inputs_batch_.size(); ++i) {
std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
}
im_tensor->copy_from_cpu(inputs_data.data());
std::ofstream fout("test_batch.dat", std::ios::out);
if (name == "YOLOv3") {
auto im_size_tensor = predictor_->GetInputTensor("im_size");
im_size_tensor->Reshape({batch_size, 2});
std::vector<int> inputs_data_size(batch_size * 2);
for(int i = 0; i < inputs_batch_.size(); ++i){
std::copy(inputs_batch_[i].ori_im_size_.begin(), inputs_batch_[i].ori_im_size_.end(), inputs_data_size.begin() + 2 * i);
}
for(int i = 0; i < inputs_data_size.size(); ++i) {
fout << inputs_data_size[i] << " ";
}
fout << std::endl;
im_size_tensor->copy_from_cpu(inputs_data_size.data());
} else if (name == "FasterRCNN" || name == "MaskRCNN") {
auto im_info_tensor = predictor_->GetInputTensor("im_info");
auto im_shape_tensor = predictor_->GetInputTensor("im_shape");
im_info_tensor->Reshape({batch_size, 3});
im_shape_tensor->Reshape({batch_size, 3});
std::vector<float> im_info(3 * batch_size);
std::vector<float> im_shape(3 * batch_size);
for(int i = 0; i < inputs_batch_.size(); ++i) {
float ori_h = static_cast<float>(inputs_batch_[i].ori_im_size_[0]);
float ori_w = static_cast<float>(inputs_batch_[i].ori_im_size_[1]);
float new_h = static_cast<float>(inputs_batch_[i].new_im_size_[0]);
float new_w = static_cast<float>(inputs_batch_[i].new_im_size_[1]);
im_info[i * 3] = new_h;
im_info[i * 3 + 1] = new_w;
im_info[i * 3 + 2] = inputs_batch_[i].scale;
im_shape[i * 3] = ori_h;
im_shape[i * 3 + 1] = ori_w;
im_shape[i * 3 + 2] = 1.0;
}
im_info_tensor->copy_from_cpu(im_info.data());
im_shape_tensor->copy_from_cpu(im_shape.data());
}
// 使用加载的模型进行预测
predictor_->ZeroCopyRun();
// 读取所有box
std::vector<float> output_box;
auto output_names = predictor_->GetOutputNames();
auto output_box_tensor = predictor_->GetOutputTensor(output_names[0]);
std::vector<int> output_box_shape = output_box_tensor->shape();
int size = 1;
for (const auto& i : output_box_shape) {
size *= i;
}
output_box.resize(size);
output_box_tensor->copy_to_cpu(output_box.data());
if (size < 6) {
std::cerr << "[WARNING] There's no object detected." << std::endl;
return true;
}
for(int i = 0; i < output_box.size(); ++i) {
fout << output_box[i] << " ";
}
auto lod_vector = output_box_tensor->lod();
int num_boxes = size / 6;
// 解析预测框box
for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
for(int j = lod_vector[0][i]; j < lod_vector[0][i + 1]; ++j) {
Box box;
box.category_id = static_cast<int> (round(output_box[j * 6]));
box.category = labels[box.category_id];
box.score = output_box[j * 6 + 1];
float xmin = output_box[j * 6 + 2];
float ymin = output_box[j * 6 + 3];
float xmax = output_box[j * 6 + 4];
float ymax = output_box[j * 6 + 5];
float w = xmax - xmin + 1;
float h = ymax - ymin + 1;
box.coordinate = {xmin, ymin, w, h};
result[i].boxes.push_back(std::move(box));
}
}
// 实例分割需解析mask
if (name == "MaskRCNN") {
std::vector<float> output_mask;
auto output_mask_tensor = predictor_->GetOutputTensor(output_names[1]);
std::vector<int> output_mask_shape = output_mask_tensor->shape();
int masks_size = 1;
for (const auto& i : output_mask_shape) {
masks_size *= i;
}
int mask_pixels = output_mask_shape[2] * output_mask_shape[3];
int classes = output_mask_shape[1];
output_mask.resize(masks_size);
output_mask_tensor->copy_to_cpu(output_mask.data());
int mask_idx = 0;
for(int i = 0; i < lod_vector[0].size() - 1; ++i) {
result[i].mask_resolution = output_mask_shape[2];
for(int j = 0; j < result[i].boxes.size(); ++j) {
Box* box = &result[i].boxes[j];
auto begin_mask = output_mask.begin() + (mask_idx * classes + box->category_id) * mask_pixels;
auto end_mask = begin_mask + mask_pixels;
box->mask.data.assign(begin_mask, end_mask);
box->mask.shape = {static_cast<int>(box->coordinate[2]),
static_cast<int>(box->coordinate[3])};
mask_idx++;
}
}
}
return true;
}
bool Model::predict(const cv::Mat& im, SegResult* result) {
result->clear();
inputs_.clear();
......
......@@ -95,11 +95,11 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) {
if (width_ > 1 & height_ > 1) {
padding_w = width_ - im->cols;
padding_h = height_ - im->rows;
} else if (coarsest_stride_ > 1) {
} else if (coarsest_stride_ >= 1) {
padding_h =
ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
ceil(max_height_ * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
padding_w =
ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
ceil(max_width_ * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
}
if (padding_h < 0 || padding_w < 0) {
......@@ -115,6 +115,11 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) {
return true;
}
void Padding::SetPaddingSize(int max_h, int max_w) {
max_height_ = max_h;
max_width_ = max_w;
}
bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
if (long_size_ <= 0) {
std::cerr << "[ResizeByLong] long_size should be greater than 0"
......@@ -201,6 +206,7 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
for (int i = 0; i < transforms_.size(); ++i) {
transforms_[i]->SetPaddingSize(max_h_, max_w_);
if (!transforms_[i]->Run(im, data)) {
std::cerr << "Apply transforms to image failed!" << std::endl;
return false;
......@@ -219,4 +225,10 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
}
return true;
}
void Transforms::SetPaddingSize(int max_h, int max_w) {
max_h_ = max_h;
max_w_ = max_w;
}
} // namespace PaddleX
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册