diff --git a/deploy/cpp/demo/detector.cpp b/deploy/cpp/demo/detector.cpp index e42288fbccd434ef5953c606696af623323aa80d..428505e30f7073ea3e66450dacb62e7f83962ea7 100644 --- a/deploy/cpp/demo/detector.cpp +++ b/deploy/cpp/demo/detector.cpp @@ -14,14 +14,19 @@ #include +#include +#include #include #include #include #include +#include #include "include/paddlex/paddlex.h" #include "include/paddlex/visualize.h" +using namespace std::chrono; + DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_trt, false, "Infering with TensorRT"); @@ -30,6 +35,7 @@ DEFINE_string(key, "", "key of encryption"); DEFINE_string(image, "", "Path of test image file"); DEFINE_string(image_list, "", "Path of test image list file"); DEFINE_string(save_dir, "output", "Path to save visualized image"); +DEFINE_int32(batch_size, 1, ""); int main(int argc, char** argv) { // 解析命令行参数 @@ -46,8 +52,11 @@ int main(int argc, char** argv) { // 加载模型 PaddleX::Model model; - model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key); + model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size); + double total_running_time_s = 0.0; + double total_imread_time_s = 0.0; + int imgs = 1; auto colormap = PaddleX::GenerateColorMap(model.labels.size()); std::string save_dir = "output"; // 进行预测 @@ -58,35 +67,57 @@ int main(int argc, char** argv) { return -1; } std::string image_path; + std::vector image_paths; while (getline(inf, image_path)) { - PaddleX::DetResult result; - cv::Mat im = cv::imread(image_path, 1); - model.predict(im, &result); - for (int i = 0; i < result.boxes.size(); ++i) { - std::cout << "image file: " << image_path - << ", predict label: " << result.boxes[i].category - << ", label_id:" << result.boxes[i].category_id - << ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):(" - << result.boxes[i].coordinate[0] << ", " - << result.boxes[i].coordinate[1] << ", " - << result.boxes[i].coordinate[2] << ", " - << result.boxes[i].coordinate[3] << ")" << std::endl; + image_paths.push_back(image_path); + } + imgs = image_paths.size(); + for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) { + auto start = system_clock::now(); + int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size); + std::vector im_vec(im_vec_size - i); + std::vector results(im_vec_size - i, PaddleX::DetResult()); + #pragma omp parallel for num_threads(im_vec_size - i) + for(int j = i; j < im_vec_size; ++j){ + im_vec[j - i] = std::move(cv::imread(image_paths[j], 1)); + } + auto imread_end = system_clock::now(); + model.predict(im_vec, results); + auto imread_duration = duration_cast(imread_end - start); + total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den; + auto end = system_clock::now(); + auto duration = duration_cast(end - start); + total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; + //输出结果目标框 + for(int j = 0; j < im_vec_size - i; ++j) { + std::cout << "image file: " << image_paths[i + j] << std::endl; + for(int k = 0; k < results[j].boxes.size(); ++k) { + std::cout << "predict label: " << results[j].boxes[k].category + << ", label_id:" << results[j].boxes[k].category_id + << ", score: " << results[j].boxes[k].score << ", box(xmin, ymin, w, h):(" + << results[j].boxes[k].coordinate[0] << ", " + << results[j].boxes[k].coordinate[1] << ", " + << results[j].boxes[k].coordinate[2] << ", " + << results[j].boxes[k].coordinate[3] << ")" << std::endl; + + } } - // 可视化 - cv::Mat vis_img = - PaddleX::Visualize(im, result, model.labels, colormap, 0.5); - std::string save_path = - PaddleX::generate_save_path(FLAGS_save_dir, image_path); - cv::imwrite(save_path, vis_img); - result.clear(); - std::cout << "Visualized output saved as " << save_path << std::endl; + for(int j = 0; j < im_vec_size - i; ++j) { + cv::Mat vis_img = + PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap, 0.5); + std::string save_path = + PaddleX::generate_save_path(FLAGS_save_dir, image_paths[i + j]); + cv::imwrite(save_path, vis_img); + std::cout << "Visualized output saved as " << save_path << std::endl; + } } } else { PaddleX::DetResult result; cv::Mat im = cv::imread(FLAGS_image, 1); model.predict(im, &result); for (int i = 0; i < result.boxes.size(); ++i) { + std::cout << "image file: " << FLAGS_image << std::endl; std::cout << ", predict label: " << result.boxes[i].category << ", label_id:" << result.boxes[i].category_id << ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):(" @@ -105,6 +136,18 @@ int main(int argc, char** argv) { result.clear(); std::cout << "Visualized output saved as " << save_path << std::endl; } + + std::cout << "Total running time: " + << total_running_time_s + << " s, average running time: " + << total_running_time_s / imgs + << " s/img, total read img time: " + << total_imread_time_s + << " s, average read img time: " + << total_imread_time_s / imgs + << " s, batch_size = " + << FLAGS_batch_size + << std::endl; return 0; } diff --git a/deploy/cpp/demo/segmenter.cpp b/deploy/cpp/demo/segmenter.cpp index 31feb7730920e004f616cfd41003225687afee28..fb49cd2c642b679afa00f79f4482cc092de52a34 100644 --- a/deploy/cpp/demo/segmenter.cpp +++ b/deploy/cpp/demo/segmenter.cpp @@ -83,7 +83,6 @@ int main(int argc, char** argv) { model.predict(im_vec, results); auto imread_duration = duration_cast(imread_end - start); total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den; - auto end = system_clock::now(); auto duration = duration_cast(end - start); total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; diff --git a/deploy/cpp/include/paddlex/transforms.h b/deploy/cpp/include/paddlex/transforms.h index f8265db447f693d084c5a789504bc4b0ccc14d28..7c5892ece5569f598b83704c2122f7cf2fda2b9e 100644 --- a/deploy/cpp/include/paddlex/transforms.h +++ b/deploy/cpp/include/paddlex/transforms.h @@ -58,6 +58,7 @@ class Transform { public: virtual void Init(const YAML::Node& item) = 0; virtual bool Run(cv::Mat* im, ImageBlob* data) = 0; + virtual void SetPaddingSize(int max_h, int max_w) {} }; class Normalize : public Transform { @@ -169,11 +170,13 @@ class Padding : public Transform { } } virtual bool Run(cv::Mat* im, ImageBlob* data); - + virtual void SetPaddingSize(int max_h, int max_w); private: int coarsest_stride_ = -1; int width_ = 0; int height_ = 0; + int max_height_ = 0; + int max_width_ = 0; }; class Transforms { @@ -181,10 +184,12 @@ class Transforms { void Init(const YAML::Node& node, bool to_rgb = true); std::shared_ptr CreateTransform(const std::string& name); bool Run(cv::Mat* im, ImageBlob* data); - + void SetPaddingSize(int max_h, int max_w); private: std::vector> transforms_; bool to_rgb_ = true; + int max_h_ = 0; + int max_w_ = 0; }; } // namespace PaddleX diff --git a/deploy/cpp/src/paddlex.cpp b/deploy/cpp/src/paddlex.cpp index df22c76361c565f0a55ad19cb078a4e7af1dbf35..99757a5b0981ca55640dd44f02262eff1ae44585 100644 --- a/deploy/cpp/src/paddlex.cpp +++ b/deploy/cpp/src/paddlex.cpp @@ -14,6 +14,7 @@ #include #include #include "include/paddlex/paddlex.h" +#include namespace PaddleX { void Model::create_predictor(const std::string& model_dir, @@ -100,6 +101,9 @@ bool Model::load_config(const std::string& model_dir) { bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) { cv::Mat im = input_im.clone(); + int max_h = im.rows; + int max_w = im.cols; + transforms_.SetPaddingSize(max_h, max_w); if (!transforms_.Run(&im, blob)) { return false; } @@ -110,7 +114,13 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) { bool Model::preprocess(const std::vector &input_im_batch, std::vector &blob_batch) { int batch_size = inputs_batch_.size(); bool success = true; - //int i; + int max_h = -1; + int max_w = -1; + for(int i = 0; i < input_im_batch.size(); ++i) { + max_h = std::max(max_h, input_im_batch[i].rows); + max_w = std::max(max_w, input_im_batch[i].cols); + } + transforms_.SetPaddingSize(max_h, max_w); #pragma omp parallel for num_threads(batch_size) for(int i = 0; i < input_im_batch.size(); ++i) { cv::Mat im = input_im_batch[i].clone(); @@ -126,10 +136,6 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) { if (type == "detector") { std::cerr << "Loading model is a 'detector', DetResult should be passed to " "function predict()!" - << std::endl; - return false; - } else if (type == "segmenter") { - std::cerr << "Loading model is a 'segmenter', SegResult should be passed " "to function predict()!" << std::endl; return false; @@ -224,7 +230,6 @@ bool Model::predict(const std::vector &im_batch, std::vector bool Model::predict(const cv::Mat& im, DetResult* result) { result->clear(); - inputs_.clear(); if (type == "classifier") { std::cerr << "Loading model is a 'classifier', ClsResult should be passed " "to function predict()!" @@ -248,9 +253,15 @@ bool Model::predict(const cv::Mat& im, DetResult* result) { auto im_tensor = predictor_->GetInputTensor("image"); im_tensor->Reshape({1, 3, h, w}); im_tensor->copy_from_cpu(inputs_.im_data_.data()); + + std::ofstream fout("test_single.dat", std::ios::out); if (name == "YOLOv3") { auto im_size_tensor = predictor_->GetInputTensor("im_size"); im_size_tensor->Reshape({1, 2}); + for(int i = 0; i < inputs_.ori_im_size_.size(); ++i) { + fout << inputs_.ori_im_size_[i] << " "; + } + fout << std::endl; im_size_tensor->copy_from_cpu(inputs_.ori_im_size_.data()); } else if (name == "FasterRCNN" || name == "MaskRCNN") { auto im_info_tensor = predictor_->GetInputTensor("im_info"); @@ -283,6 +294,9 @@ bool Model::predict(const cv::Mat& im, DetResult* result) { std::cerr << "[WARNING] There's no object detected." << std::endl; return true; } + for(int i = 0; i < output_box.size(); ++i) { + fout << output_box[i] << " "; + } int num_boxes = size / 6; // 解析预测框box for (int i = 0; i < num_boxes; ++i) { @@ -326,6 +340,141 @@ bool Model::predict(const cv::Mat& im, DetResult* result) { return true; } +bool Model::predict(const std::vector &im_batch, std::vector &result) { + if (type == "classifier") { + std::cerr << "Loading model is a 'classifier', ClsResult should be passed " + "to function predict()!" + << std::endl; + return false; + } else if (type == "segmenter") { + std::cerr << "Loading model is a 'segmenter', SegResult should be passed " + "to function predict()!" + << std::endl; + return false; + } + + // 处理输入图像 + if (!preprocess(im_batch, inputs_batch_)) { + std::cerr << "Preprocess failed!" << std::endl; + return false; + } + int batch_size = im_batch.size(); + int h = inputs_batch_[0].new_im_size_[0]; + int w = inputs_batch_[0].new_im_size_[1]; + auto im_tensor = predictor_->GetInputTensor("image"); + im_tensor->Reshape({batch_size, 3, h, w}); + std::vector inputs_data(batch_size * 3 * h * w); + for(int i = 0; i < inputs_batch_.size(); ++i) { + std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w); + } + im_tensor->copy_from_cpu(inputs_data.data()); + std::ofstream fout("test_batch.dat", std::ios::out); + if (name == "YOLOv3") { + auto im_size_tensor = predictor_->GetInputTensor("im_size"); + im_size_tensor->Reshape({batch_size, 2}); + std::vector inputs_data_size(batch_size * 2); + for(int i = 0; i < inputs_batch_.size(); ++i){ + std::copy(inputs_batch_[i].ori_im_size_.begin(), inputs_batch_[i].ori_im_size_.end(), inputs_data_size.begin() + 2 * i); + } + for(int i = 0; i < inputs_data_size.size(); ++i) { + fout << inputs_data_size[i] << " "; + } + fout << std::endl; + im_size_tensor->copy_from_cpu(inputs_data_size.data()); + } else if (name == "FasterRCNN" || name == "MaskRCNN") { + auto im_info_tensor = predictor_->GetInputTensor("im_info"); + auto im_shape_tensor = predictor_->GetInputTensor("im_shape"); + im_info_tensor->Reshape({batch_size, 3}); + im_shape_tensor->Reshape({batch_size, 3}); + + std::vector im_info(3 * batch_size); + std::vector im_shape(3 * batch_size); + for(int i = 0; i < inputs_batch_.size(); ++i) { + float ori_h = static_cast(inputs_batch_[i].ori_im_size_[0]); + float ori_w = static_cast(inputs_batch_[i].ori_im_size_[1]); + float new_h = static_cast(inputs_batch_[i].new_im_size_[0]); + float new_w = static_cast(inputs_batch_[i].new_im_size_[1]); + im_info[i * 3] = new_h; + im_info[i * 3 + 1] = new_w; + im_info[i * 3 + 2] = inputs_batch_[i].scale; + im_shape[i * 3] = ori_h; + im_shape[i * 3 + 1] = ori_w; + im_shape[i * 3 + 2] = 1.0; + } + im_info_tensor->copy_from_cpu(im_info.data()); + im_shape_tensor->copy_from_cpu(im_shape.data()); + } + // 使用加载的模型进行预测 + predictor_->ZeroCopyRun(); + + // 读取所有box + std::vector output_box; + auto output_names = predictor_->GetOutputNames(); + auto output_box_tensor = predictor_->GetOutputTensor(output_names[0]); + std::vector output_box_shape = output_box_tensor->shape(); + int size = 1; + for (const auto& i : output_box_shape) { + size *= i; + } + output_box.resize(size); + output_box_tensor->copy_to_cpu(output_box.data()); + if (size < 6) { + std::cerr << "[WARNING] There's no object detected." << std::endl; + return true; + } + for(int i = 0; i < output_box.size(); ++i) { + fout << output_box[i] << " "; + } + auto lod_vector = output_box_tensor->lod(); + int num_boxes = size / 6; + // 解析预测框box + for (int i = 0; i < lod_vector[0].size() - 1; ++i) { + for(int j = lod_vector[0][i]; j < lod_vector[0][i + 1]; ++j) { + Box box; + box.category_id = static_cast (round(output_box[j * 6])); + box.category = labels[box.category_id]; + box.score = output_box[j * 6 + 1]; + float xmin = output_box[j * 6 + 2]; + float ymin = output_box[j * 6 + 3]; + float xmax = output_box[j * 6 + 4]; + float ymax = output_box[j * 6 + 5]; + float w = xmax - xmin + 1; + float h = ymax - ymin + 1; + box.coordinate = {xmin, ymin, w, h}; + result[i].boxes.push_back(std::move(box)); + } + } + + // 实例分割需解析mask + if (name == "MaskRCNN") { + std::vector output_mask; + auto output_mask_tensor = predictor_->GetOutputTensor(output_names[1]); + std::vector output_mask_shape = output_mask_tensor->shape(); + int masks_size = 1; + for (const auto& i : output_mask_shape) { + masks_size *= i; + } + int mask_pixels = output_mask_shape[2] * output_mask_shape[3]; + int classes = output_mask_shape[1]; + output_mask.resize(masks_size); + output_mask_tensor->copy_to_cpu(output_mask.data()); + int mask_idx = 0; + for(int i = 0; i < lod_vector[0].size() - 1; ++i) { + result[i].mask_resolution = output_mask_shape[2]; + for(int j = 0; j < result[i].boxes.size(); ++j) { + Box* box = &result[i].boxes[j]; + auto begin_mask = output_mask.begin() + (mask_idx * classes + box->category_id) * mask_pixels; + auto end_mask = begin_mask + mask_pixels; + box->mask.data.assign(begin_mask, end_mask); + box->mask.shape = {static_cast(box->coordinate[2]), + static_cast(box->coordinate[3])}; + mask_idx++; + } + } + } + return true; +} + bool Model::predict(const cv::Mat& im, SegResult* result) { result->clear(); inputs_.clear(); diff --git a/deploy/cpp/src/transforms.cpp b/deploy/cpp/src/transforms.cpp index 9224367d3522ebe4e323a40a1af92be7cfeae9d3..27977ef44fa79b8486ef717c4ee0a8f4f653f62e 100644 --- a/deploy/cpp/src/transforms.cpp +++ b/deploy/cpp/src/transforms.cpp @@ -95,11 +95,11 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) { if (width_ > 1 & height_ > 1) { padding_w = width_ - im->cols; padding_h = height_ - im->rows; - } else if (coarsest_stride_ > 1) { + } else if (coarsest_stride_ >= 1) { padding_h = - ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows; + ceil(max_height_ * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows; padding_w = - ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols; + ceil(max_width_ * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols; } if (padding_h < 0 || padding_w < 0) { @@ -115,6 +115,11 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) { return true; } +void Padding::SetPaddingSize(int max_h, int max_w) { + max_height_ = max_h; + max_width_ = max_w; +} + bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) { if (long_size_ <= 0) { std::cerr << "[ResizeByLong] long_size should be greater than 0" @@ -201,6 +206,7 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) { data->new_im_size_[0] = im->rows; data->new_im_size_[1] = im->cols; for (int i = 0; i < transforms_.size(); ++i) { + transforms_[i]->SetPaddingSize(max_h_, max_w_); if (!transforms_[i]->Run(im, data)) { std::cerr << "Apply transforms to image failed!" << std::endl; return false; @@ -219,4 +225,10 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) { } return true; } + +void Transforms::SetPaddingSize(int max_h, int max_w) { + max_h_ = max_h; + max_w_ = max_w; +} + } // namespace PaddleX