提交 7c9c3c06 编写于 作者: J Jack

add output_dir args and give warning when yaml file is not found

上级 8138f9aa
...@@ -18,6 +18,12 @@ ...@@ -18,6 +18,12 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <map> #include <map>
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#else // Linux/Unix
#include <unistd.h>
#endif
#include "yaml-cpp/yaml.h" #include "yaml-cpp/yaml.h"
...@@ -38,9 +44,15 @@ class ConfigPaser { ...@@ -38,9 +44,15 @@ class ConfigPaser {
bool load_config(const std::string& model_dir, bool load_config(const std::string& model_dir,
const std::string& cfg = "infer_cfg.yml") { const std::string& cfg = "infer_cfg.yml") {
std::string cfg_file = model_dir + OS_PATH_SEP + cfg;
if (access(cfg_file.c_str(), 0) < 0) {
std::cerr << "[WARNING] Config yaml file is not found, please check "
<< "whether infer_cfg.yml exists in model_dir" << std::endl;
return false;
}
// Load as a YAML::Node // Load as a YAML::Node
YAML::Node config; YAML::Node config;
config = YAML::LoadFile(model_dir + OS_PATH_SEP + cfg); config = YAML::LoadFile(cfg_file);
// Get runtime mode : fluid, trt_fp16, trt_fp32 // Get runtime mode : fluid, trt_fp16, trt_fp32
if (config["mode"].IsDefined()) { if (config["mode"].IsDefined()) {
......
...@@ -58,12 +58,16 @@ class ObjectDetector { ...@@ -58,12 +58,16 @@ class ObjectDetector {
bool use_gpu=false, bool use_gpu=false,
const std::string& run_mode="fluid", const std::string& run_mode="fluid",
const int gpu_id=0) { const int gpu_id=0) {
config_.load_config(model_dir); success_init_ = config_.load_config(model_dir);
threshold_ = config_.draw_threshold_; threshold_ = config_.draw_threshold_;
preprocessor_.Init(config_.preprocess_info_, config_.arch_); preprocessor_.Init(config_.preprocess_info_, config_.arch_);
LoadModel(model_dir, use_gpu, config_.min_subgraph_size_, 1, run_mode, gpu_id); LoadModel(model_dir, use_gpu, config_.min_subgraph_size_, 1, run_mode, gpu_id);
} }
bool GetSuccessInit() const {
return success_init_;
}
// Load Paddle inference model // Load Paddle inference model
void LoadModel( void LoadModel(
const std::string& model_dir, const std::string& model_dir,
...@@ -97,6 +101,7 @@ class ObjectDetector { ...@@ -97,6 +101,7 @@ class ObjectDetector {
std::vector<float> output_data_; std::vector<float> output_data_;
float threshold_; float threshold_;
ConfigPaser config_; ConfigPaser config_;
bool success_init_;
}; };
} // namespace PaddleDetection } // namespace PaddleDetection
...@@ -20,6 +20,22 @@ ...@@ -20,6 +20,22 @@
#include "include/object_detector.h" #include "include/object_detector.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#else // Linux/Unix
#include <dirent.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#ifdef _WIN32
#define OS_PATH_SEP "\\"
#else
#define OS_PATH_SEP "/"
#endif
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_path, "", "Path of input image"); DEFINE_string(image_path, "", "Path of input image");
...@@ -27,6 +43,23 @@ DEFINE_string(video_path, "", "Path of input video"); ...@@ -27,6 +43,23 @@ DEFINE_string(video_path, "", "Path of input video");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16)"); DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute"); DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_string(output_dir, "output", "Path of saved image or video");
std::string generate_save_path(const std::string& save_dir,
const std::string& file_path) {
if (access(save_dir.c_str(), 0) < 0) {
#ifdef _WIN32
mkdir(save_dir.c_str());
#else
if (mkdir(save_dir.c_str(), S_IRWXU) < 0) {
std::cerr << "Fail to create " << save_dir << "directory." << std::endl;
}
#endif
}
int pos = file_path.find_last_of(OS_PATH_SEP);
std::string image_name(file_path.substr(pos + 1));
return save_dir + OS_PATH_SEP + image_name;
}
void PredictVideo(const std::string& video_path, void PredictVideo(const std::string& video_path,
PaddleDetection::ObjectDetector* det) { PaddleDetection::ObjectDetector* det) {
...@@ -45,7 +78,7 @@ void PredictVideo(const std::string& video_path, ...@@ -45,7 +78,7 @@ void PredictVideo(const std::string& video_path,
// Create VideoWriter for output // Create VideoWriter for output
cv::VideoWriter video_out; cv::VideoWriter video_out;
std::string video_out_path = "output.mp4"; std::string video_out_path = generate_save_path(FLAGS_output_dir, "output.mp4");
video_out.open(video_out_path.c_str(), video_out.open(video_out_path.c_str(),
0x00000021, 0x00000021,
video_fps, video_fps,
...@@ -110,7 +143,8 @@ void PredictImage(const std::string& image_path, ...@@ -110,7 +143,8 @@ void PredictImage(const std::string& image_path,
std::vector<int> compression_params; std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY); compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95); compression_params.push_back(95);
cv::imwrite("output.jpg", vis_img, compression_params); std::string output_image_path = generate_save_path(FLAGS_output_dir, "output.jpg");
cv::imwrite(output_image_path, vis_img, compression_params);
printf("Visualized output saved as output.jpeg\n"); printf("Visualized output saved as output.jpeg\n");
} }
...@@ -133,10 +167,12 @@ int main(int argc, char** argv) { ...@@ -133,10 +167,12 @@ int main(int argc, char** argv) {
PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_use_gpu, PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_use_gpu,
FLAGS_run_mode, FLAGS_gpu_id); FLAGS_run_mode, FLAGS_gpu_id);
// Do inference on input video or image // Do inference on input video or image
if (det.GetSuccessInit()) {
if (!FLAGS_video_path.empty()) { if (!FLAGS_video_path.empty()) {
PredictVideo(FLAGS_video_path, &det); PredictVideo(FLAGS_video_path, &det);
} else if (!FLAGS_image_path.empty()) { } else if (!FLAGS_image_path.empty()) {
PredictImage(FLAGS_image_path, &det); PredictImage(FLAGS_image_path, &det);
} }
}
return 0; return 0;
} }
...@@ -15,6 +15,12 @@ ...@@ -15,6 +15,12 @@
// for setprecision // for setprecision
#include <iomanip> #include <iomanip>
#include "include/object_detector.h" #include "include/object_detector.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#else // Linux/Unix
#include <unistd.h>
#endif
namespace PaddleDetection { namespace PaddleDetection {
...@@ -28,6 +34,11 @@ void ObjectDetector::LoadModel(const std::string& model_dir, ...@@ -28,6 +34,11 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
paddle::AnalysisConfig config; paddle::AnalysisConfig config;
std::string prog_file = model_dir + OS_PATH_SEP + "__model__"; std::string prog_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__"; std::string params_file = model_dir + OS_PATH_SEP + "__params__";
if (access(prog_file.c_str(), 0) < 0 || access(params_file.c_str(), 0) < 0) {
std::cerr << "[WARNING] Model file or parameter file can't be found." << std::endl;
success_init_ = false;
return;
}
config.SetModel(prog_file, params_file); config.SetModel(prog_file, params_file);
if (use_gpu) { if (use_gpu) {
config.EnableUseGpu(100, gpu_id); config.EnableUseGpu(100, gpu_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册