提交 01e54a1f 编写于 作者: J jack

use google style

上级 afb8620c
...@@ -13,18 +13,18 @@ ...@@ -13,18 +13,18 @@
// limitations under the License. // limitations under the License.
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h>
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono> // NOLINT
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <omp.h>
#include "include/paddlex/paddlex.h" #include "include/paddlex/paddlex.h"
using namespace std::chrono; using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
...@@ -34,7 +34,9 @@ DEFINE_string(key, "", "key of encryption"); ...@@ -34,7 +34,9 @@ DEFINE_string(key, "", "key of encryption");
DEFINE_string(image, "", "Path of test image file"); DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file"); DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_int32(batch_size, 1, "Batch size of infering"); DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads"); DEFINE_int32(thread_num,
omp_get_num_procs(),
"Number of preprocessing threads");
int main(int argc, char** argv) { int main(int argc, char** argv) {
// Parsing command-line // Parsing command-line
...@@ -51,7 +53,12 @@ int main(int argc, char** argv) { ...@@ -51,7 +53,12 @@ int main(int argc, char** argv) {
// 加载模型 // 加载模型
PaddleX::Model model; PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size); model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_gpu_id,
FLAGS_key,
FLAGS_batch_size);
// 进行预测 // 进行预测
double total_running_time_s = 0.0; double total_running_time_s = 0.0;
...@@ -70,27 +77,33 @@ int main(int argc, char** argv) { ...@@ -70,27 +77,33 @@ int main(int argc, char** argv) {
image_paths.push_back(image_path); image_paths.push_back(image_path);
} }
imgs = image_paths.size(); imgs = image_paths.size();
for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) { for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
auto start = system_clock::now(); auto start = system_clock::now();
// 读图像 // 读图像
int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size); int im_vec_size =
std::min(static_cat<int>(image_paths.size()), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i); std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::ClsResult> results(im_vec_size - i, PaddleX::ClsResult()); std::vector<PaddleX::ClsResult> results(im_vec_size - i,
PaddleX::ClsResult());
int thread_num = std::min(FLAGS_thread_num, im_vec_size - i); int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
#pragma omp parallel for num_threads(thread_num) #pragma omp parallel for num_threads(thread_num)
for(int j = i; j < im_vec_size; ++j){ for (int j = i; j < im_vec_size; ++j) {
im_vec[j - i] = std::move(cv::imread(image_paths[j], 1)); im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
} }
auto imread_end = system_clock::now(); auto imread_end = system_clock::now();
model.predict(im_vec, results, thread_num); model.predict(im_vec, &results, thread_num);
auto imread_duration = duration_cast<microseconds>(imread_end - start); auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den; total_imread_time_s += static_cast<double>(imread_duration.count()) *
microseconds::period::num /
microseconds::period::den;
auto end = system_clock::now(); auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start); auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; total_running_time_s += static_cast<double>(duration.count()) *
for(int j = i; j < im_vec_size; ++j) { microseconds::period::num /
microseconds::period::den;
for (int j = i; j < im_vec_size; ++j) {
std::cout << "Path:" << image_paths[j] std::cout << "Path:" << image_paths[j]
<< ", predict label: " << results[j - i].category << ", predict label: " << results[j - i].category
<< ", label_id:" << results[j - i].category_id << ", label_id:" << results[j - i].category_id
...@@ -104,21 +117,17 @@ int main(int argc, char** argv) { ...@@ -104,21 +117,17 @@ int main(int argc, char** argv) {
model.predict(im, &result); model.predict(im, &result);
auto end = system_clock::now(); auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start); auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; total_running_time_s += static_cast<double>(duration.count()) *
microseconds::period::num /
microseconds::period::den;
std::cout << "Predict label: " << result.category std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id << ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl; << ", score: " << result.score << std::endl;
} }
std::cout << "Total running time: " std::cout << "Total running time: " << total_running_time_s
<< total_running_time_s << " s, average running time: " << total_running_time_s / imgs
<< " s, average running time: " << " s/img, total read img time: " << total_imread_time_s
<< total_running_time_s / imgs << " s, average read time: " << total_imread_time_s / imgs
<< " s/img, total read img time: " << " s/img, batch_size = " << FLAGS_batch_size << std::endl;
<< total_imread_time_s
<< " s, average read time: "
<< total_imread_time_s / imgs
<< " s/img, batch_size = "
<< FLAGS_batch_size
<< std::endl;
return 0; return 0;
} }
...@@ -13,20 +13,20 @@ ...@@ -13,20 +13,20 @@
// limitations under the License. // limitations under the License.
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h>
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono> // NOLINT
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <omp.h>
#include "include/paddlex/paddlex.h" #include "include/paddlex/paddlex.h"
#include "include/paddlex/visualize.h" #include "include/paddlex/visualize.h"
using namespace std::chrono; using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
...@@ -37,8 +37,12 @@ DEFINE_string(image, "", "Path of test image file"); ...@@ -37,8 +37,12 @@ DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file"); DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_string(save_dir, "output", "Path to save visualized image"); DEFINE_string(save_dir, "output", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "Batch size of infering"); DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_double(threshold, 0.5, "The minimum scores of target boxes which are shown"); DEFINE_double(threshold,
DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads"); 0.5,
"The minimum scores of target boxes which are shown");
DEFINE_int32(thread_num,
omp_get_num_procs(),
"Number of preprocessing threads");
int main(int argc, char** argv) { int main(int argc, char** argv) {
// 解析命令行参数 // 解析命令行参数
...@@ -55,7 +59,12 @@ int main(int argc, char** argv) { ...@@ -55,7 +59,12 @@ int main(int argc, char** argv) {
std::cout << "Thread num: " << FLAGS_thread_num << std::endl; std::cout << "Thread num: " << FLAGS_thread_num << std::endl;
// 加载模型 // 加载模型
PaddleX::Model model; PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size); model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_gpu_id,
FLAGS_key,
FLAGS_batch_size);
double total_running_time_s = 0.0; double total_running_time_s = 0.0;
double total_imread_time_s = 0.0; double total_imread_time_s = 0.0;
...@@ -75,41 +84,47 @@ int main(int argc, char** argv) { ...@@ -75,41 +84,47 @@ int main(int argc, char** argv) {
image_paths.push_back(image_path); image_paths.push_back(image_path);
} }
imgs = image_paths.size(); imgs = image_paths.size();
for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) { for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
auto start = system_clock::now(); auto start = system_clock::now();
int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size); int im_vec_size =
std::min(static_cast<int>(image_paths.size()), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i); std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::DetResult> results(im_vec_size - i, PaddleX::DetResult()); std::vector<PaddleX::DetResult> results(im_vec_size - i,
PaddleX::DetResult());
int thread_num = std::min(FLAGS_thread_num, im_vec_size - i); int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
#pragma omp parallel for num_threads(thread_num) #pragma omp parallel for num_threads(thread_num)
for(int j = i; j < im_vec_size; ++j){ for (int j = i; j < im_vec_size; ++j) {
im_vec[j - i] = std::move(cv::imread(image_paths[j], 1)); im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
} }
auto imread_end = system_clock::now(); auto imread_end = system_clock::now();
model.predict(im_vec, results, thread_num); model.predict(im_vec, &results, thread_num);
auto imread_duration = duration_cast<microseconds>(imread_end - start); auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den; total_imread_time_s += static_cast<double>(imread_duration.count()) *
microseconds::period::num /
microseconds::period::den;
auto end = system_clock::now(); auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start); auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; total_running_time_s += static_cast<double>(duration.count()) *
//输出结果目标框 microseconds::period::num /
for(int j = 0; j < im_vec_size - i; ++j) { microseconds::period::den;
for(int k = 0; k < results[j].boxes.size(); ++k) { // 输出结果目标框
std::cout << "image file: " << image_paths[i + j] << ", ";// << std::endl; for (int j = 0; j < im_vec_size - i; ++j) {
for (int k = 0; k < results[j].boxes.size(); ++k) {
std::cout << "image file: " << image_paths[i + j] << ", ";
std::cout << "predict label: " << results[j].boxes[k].category std::cout << "predict label: " << results[j].boxes[k].category
<< ", label_id:" << results[j].boxes[k].category_id << ", label_id:" << results[j].boxes[k].category_id
<< ", score: " << results[j].boxes[k].score << ", box(xmin, ymin, w, h):(" << ", score: " << results[j].boxes[k].score
<< ", box(xmin, ymin, w, h):("
<< results[j].boxes[k].coordinate[0] << ", " << results[j].boxes[k].coordinate[0] << ", "
<< results[j].boxes[k].coordinate[1] << ", " << results[j].boxes[k].coordinate[1] << ", "
<< results[j].boxes[k].coordinate[2] << ", " << results[j].boxes[k].coordinate[2] << ", "
<< results[j].boxes[k].coordinate[3] << ")" << std::endl; << results[j].boxes[k].coordinate[3] << ")" << std::endl;
} }
} }
// 可视化 // 可视化
for(int j = 0; j < im_vec_size - i; ++j) { for (int j = 0; j < im_vec_size - i; ++j) {
cv::Mat vis_img = cv::Mat vis_img = PaddleX::Visualize(
PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap, FLAGS_threshold); im_vec[j], results[j], model.labels, colormap, FLAGS_threshold);
std::string save_path = std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, image_paths[i + j]); PaddleX::generate_save_path(FLAGS_save_dir, image_paths[i + j]);
cv::imwrite(save_path, vis_img); cv::imwrite(save_path, vis_img);
...@@ -124,9 +139,9 @@ int main(int argc, char** argv) { ...@@ -124,9 +139,9 @@ int main(int argc, char** argv) {
std::cout << "image file: " << FLAGS_image << std::endl; std::cout << "image file: " << FLAGS_image << std::endl;
std::cout << ", predict label: " << result.boxes[i].category std::cout << ", predict label: " << result.boxes[i].category
<< ", label_id:" << result.boxes[i].category_id << ", label_id:" << result.boxes[i].category_id
<< ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):(" << ", score: " << result.boxes[i].score
<< result.boxes[i].coordinate[0] << ", " << ", box(xmin, ymin, w, h):(" << result.boxes[i].coordinate[0]
<< result.boxes[i].coordinate[1] << ", " << ", " << result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", " << result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << ")" << std::endl; << result.boxes[i].coordinate[3] << ")" << std::endl;
} }
...@@ -141,17 +156,11 @@ int main(int argc, char** argv) { ...@@ -141,17 +156,11 @@ int main(int argc, char** argv) {
std::cout << "Visualized output saved as " << save_path << std::endl; std::cout << "Visualized output saved as " << save_path << std::endl;
} }
std::cout << "Total running time: " std::cout << "Total running time: " << total_running_time_s
<< total_running_time_s << " s, average running time: " << total_running_time_s / imgs
<< " s, average running time: " << " s/img, total read img time: " << total_imread_time_s
<< total_running_time_s / imgs << " s, average read img time: " << total_imread_time_s / imgs
<< " s/img, total read img time: " << " s, batch_size = " << FLAGS_batch_size << std::endl;
<< total_imread_time_s
<< " s, average read img time: "
<< total_imread_time_s / imgs
<< " s, batch_size = "
<< FLAGS_batch_size
<< std::endl;
return 0; return 0;
} }
...@@ -13,19 +13,19 @@ ...@@ -13,19 +13,19 @@
// limitations under the License. // limitations under the License.
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h>
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono> // NOLINT
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <omp.h>
#include "include/paddlex/paddlex.h" #include "include/paddlex/paddlex.h"
#include "include/paddlex/visualize.h" #include "include/paddlex/visualize.h"
using namespace std::chrono; using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
...@@ -36,7 +36,9 @@ DEFINE_string(image, "", "Path of test image file"); ...@@ -36,7 +36,9 @@ DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file"); DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_string(save_dir, "output", "Path to save visualized image"); DEFINE_string(save_dir, "output", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "Batch size of infering"); DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads"); DEFINE_int32(thread_num,
omp_get_num_procs(),
"Number of preprocessing threads");
int main(int argc, char** argv) { int main(int argc, char** argv) {
// 解析命令行参数 // 解析命令行参数
...@@ -53,7 +55,12 @@ int main(int argc, char** argv) { ...@@ -53,7 +55,12 @@ int main(int argc, char** argv) {
// 加载模型 // 加载模型
PaddleX::Model model; PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size); model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_gpu_id,
FLAGS_key,
FLAGS_batch_size);
double total_running_time_s = 0.0; double total_running_time_s = 0.0;
double total_imread_time_s = 0.0; double total_imread_time_s = 0.0;
...@@ -72,25 +79,31 @@ int main(int argc, char** argv) { ...@@ -72,25 +79,31 @@ int main(int argc, char** argv) {
image_paths.push_back(image_path); image_paths.push_back(image_path);
} }
imgs = image_paths.size(); imgs = image_paths.size();
for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size){ for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
auto start = system_clock::now(); auto start = system_clock::now();
int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size); int im_vec_size =
std::min(static_cast<int>(image_paths.size()), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i); std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::SegResult> results(im_vec_size - i, PaddleX::SegResult()); std::vector<PaddleX::SegResult> results(im_vec_size - i,
PaddleX::SegResult());
int thread_num = std::min(FLAGS_thread_num, im_vec_size - i); int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
#pragma omp parallel for num_threads(thread_num) #pragma omp parallel for num_threads(thread_num)
for(int j = i; j < im_vec_size; ++j){ for (int j = i; j < im_vec_size; ++j) {
im_vec[j - i] = std::move(cv::imread(image_paths[j], 1)); im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
} }
auto imread_end = system_clock::now(); auto imread_end = system_clock::now();
model.predict(im_vec, results, thread_num); model.predict(im_vec, &results, thread_num);
auto imread_duration = duration_cast<microseconds>(imread_end - start); auto imread_duration = duration_cast<microseconds>(imread_end - start);
total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den; total_imread_time_s += static_cast<double>(imread_duration.count()) *
microseconds::period::num /
microseconds::period::den;
auto end = system_clock::now(); auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start); auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; total_running_time_s += static_cast<double>(duration.count()) *
microseconds::period::num /
microseconds::period::den;
// 可视化 // 可视化
for(int j = 0; j < im_vec_size - i; ++j) { for (int j = 0; j < im_vec_size - i; ++j) {
cv::Mat vis_img = cv::Mat vis_img =
PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap); PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap);
std::string save_path = std::string save_path =
...@@ -106,7 +119,9 @@ int main(int argc, char** argv) { ...@@ -106,7 +119,9 @@ int main(int argc, char** argv) {
model.predict(im, &result); model.predict(im, &result);
auto end = system_clock::now(); auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start); auto duration = duration_cast<microseconds>(end - start);
total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den; total_running_time_s += static_cast<double>(duration.count()) *
microseconds::period::num /
microseconds::period::den;
// 可视化 // 可视化
cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap); cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap);
std::string save_path = std::string save_path =
...@@ -115,17 +130,11 @@ int main(int argc, char** argv) { ...@@ -115,17 +130,11 @@ int main(int argc, char** argv) {
result.clear(); result.clear();
std::cout << "Visualized output saved as " << save_path << std::endl; std::cout << "Visualized output saved as " << save_path << std::endl;
} }
std::cout << "Total running time: " std::cout << "Total running time: " << total_running_time_s
<< total_running_time_s << " s, average running time: " << total_running_time_s / imgs
<< " s, average running time: " << " s/img, total read img time: " << total_imread_time_s
<< total_running_time_s / imgs << " s, average read img time: " << total_imread_time_s / imgs
<< " s/img, total read img time: " << " s, batch_size = " << FLAGS_batch_size << std::endl;
<< total_imread_time_s
<< " s, average read img time: "
<< total_imread_time_s / imgs
<< " s, batch_size = "
<< FLAGS_batch_size
<< std::endl;
return 0; return 0;
} }
...@@ -54,4 +54,4 @@ class ConfigPaser { ...@@ -54,4 +54,4 @@ class ConfigPaser {
YAML::Node Transforms_; YAML::Node Transforms_;
}; };
} // namespace PaddleDetection } // namespace PaddleX
...@@ -16,8 +16,11 @@ ...@@ -16,8 +16,11 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map>
#include <memory>
#include <numeric> #include <numeric>
#include <string>
#include <vector>
#include "yaml-cpp/yaml.h" #include "yaml-cpp/yaml.h"
#ifdef _WIN32 #ifdef _WIN32
...@@ -28,13 +31,13 @@ ...@@ -28,13 +31,13 @@
#include "paddle_inference_api.h" // NOLINT #include "paddle_inference_api.h" // NOLINT
#include "config_parser.h" #include "config_parser.h" // NOLINT
#include "results.h" #include "results.h" // NOLINT
#include "transforms.h" #include "transforms.h" // NOLINT
#ifdef WITH_ENCRYPTION #ifdef WITH_ENCRYPTION
#include "paddle_model_decrypt.h" #include "paddle_model_decrypt.h" // NOLINT
#include "model_code.h" #include "model_code.h" // NOLINT
#endif #endif
namespace PaddleX { namespace PaddleX {
...@@ -119,7 +122,9 @@ class Model { ...@@ -119,7 +122,9 @@ class Model {
* each thread run preprocess on single image matrix * each thread run preprocess on single image matrix
* @return true if preprocess a batch of image matrixs successfully * @return true if preprocess a batch of image matrixs successfully
* */ * */
bool preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch, int thread_num = 1); bool preprocess(const std::vector<cv::Mat> &input_im_batch,
std::vector<ImageBlob> *blob_batch,
int thread_num = 1);
/* /*
* @brief * @brief
...@@ -143,7 +148,9 @@ class Model { ...@@ -143,7 +148,9 @@ class Model {
* on single image matrix * on single image matrix
* @return true if predict successfully * @return true if predict successfully
* */ * */
bool predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results, int thread_num = 1); bool predict(const std::vector<cv::Mat> &im_batch,
std::vector<ClsResult> *results,
int thread_num = 1);
/* /*
* @brief * @brief
...@@ -167,7 +174,9 @@ class Model { ...@@ -167,7 +174,9 @@ class Model {
* on single image matrix * on single image matrix
* @return true if predict successfully * @return true if predict successfully
* */ * */
bool predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult> &result, int thread_num = 1); bool predict(const std::vector<cv::Mat> &im_batch,
std::vector<DetResult> *result,
int thread_num = 1);
/* /*
* @brief * @brief
...@@ -191,7 +200,9 @@ class Model { ...@@ -191,7 +200,9 @@ class Model {
* on single image matrix * on single image matrix
* @return true if predict successfully * @return true if predict successfully
* */ * */
bool predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult> &result, int thread_num = 1); bool predict(const std::vector<cv::Mat> &im_batch,
std::vector<SegResult> *result,
int thread_num = 1);
// model type, include 3 type: classifier, detector, segmenter // model type, include 3 type: classifier, detector, segmenter
std::string type; std::string type;
...@@ -209,4 +220,4 @@ class Model { ...@@ -209,4 +220,4 @@ class Model {
// a predictor which run the model predicting // a predictor which run the model predicting
std::unique_ptr<paddle::PaddlePredictor> predictor_; std::unique_ptr<paddle::PaddlePredictor> predictor_;
}; };
} // namespce of PaddleX } // namespace PaddleX
...@@ -214,6 +214,7 @@ class Padding : public Transform { ...@@ -214,6 +214,7 @@ class Padding : public Transform {
} }
} }
virtual bool Run(cv::Mat* im, ImageBlob* data); virtual bool Run(cv::Mat* im, ImageBlob* data);
private: private:
int coarsest_stride_ = -1; int coarsest_stride_ = -1;
int width_ = 0; int width_ = 0;
...@@ -229,6 +230,7 @@ class Transforms { ...@@ -229,6 +230,7 @@ class Transforms {
void Init(const YAML::Node& node, bool to_rgb = true); void Init(const YAML::Node& node, bool to_rgb = true);
std::shared_ptr<Transform> CreateTransform(const std::string& name); std::shared_ptr<Transform> CreateTransform(const std::string& name);
bool Run(cv::Mat* im, ImageBlob* data); bool Run(cv::Mat* im, ImageBlob* data);
private: private:
std::vector<std::shared_ptr<Transform>> transforms_; std::vector<std::shared_ptr<Transform>> transforms_;
bool to_rgb_ = true; bool to_rgb_ = true;
......
...@@ -94,4 +94,4 @@ cv::Mat Visualize(const cv::Mat& img, ...@@ -94,4 +94,4 @@ cv::Mat Visualize(const cv::Mat& img,
* */ * */
std::string generate_save_path(const std::string& save_dir, std::string generate_save_path(const std::string& save_dir,
const std::string& file_path); const std::string& file_path);
} // namespce of PaddleX } // namespace PaddleX
此差异已折叠。
...@@ -145,4 +145,4 @@ std::string generate_save_path(const std::string& save_dir, ...@@ -145,4 +145,4 @@ std::string generate_save_path(const std::string& save_dir,
std::string image_name(file_path.substr(pos + 1)); std::string image_name(file_path.substr(pos + 1));
return save_dir + OS_PATH_SEP + image_name; return save_dir + OS_PATH_SEP + image_name;
} }
} // namespace of PaddleX } // namespace PaddleX
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册