提交 01e54a1f 编写于 作者: J jack

use google style

上级 afb8620c
......@@ -13,18 +13,18 @@
// limitations under the License.
#include <glog/logging.h>
#include <omp.h>
#include <algorithm>
#include <chrono>
#include <chrono> // NOLINT
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <utility>
#include <omp.h>
#include "include/paddlex/paddlex.h"
using namespace std::chrono;
using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
......@@ -34,7 +34,9 @@ DEFINE_string(key, "", "key of encryption");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads");
DEFINE_int32(thread_num,
omp_get_num_procs(),
"Number of preprocessing threads");
int main(int argc, char** argv) {
// Parsing command-line
......@@ -51,7 +53,12 @@ int main(int argc, char** argv) {
// 加载模型
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_gpu_id,
FLAGS_key,
FLAGS_batch_size);
// 进行预测
double total_running_time_s = 0.0;
......@@ -70,27 +77,33 @@ int main(int argc, char** argv) {
image_paths.push_back(image_path);
}
imgs = image_paths.size();
for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
auto start = system_clock::now();
// 读图像
int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);
int im_vec_size =
std::min(static_cat<int>(image_paths.size()), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::ClsResult> results(im_vec_size - i, PaddleX::ClsResult());
std::vector<PaddleX::ClsResult> results(im_vec_size - i,
PaddleX::ClsResult());
int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
#pragma omp parallel for num_threads(thread_num)
for(int j = i; j < im_vec_size; ++j){
for (int j = i; j < im_vec_size; ++j) {
im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
}
auto imread_end = system_clock::now();
model.predict(im_vec, results, thread_num);
model.predict(im_vec, &results, thread_num);
auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
total_imread_time_s += static_cast<double>(imread_duration.count()) *
microseconds::period::num /
microseconds::period::den;
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
for(int j = i; j < im_vec_size; ++j) {
total_running_time_s += static_cast<double>(duration.count()) *
microseconds::period::num /
microseconds::period::den;
for (int j = i; j < im_vec_size; ++j) {
std::cout << "Path:" << image_paths[j]
<< ", predict label: " << results[j - i].category
<< ", label_id:" << results[j - i].category_id
......@@ -104,21 +117,17 @@ int main(int argc, char** argv) {
model.predict(im, &result);
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
total_running_time_s += static_cast<double>(duration.count()) *
microseconds::period::num /
microseconds::period::den;
std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl;
}
std::cout << "Total running time: "
<< total_running_time_s
<< " s, average running time: "
<< total_running_time_s / imgs
<< " s/img, total read img time: "
<< total_imread_time_s
<< " s, average read time: "
<< total_imread_time_s / imgs
<< " s/img, batch_size = "
<< FLAGS_batch_size
<< std::endl;
std::cout << "Total running time: " << total_running_time_s
<< " s, average running time: " << total_running_time_s / imgs
<< " s/img, total read img time: " << total_imread_time_s
<< " s, average read time: " << total_imread_time_s / imgs
<< " s/img, batch_size = " << FLAGS_batch_size << std::endl;
return 0;
}
......@@ -13,20 +13,20 @@
// limitations under the License.
#include <glog/logging.h>
#include <omp.h>
#include <algorithm>
#include <chrono>
#include <chrono> // NOLINT
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <utility>
#include <omp.h>
#include "include/paddlex/paddlex.h"
#include "include/paddlex/visualize.h"
using namespace std::chrono;
using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
......@@ -37,8 +37,12 @@ DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_string(save_dir, "output", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_double(threshold, 0.5, "The minimum scores of target boxes which are shown");
DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads");
DEFINE_double(threshold,
0.5,
"The minimum scores of target boxes which are shown");
DEFINE_int32(thread_num,
omp_get_num_procs(),
"Number of preprocessing threads");
int main(int argc, char** argv) {
// 解析命令行参数
......@@ -55,7 +59,12 @@ int main(int argc, char** argv) {
std::cout << "Thread num: " << FLAGS_thread_num << std::endl;
// 加载模型
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_gpu_id,
FLAGS_key,
FLAGS_batch_size);
double total_running_time_s = 0.0;
double total_imread_time_s = 0.0;
......@@ -75,41 +84,47 @@ int main(int argc, char** argv) {
image_paths.push_back(image_path);
}
imgs = image_paths.size();
for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
auto start = system_clock::now();
int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);
int im_vec_size =
std::min(static_cast<int>(image_paths.size()), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::DetResult> results(im_vec_size - i, PaddleX::DetResult());
std::vector<PaddleX::DetResult> results(im_vec_size - i,
PaddleX::DetResult());
int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
#pragma omp parallel for num_threads(thread_num)
for(int j = i; j < im_vec_size; ++j){
for (int j = i; j < im_vec_size; ++j) {
im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
}
auto imread_end = system_clock::now();
model.predict(im_vec, results, thread_num);
model.predict(im_vec, &results, thread_num);
auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
total_imread_time_s += static_cast<double>(imread_duration.count()) *
microseconds::period::num /
microseconds::period::den;
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
//输出结果目标框
for(int j = 0; j < im_vec_size - i; ++j) {
for(int k = 0; k < results[j].boxes.size(); ++k) {
std::cout << "image file: " << image_paths[i + j] << ", ";// << std::endl;
total_running_time_s += static_cast<double>(duration.count()) *
microseconds::period::num /
microseconds::period::den;
// 输出结果目标框
for (int j = 0; j < im_vec_size - i; ++j) {
for (int k = 0; k < results[j].boxes.size(); ++k) {
std::cout << "image file: " << image_paths[i + j] << ", ";
std::cout << "predict label: " << results[j].boxes[k].category
<< ", label_id:" << results[j].boxes[k].category_id
<< ", score: " << results[j].boxes[k].score << ", box(xmin, ymin, w, h):("
<< ", score: " << results[j].boxes[k].score
<< ", box(xmin, ymin, w, h):("
<< results[j].boxes[k].coordinate[0] << ", "
<< results[j].boxes[k].coordinate[1] << ", "
<< results[j].boxes[k].coordinate[2] << ", "
<< results[j].boxes[k].coordinate[3] << ")" << std::endl;
}
}
// 可视化
for(int j = 0; j < im_vec_size - i; ++j) {
cv::Mat vis_img =
PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap, FLAGS_threshold);
for (int j = 0; j < im_vec_size - i; ++j) {
cv::Mat vis_img = PaddleX::Visualize(
im_vec[j], results[j], model.labels, colormap, FLAGS_threshold);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, image_paths[i + j]);
cv::imwrite(save_path, vis_img);
......@@ -124,9 +139,9 @@ int main(int argc, char** argv) {
std::cout << "image file: " << FLAGS_image << std::endl;
std::cout << ", predict label: " << result.boxes[i].category
<< ", label_id:" << result.boxes[i].category_id
<< ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):("
<< result.boxes[i].coordinate[0] << ", "
<< result.boxes[i].coordinate[1] << ", "
<< ", score: " << result.boxes[i].score
<< ", box(xmin, ymin, w, h):(" << result.boxes[i].coordinate[0]
<< ", " << result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << ")" << std::endl;
}
......@@ -141,17 +156,11 @@ int main(int argc, char** argv) {
std::cout << "Visualized output saved as " << save_path << std::endl;
}
std::cout << "Total running time: "
<< total_running_time_s
<< " s, average running time: "
<< total_running_time_s / imgs
<< " s/img, total read img time: "
<< total_imread_time_s
<< " s, average read img time: "
<< total_imread_time_s / imgs
<< " s, batch_size = "
<< FLAGS_batch_size
<< std::endl;
std::cout << "Total running time: " << total_running_time_s
<< " s, average running time: " << total_running_time_s / imgs
<< " s/img, total read img time: " << total_imread_time_s
<< " s, average read img time: " << total_imread_time_s / imgs
<< " s, batch_size = " << FLAGS_batch_size << std::endl;
return 0;
}
......@@ -13,19 +13,19 @@
// limitations under the License.
#include <glog/logging.h>
#include <omp.h>
#include <algorithm>
#include <chrono>
#include <chrono> // NOLINT
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <utility>
#include <omp.h>
#include "include/paddlex/paddlex.h"
#include "include/paddlex/visualize.h"
using namespace std::chrono;
using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
......@@ -36,7 +36,9 @@ DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_string(save_dir, "output", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads");
DEFINE_int32(thread_num,
omp_get_num_procs(),
"Number of preprocessing threads");
int main(int argc, char** argv) {
// 解析命令行参数
......@@ -53,7 +55,12 @@ int main(int argc, char** argv) {
// 加载模型
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_gpu_id,
FLAGS_key,
FLAGS_batch_size);
double total_running_time_s = 0.0;
double total_imread_time_s = 0.0;
......@@ -72,25 +79,31 @@ int main(int argc, char** argv) {
image_paths.push_back(image_path);
}
imgs = image_paths.size();
for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size){
for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
auto start = system_clock::now();
int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);
int im_vec_size =
std::min(static_cast<int>(image_paths.size()), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::SegResult> results(im_vec_size - i, PaddleX::SegResult());
std::vector<PaddleX::SegResult> results(im_vec_size - i,
PaddleX::SegResult());
int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
#pragma omp parallel for num_threads(thread_num)
for(int j = i; j < im_vec_size; ++j){
for (int j = i; j < im_vec_size; ++j) {
im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
}
auto imread_end = system_clock::now();
model.predict(im_vec, results, thread_num);
model.predict(im_vec, &results, thread_num);
auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
total_imread_time_s += static_cast<double>(imread_duration.count()) *
microseconds::period::num /
microseconds::period::den;
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
total_running_time_s += static_cast<double>(duration.count()) *
microseconds::period::num /
microseconds::period::den;
// 可视化
for(int j = 0; j < im_vec_size - i; ++j) {
for (int j = 0; j < im_vec_size - i; ++j) {
cv::Mat vis_img =
PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap);
std::string save_path =
......@@ -106,7 +119,9 @@ int main(int argc, char** argv) {
model.predict(im, &result);
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
total_running_time_s += static_cast<double>(duration.count()) *
microseconds::period::num /
microseconds::period::den;
// 可视化
cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap);
std::string save_path =
......@@ -115,17 +130,11 @@ int main(int argc, char** argv) {
result.clear();
std::cout << "Visualized output saved as " << save_path << std::endl;
}
std::cout << "Total running time: "
<< total_running_time_s
<< " s, average running time: "
<< total_running_time_s / imgs
<< " s/img, total read img time: "
<< total_imread_time_s
<< " s, average read img time: "
<< total_imread_time_s / imgs
<< " s, batch_size = "
<< FLAGS_batch_size
<< std::endl;
std::cout << "Total running time: " << total_running_time_s
<< " s, average running time: " << total_running_time_s / imgs
<< " s/img, total read img time: " << total_imread_time_s
<< " s, average read img time: " << total_imread_time_s / imgs
<< " s, batch_size = " << FLAGS_batch_size << std::endl;
return 0;
}
......@@ -54,4 +54,4 @@ class ConfigPaser {
YAML::Node Transforms_;
};
} // namespace PaddleDetection
} // namespace PaddleX
......@@ -16,8 +16,11 @@
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "yaml-cpp/yaml.h"
#ifdef _WIN32
......@@ -28,13 +31,13 @@
#include "paddle_inference_api.h" // NOLINT
#include "config_parser.h"
#include "results.h"
#include "transforms.h"
#include "config_parser.h" // NOLINT
#include "results.h" // NOLINT
#include "transforms.h" // NOLINT
#ifdef WITH_ENCRYPTION
#include "paddle_model_decrypt.h"
#include "model_code.h"
#include "paddle_model_decrypt.h" // NOLINT
#include "model_code.h" // NOLINT
#endif
namespace PaddleX {
......@@ -119,7 +122,9 @@ class Model {
* each thread run preprocess on single image matrix
* @return true if preprocess a batch of image matrixs successfully
* */
bool preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch, int thread_num = 1);
bool preprocess(const std::vector<cv::Mat> &input_im_batch,
std::vector<ImageBlob> *blob_batch,
int thread_num = 1);
/*
* @brief
......@@ -143,7 +148,9 @@ class Model {
* on single image matrix
* @return true if predict successfully
* */
bool predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results, int thread_num = 1);
bool predict(const std::vector<cv::Mat> &im_batch,
std::vector<ClsResult> *results,
int thread_num = 1);
/*
* @brief
......@@ -167,7 +174,9 @@ class Model {
* on single image matrix
* @return true if predict successfully
* */
bool predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult> &result, int thread_num = 1);
bool predict(const std::vector<cv::Mat> &im_batch,
std::vector<DetResult> *result,
int thread_num = 1);
/*
* @brief
......@@ -191,7 +200,9 @@ class Model {
* on single image matrix
* @return true if predict successfully
* */
bool predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult> &result, int thread_num = 1);
bool predict(const std::vector<cv::Mat> &im_batch,
std::vector<SegResult> *result,
int thread_num = 1);
// model type, include 3 type: classifier, detector, segmenter
std::string type;
......@@ -209,4 +220,4 @@ class Model {
// a predictor which run the model predicting
std::unique_ptr<paddle::PaddlePredictor> predictor_;
};
} // namespce of PaddleX
} // namespace PaddleX
......@@ -214,6 +214,7 @@ class Padding : public Transform {
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int coarsest_stride_ = -1;
int width_ = 0;
......@@ -229,6 +230,7 @@ class Transforms {
void Init(const YAML::Node& node, bool to_rgb = true);
std::shared_ptr<Transform> CreateTransform(const std::string& name);
bool Run(cv::Mat* im, ImageBlob* data);
private:
std::vector<std::shared_ptr<Transform>> transforms_;
bool to_rgb_ = true;
......
......@@ -94,4 +94,4 @@ cv::Mat Visualize(const cv::Mat& img,
* */
std::string generate_save_path(const std::string& save_dir,
const std::string& file_path);
} // namespce of PaddleX
} // namespace PaddleX
......@@ -11,10 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <omp.h>
#include "include/paddlex/paddlex.h"
#include <algorithm>
#include <cstring>
#include "include/paddlex/paddlex.h"
namespace PaddleX {
void Model::create_predictor(const std::string& model_dir,
......@@ -32,13 +32,14 @@ void Model::create_predictor(const std::string& model_dir,
std::string model_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__";
#ifdef WITH_ENCRYPTION
if (key != ""){
if (key != "") {
model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
paddle_security_load_model(&config, key.c_str(), model_file.c_str(), params_file.c_str());
paddle_security_load_model(
&config, key.c_str(), model_file.c_str(), params_file.c_str());
}
#endif
if (key == ""){
if (key == "") {
config.SetModel(model_file, params_file);
}
if (use_gpu) {
......@@ -70,11 +71,11 @@ bool Model::load_config(const std::string& model_dir) {
name = config["Model"].as<std::string>();
std::string version = config["version"].as<std::string>();
if (version[0] == '0') {
std::cerr << "[Init] Version of the loaded model is lower than 1.0.0, deployment "
<< "cannot be done, please refer to "
<< "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/tutorials/deploy/upgrade_version.md "
<< "to transfer version."
<< std::endl;
std::cerr << "[Init] Version of the loaded model is lower than 1.0.0, "
<< "deployment cannot be done, please refer to "
<< "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs"
<< "/tutorials/deploy/upgrade_version.md "
<< "to transfer version." << std::endl;
return false;
}
bool to_rgb = true;
......@@ -108,14 +109,16 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
}
// use openmp
bool Model::preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch, int thread_num) {
bool Model::preprocess(const std::vector<cv::Mat>& input_im_batch,
std::vector<ImageBlob>* blob_batch,
int thread_num) {
int batch_size = input_im_batch.size();
bool success = true;
thread_num = std::min(thread_num, batch_size);
#pragma omp parallel for num_threads(thread_num)
for(int i = 0; i < input_im_batch.size(); ++i) {
for (int i = 0; i < input_im_batch.size(); ++i) {
cv::Mat im = input_im_batch[i].clone();
if(!transforms_.Run(&im, &blob_batch[i])){
if (!transforms_.Run(&im, &(*blob_batch)[i])) {
success = false;
}
}
......@@ -127,8 +130,7 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
"to function predict()!"
<< std::endl;
"to function predict()!" << std::endl;
return false;
}
// 处理输入图像
......@@ -161,23 +163,23 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
return true;
}
bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results, int thread_num) {
for(auto &inputs: inputs_batch_) {
bool Model::predict(const std::vector<cv::Mat>& im_batch,
std::vector<ClsResult>* results,
int thread_num) {
for (auto& inputs : inputs_batch_) {
inputs.clear();
}
if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
<< std::endl;
"function predict()!" << std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!"
<< std::endl;
"to function predict()!" << std::endl;
return false;
}
// 处理输入图像
if (!preprocess(im_batch, inputs_batch_, thread_num)) {
if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
......@@ -188,11 +190,13 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult>
int w = inputs_batch_[0].new_im_size_[1];
in_tensor->Reshape({batch_size, 3, h, w});
std::vector<float> inputs_data(batch_size * 3 * h * w);
for(int i = 0; i < batch_size; ++i) {
std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
for (int i = 0; i < batch_size; ++i) {
std::copy(inputs_batch_[i].im_data_.begin(),
inputs_batch_[i].im_data_.end(),
inputs_data.begin() + i * 3 * h * w);
}
in_tensor->copy_from_cpu(inputs_data.data());
//in_tensor->copy_from_cpu(inputs_.im_data_.data());
// in_tensor->copy_from_cpu(inputs_.im_data_.data());
predictor_->ZeroCopyRun();
// 取出模型的输出结果
auto output_names = predictor_->GetOutputNames();
......@@ -206,15 +210,15 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult>
output_tensor->copy_to_cpu(outputs_.data());
// 对模型输出结果进行后处理
int single_batch_size = size / batch_size;
for(int i = 0; i < batch_size; ++i) {
for (int i = 0; i < batch_size; ++i) {
auto start_ptr = std::begin(outputs_);
auto end_ptr = std::begin(outputs_);
std::advance(start_ptr, i * single_batch_size);
std::advance(end_ptr, (i + 1) * single_batch_size);
auto ptr = std::max_element(start_ptr, end_ptr);
results[i].category_id = std::distance(start_ptr, ptr);
results[i].score = *ptr;
results[i].category = labels[results[i].category_id];
(*results)[i].category_id = std::distance(start_ptr, ptr);
(*results)[i].score = *ptr;
(*results)[i].category = labels[(*results)[i].category_id];
}
return true;
}
......@@ -224,13 +228,11 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
result->clear();
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!"
<< std::endl;
"to function predict()!" << std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!"
<< std::endl;
"to function predict()!" << std::endl;
return false;
}
......@@ -324,25 +326,25 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
return true;
}
bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult> &result, int thread_num) {
for(auto &inputs: inputs_batch_) {
bool Model::predict(const std::vector<cv::Mat>& im_batch,
std::vector<DetResult>* result,
int thread_num) {
for (auto& inputs : inputs_batch_) {
inputs.clear();
}
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!"
<< std::endl;
"to function predict()!" << std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!"
<< std::endl;
"to function predict()!" << std::endl;
return false;
}
int batch_size = im_batch.size();
// 处理输入图像
if (!preprocess(im_batch, inputs_batch_, thread_num)) {
if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
......@@ -351,28 +353,29 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
if (name == "FasterRCNN" || name == "MaskRCNN") {
int max_h = -1;
int max_w = -1;
for(int i = 0; i < batch_size; ++i) {
for (int i = 0; i < batch_size; ++i) {
max_h = std::max(max_h, inputs_batch_[i].new_im_size_[0]);
max_w = std::max(max_w, inputs_batch_[i].new_im_size_[1]);
//std::cout << "(" << inputs_batch_[i].new_im_size_[0]
// std::cout << "(" << inputs_batch_[i].new_im_size_[0]
// << ", " << inputs_batch_[i].new_im_size_[1]
// << ")" << std::endl;
}
thread_num = std::min(thread_num, batch_size);
#pragma omp parallel for num_threads(thread_num)
for(int i = 0; i < batch_size; ++i) {
for (int i = 0; i < batch_size; ++i) {
int h = inputs_batch_[i].new_im_size_[0];
int w = inputs_batch_[i].new_im_size_[1];
int c = im_batch[i].channels();
if(max_h != h || max_w != w) {
if (max_h != h || max_w != w) {
std::vector<float> temp_buffer(c * max_h * max_w);
float *temp_ptr = temp_buffer.data();
float *ptr = inputs_batch_[i].im_data_.data();
for(int cur_channel = c - 1; cur_channel >= 0; --cur_channel) {
float* temp_ptr = temp_buffer.data();
float* ptr = inputs_batch_[i].im_data_.data();
for (int cur_channel = c - 1; cur_channel >= 0; --cur_channel) {
int ori_pos = cur_channel * h * w + (h - 1) * w;
int des_pos = cur_channel * max_h * max_w + (h - 1) * max_w;
for(int start_pos = ori_pos; start_pos >= cur_channel * h * w; start_pos -= w, des_pos -= max_w) {
memcpy(temp_ptr + des_pos, ptr + start_pos, w * sizeof(float));
int last_pos = cur_channel * h * w;
for (; ori_pos >= last_pos; ori_pos -= w, des_pos -= max_w) {
memcpy(temp_ptr + des_pos, ptr + ori_pos, w * sizeof(float));
}
}
inputs_batch_[i].im_data_.swap(temp_buffer);
......@@ -387,16 +390,20 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
auto im_tensor = predictor_->GetInputTensor("image");
im_tensor->Reshape({batch_size, 3, h, w});
std::vector<float> inputs_data(batch_size * 3 * h * w);
for(int i = 0; i < batch_size; ++i) {
std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
for (int i = 0; i < batch_size; ++i) {
std::copy(inputs_batch_[i].im_data_.begin(),
inputs_batch_[i].im_data_.end(),
inputs_data.begin() + i * 3 * h * w);
}
im_tensor->copy_from_cpu(inputs_data.data());
if (name == "YOLOv3") {
auto im_size_tensor = predictor_->GetInputTensor("im_size");
im_size_tensor->Reshape({batch_size, 2});
std::vector<int> inputs_data_size(batch_size * 2);
for(int i = 0; i < batch_size; ++i){
std::copy(inputs_batch_[i].ori_im_size_.begin(), inputs_batch_[i].ori_im_size_.end(), inputs_data_size.begin() + 2 * i);
for (int i = 0; i < batch_size; ++i) {
std::copy(inputs_batch_[i].ori_im_size_.begin(),
inputs_batch_[i].ori_im_size_.end(),
inputs_data_size.begin() + 2 * i);
}
im_size_tensor->copy_from_cpu(inputs_data_size.data());
} else if (name == "FasterRCNN" || name == "MaskRCNN") {
......@@ -407,7 +414,7 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
std::vector<float> im_info(3 * batch_size);
std::vector<float> im_shape(3 * batch_size);
for(int i = 0; i < batch_size; ++i) {
for (int i = 0; i < batch_size; ++i) {
float ori_h = static_cast<float>(inputs_batch_[i].ori_im_size_[0]);
float ori_w = static_cast<float>(inputs_batch_[i].ori_im_size_[1]);
float new_h = static_cast<float>(inputs_batch_[i].new_im_size_[0]);
......@@ -444,9 +451,9 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
int num_boxes = size / 6;
// 解析预测框box
for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
for(int j = lod_vector[0][i]; j < lod_vector[0][i + 1]; ++j) {
for (int j = lod_vector[0][i]; j < lod_vector[0][i + 1]; ++j) {
Box box;
box.category_id = static_cast<int> (round(output_box[j * 6]));
box.category_id = static_cast<int>(round(output_box[j * 6]));
box.category = labels[box.category_id];
box.score = output_box[j * 6 + 1];
float xmin = output_box[j * 6 + 2];
......@@ -456,7 +463,7 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
float w = xmax - xmin + 1;
float h = ymax - ymin + 1;
box.coordinate = {xmin, ymin, w, h};
result[i].boxes.push_back(std::move(box));
(*result)[i].boxes.push_back(std::move(box));
}
}
......@@ -474,11 +481,13 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
output_mask.resize(masks_size);
output_mask_tensor->copy_to_cpu(output_mask.data());
int mask_idx = 0;
for(int i = 0; i < lod_vector[0].size() - 1; ++i) {
result[i].mask_resolution = output_mask_shape[2];
for(int j = 0; j < result[i].boxes.size(); ++j) {
Box* box = &result[i].boxes[j];
auto begin_mask = output_mask.begin() + (mask_idx * classes + box->category_id) * mask_pixels;
for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
(*result)[i].mask_resolution = output_mask_shape[2];
for (int j = 0; j < (*result)[i].boxes.size(); ++j) {
Box* box = &(*result)[i].boxes[j];
int category_id = box->category_id;
auto begin_mask = output_mask.begin() +
(mask_idx * classes + category_id) * mask_pixels;
auto end_mask = begin_mask + mask_pixels;
box->mask.data.assign(begin_mask, end_mask);
box->mask.shape = {static_cast<int>(box->coordinate[2]),
......@@ -495,13 +504,11 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
inputs_.clear();
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!"
<< std::endl;
"to function predict()!" << std::endl;
return false;
} else if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
<< std::endl;
"function predict()!" << std::endl;
return false;
}
......@@ -599,41 +606,43 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
return true;
}
bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult> &result, int thread_num) {
for(auto &inputs: inputs_batch_) {
bool Model::predict(const std::vector<cv::Mat>& im_batch,
std::vector<SegResult>* result,
int thread_num) {
for (auto& inputs : inputs_batch_) {
inputs.clear();
}
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!"
<< std::endl;
"to function predict()!" << std::endl;
return false;
} else if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
<< std::endl;
"function predict()!" << std::endl;
return false;
}
// 处理输入图像
if (!preprocess(im_batch, inputs_batch_, thread_num)) {
if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
int batch_size = im_batch.size();
result.clear();
result.resize(batch_size);
(*result).clear();
(*result).resize(batch_size);
int h = inputs_batch_[0].new_im_size_[0];
int w = inputs_batch_[0].new_im_size_[1];
auto im_tensor = predictor_->GetInputTensor("image");
im_tensor->Reshape({batch_size, 3, h, w});
std::vector<float> inputs_data(batch_size * 3 * h * w);
for(int i = 0; i < batch_size; ++i) {
std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
for (int i = 0; i < batch_size; ++i) {
std::copy(inputs_batch_[i].im_data_.begin(),
inputs_batch_[i].im_data_.end(),
inputs_data.begin() + i * 3 * h * w);
}
im_tensor->copy_from_cpu(inputs_data.data());
//im_tensor->copy_from_cpu(inputs_.im_data_.data());
// im_tensor->copy_from_cpu(inputs_.im_data_.data());
// 使用加载的模型进行预测
predictor_->ZeroCopyRun();
......@@ -652,13 +661,15 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult>
auto output_labels_iter = output_labels.begin();
int single_batch_size = size / batch_size;
for(int i = 0; i < batch_size; ++i) {
result[i].label_map.data.resize(single_batch_size);
result[i].label_map.shape.push_back(1);
for(int j = 1; j < output_label_shape.size(); ++j) {
result[i].label_map.shape.push_back(output_label_shape[j]);
for (int i = 0; i < batch_size; ++i) {
(*result)[i].label_map.data.resize(single_batch_size);
(*result)[i].label_map.shape.push_back(1);
for (int j = 1; j < output_label_shape.size(); ++j) {
(*result)[i].label_map.shape.push_back(output_label_shape[j]);
}
std::copy(output_labels_iter + i * single_batch_size, output_labels_iter + (i + 1) * single_batch_size, result[i].label_map.data.data());
std::copy(output_labels_iter + i * single_batch_size,
output_labels_iter + (i + 1) * single_batch_size,
(*result)[i].label_map.data.data());
}
// 获取预测置信度scoremap
......@@ -674,28 +685,30 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult>
auto output_scores_iter = output_scores.begin();
int single_batch_score_size = size / batch_size;
for(int i = 0; i < batch_size; ++i) {
result[i].score_map.data.resize(single_batch_score_size);
result[i].score_map.shape.push_back(1);
for(int j = 1; j < output_score_shape.size(); ++j) {
result[i].score_map.shape.push_back(output_score_shape[j]);
for (int i = 0; i < batch_size; ++i) {
(*result)[i].score_map.data.resize(single_batch_score_size);
(*result)[i].score_map.shape.push_back(1);
for (int j = 1; j < output_score_shape.size(); ++j) {
(*result)[i].score_map.shape.push_back(output_score_shape[j]);
}
std::copy(output_scores_iter + i * single_batch_score_size, output_scores_iter + (i + 1) * single_batch_score_size, result[i].score_map.data.data());
std::copy(output_scores_iter + i * single_batch_score_size,
output_scores_iter + (i + 1) * single_batch_score_size,
(*result)[i].score_map.data.data());
}
// 解析输出结果到原图大小
for(int i = 0; i < batch_size; ++i) {
std::vector<uint8_t> label_map(result[i].label_map.data.begin(),
result[i].label_map.data.end());
cv::Mat mask_label(result[i].label_map.shape[1],
result[i].label_map.shape[2],
for (int i = 0; i < batch_size; ++i) {
std::vector<uint8_t> label_map((*result)[i].label_map.data.begin(),
(*result)[i].label_map.data.end());
cv::Mat mask_label((*result)[i].label_map.shape[1],
(*result)[i].label_map.shape[2],
CV_8UC1,
label_map.data());
cv::Mat mask_score(result[i].score_map.shape[2],
result[i].score_map.shape[3],
cv::Mat mask_score((*result)[i].score_map.shape[2],
(*result)[i].score_map.shape[3],
CV_32FC1,
result[i].score_map.data.data());
(*result)[i].score_map.data.data());
int idx = 1;
int len_postprocess = inputs_batch_[i].im_size_before_resize_.size();
for (std::vector<std::string>::reverse_iterator iter =
......@@ -703,14 +716,16 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult>
iter != inputs_batch_[i].reshape_order_.rend();
++iter) {
if (*iter == "padding") {
auto before_shape = inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
auto before_shape =
inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
inputs_batch_[i].im_size_before_resize_.pop_back();
auto padding_w = before_shape[0];
auto padding_h = before_shape[1];
mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
} else if (*iter == "resize") {
auto before_shape = inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
auto before_shape =
inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
inputs_batch_[i].im_size_before_resize_.pop_back();
auto resize_w = before_shape[0];
auto resize_h = before_shape[1];
......@@ -729,14 +744,14 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult>
}
++idx;
}
result[i].label_map.data.assign(mask_label.begin<uint8_t>(),
(*result)[i].label_map.data.assign(mask_label.begin<uint8_t>(),
mask_label.end<uint8_t>());
result[i].label_map.shape = {mask_label.rows, mask_label.cols};
result[i].score_map.data.assign(mask_score.begin<float>(),
(*result)[i].label_map.shape = {mask_label.rows, mask_label.cols};
(*result)[i].score_map.data.assign(mask_score.begin<float>(),
mask_score.end<float>());
result[i].score_map.shape = {mask_score.rows, mask_score.cols};
(*result)[i].score_map.shape = {mask_score.rows, mask_score.cols};
}
return true;
}
} // namespce of PaddleX
} // namespace PaddleX
......@@ -145,4 +145,4 @@ std::string generate_save_path(const std::string& save_dir,
std::string image_name(file_path.substr(pos + 1));
return save_dir + OS_PATH_SEP + image_name;
}
} // namespace of PaddleX
} // namespace PaddleX
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册