提交 23a62440 编写于 作者: J jack

remove batch_size args

上级 59d1e941
...@@ -72,23 +72,20 @@ class Model { ...@@ -72,23 +72,20 @@ class Model {
* @param use_trt: use Tensor RT or not when infering * @param use_trt: use Tensor RT or not when infering
* @param gpu_id: the id of gpu when infering with using gpu * @param gpu_id: the id of gpu when infering with using gpu
* @param key: the key of encryption when using encrypted model * @param key: the key of encryption when using encrypted model
* @param batch_size: batch size of infering
* */ * */
void Init(const std::string& model_dir, void Init(const std::string& model_dir,
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false, bool use_trt = false,
int gpu_id = 0, int gpu_id = 0,
std::string key = "", std::string key = "") {
int batch_size = 1) { create_predictor(model_dir, use_gpu, use_trt, gpu_id, key);
create_predictor(model_dir, use_gpu, use_trt, gpu_id, key, batch_size);
} }
void create_predictor(const std::string& model_dir, void create_predictor(const std::string& model_dir,
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false, bool use_trt = false,
int gpu_id = 0, int gpu_id = 0,
std::string key = "", std::string key = "");
int batch_size = 1);
/* /*
* @brief * @brief
......
...@@ -22,8 +22,7 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -22,8 +22,7 @@ void Model::create_predictor(const std::string& model_dir,
bool use_gpu, bool use_gpu,
bool use_trt, bool use_trt,
int gpu_id, int gpu_id,
std::string key, std::string key) {
int batch_size) {
paddle::AnalysisConfig config; paddle::AnalysisConfig config;
std::string model_file = model_dir + OS_PATH_SEP + "__model__"; std::string model_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__"; std::string params_file = model_dir + OS_PATH_SEP + "__params__";
...@@ -76,7 +75,6 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -76,7 +75,6 @@ void Model::create_predictor(const std::string& model_dir,
false /* use_calib_mode*/); false /* use_calib_mode*/);
} }
predictor_ = std::move(CreatePaddlePredictor(config)); predictor_ = std::move(CreatePaddlePredictor(config));
inputs_batch_.assign(batch_size, ImageBlob());
} }
bool Model::load_config(const std::string& yaml_input) { bool Model::load_config(const std::string& yaml_input) {
...@@ -192,6 +190,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch, ...@@ -192,6 +190,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
"to function predict()!" << std::endl; "to function predict()!" << std::endl;
return false; return false;
} }
inputs_batch_.assign(im_batch.size(), ImageBlob());
// 处理输入图像 // 处理输入图像
if (!preprocess(im_batch, &inputs_batch_, thread_num)) { if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl; std::cerr << "Preprocess failed!" << std::endl;
...@@ -356,6 +355,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch, ...@@ -356,6 +355,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
return false; return false;
} }
inputs_batch_.assign(im_batch.size(), ImageBlob());
int batch_size = im_batch.size(); int batch_size = im_batch.size();
// 处理输入图像 // 处理输入图像
if (!preprocess(im_batch, &inputs_batch_, thread_num)) { if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
...@@ -637,6 +637,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch, ...@@ -637,6 +637,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
} }
// 处理输入图像 // 处理输入图像
inputs_batch_.assign(im_batch.size(), ImageBlob());
if (!preprocess(im_batch, &inputs_batch_, thread_num)) { if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl; std::cerr << "Preprocess failed!" << std::endl;
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册