未验证 提交 56acdf95 编写于 作者: Z Zeyu Chen 提交者: GitHub

Remove batch_size function parameter of create_predictor

find_package(Git REQUIRED)
include(ExternalProject) include(ExternalProject)
message("${CMAKE_BUILD_TYPE}") message("${CMAKE_BUILD_TYPE}")
......
...@@ -57,8 +57,7 @@ int main(int argc, char** argv) { ...@@ -57,8 +57,7 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_batch_size);
// 进行预测 // 进行预测
double total_running_time_s = 0.0; double total_running_time_s = 0.0;
......
...@@ -62,8 +62,7 @@ int main(int argc, char** argv) { ...@@ -62,8 +62,7 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_batch_size);
double total_running_time_s = 0.0; double total_running_time_s = 0.0;
double total_imread_time_s = 0.0; double total_imread_time_s = 0.0;
......
...@@ -59,8 +59,7 @@ int main(int argc, char** argv) { ...@@ -59,8 +59,7 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_batch_size);
double total_running_time_s = 0.0; double total_running_time_s = 0.0;
double total_imread_time_s = 0.0; double total_imread_time_s = 0.0;
......
...@@ -72,23 +72,20 @@ class Model { ...@@ -72,23 +72,20 @@ class Model {
* @param use_trt: use Tensor RT or not when infering * @param use_trt: use Tensor RT or not when infering
* @param gpu_id: the id of gpu when infering with using gpu * @param gpu_id: the id of gpu when infering with using gpu
* @param key: the key of encryption when using encrypted model * @param key: the key of encryption when using encrypted model
* @param batch_size: batch size of infering
* */ * */
void Init(const std::string& model_dir, void Init(const std::string& model_dir,
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false, bool use_trt = false,
int gpu_id = 0, int gpu_id = 0,
std::string key = "", std::string key = "") {
int batch_size = 1) { create_predictor(model_dir, use_gpu, use_trt, gpu_id, key);
create_predictor(model_dir, use_gpu, use_trt, gpu_id, key, batch_size);
} }
void create_predictor(const std::string& model_dir, void create_predictor(const std::string& model_dir,
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false, bool use_trt = false,
int gpu_id = 0, int gpu_id = 0,
std::string key = "", std::string key = "");
int batch_size = 1);
/* /*
* @brief * @brief
......
...@@ -22,8 +22,7 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -22,8 +22,7 @@ void Model::create_predictor(const std::string& model_dir,
bool use_gpu, bool use_gpu,
bool use_trt, bool use_trt,
int gpu_id, int gpu_id,
std::string key, std::string key) {
int batch_size) {
paddle::AnalysisConfig config; paddle::AnalysisConfig config;
std::string model_file = model_dir + OS_PATH_SEP + "__model__"; std::string model_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__"; std::string params_file = model_dir + OS_PATH_SEP + "__params__";
...@@ -76,7 +75,6 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -76,7 +75,6 @@ void Model::create_predictor(const std::string& model_dir,
false /* use_calib_mode*/); false /* use_calib_mode*/);
} }
predictor_ = std::move(CreatePaddlePredictor(config)); predictor_ = std::move(CreatePaddlePredictor(config));
inputs_batch_.assign(batch_size, ImageBlob());
} }
bool Model::load_config(const std::string& yaml_input) { bool Model::load_config(const std::string& yaml_input) {
...@@ -192,6 +190,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch, ...@@ -192,6 +190,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
"to function predict()!" << std::endl; "to function predict()!" << std::endl;
return false; return false;
} }
inputs_batch_.assign(im_batch.size(), ImageBlob());
// 处理输入图像 // 处理输入图像
if (!preprocess(im_batch, &inputs_batch_, thread_num)) { if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl; std::cerr << "Preprocess failed!" << std::endl;
...@@ -356,6 +355,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch, ...@@ -356,6 +355,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
return false; return false;
} }
inputs_batch_.assign(im_batch.size(), ImageBlob());
int batch_size = im_batch.size(); int batch_size = im_batch.size();
// 处理输入图像 // 处理输入图像
if (!preprocess(im_batch, &inputs_batch_, thread_num)) { if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
...@@ -637,6 +637,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch, ...@@ -637,6 +637,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
} }
// 处理输入图像 // 处理输入图像
inputs_batch_.assign(im_batch.size(), ImageBlob());
if (!preprocess(im_batch, &inputs_batch_, thread_num)) { if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl; std::cerr << "Preprocess failed!" << std::endl;
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册