提交 29f8fb83 编写于 作者: D dongshuilong

cpp shitu code format

上级 a96305c2
...@@ -35,77 +35,76 @@ using namespace paddle_infer; ...@@ -35,77 +35,76 @@ using namespace paddle_infer;
namespace Feature { namespace Feature {
class FeatureExtracter { class FeatureExtracter {
public: public:
explicit FeatureExtracter(const YAML::Node &config_file) { explicit FeatureExtracter(const YAML::Node &config_file) {
this->use_gpu_ = config_file["Global"]["use_gpu"].as<bool>(); this->use_gpu_ = config_file["Global"]["use_gpu"].as<bool>();
if (config_file["Global"]["gpu_id"].IsDefined()) if (config_file["Global"]["gpu_id"].IsDefined())
this->gpu_id_ = config_file["Global"]["gpu_id"].as<int>(); this->gpu_id_ = config_file["Global"]["gpu_id"].as<int>();
else else
this->gpu_id_ = 0; this->gpu_id_ = 0;
this->gpu_mem_ = config_file["Global"]["gpu_mem"].as<int>(); this->gpu_mem_ = config_file["Global"]["gpu_mem"].as<int>();
this->cpu_math_library_num_threads_ = this->cpu_math_library_num_threads_ =
config_file["Global"]["cpu_num_threads"].as<int>(); config_file["Global"]["cpu_num_threads"].as<int>();
this->use_mkldnn_ = config_file["Global"]["enable_mkldnn"].as<bool>(); this->use_mkldnn_ = config_file["Global"]["enable_mkldnn"].as<bool>();
this->use_tensorrt_ = config_file["Global"]["use_tensorrt"].as<bool>(); this->use_tensorrt_ = config_file["Global"]["use_tensorrt"].as<bool>();
this->use_fp16_ = config_file["Global"]["use_fp16"].as<bool>(); this->use_fp16_ = config_file["Global"]["use_fp16"].as<bool>();
this->cls_model_path_ = this->cls_model_path_ =
config_file["Global"]["rec_inference_model_dir"].as<std::string>() + config_file["Global"]["rec_inference_model_dir"].as<std::string>() +
OS_PATH_SEP + "inference.pdmodel"; OS_PATH_SEP + "inference.pdmodel";
this->cls_params_path_ = this->cls_params_path_ =
config_file["Global"]["rec_inference_model_dir"].as<std::string>() + config_file["Global"]["rec_inference_model_dir"].as<std::string>() +
OS_PATH_SEP + "inference.pdiparams"; OS_PATH_SEP + "inference.pdiparams";
this->resize_size_ = this->resize_size_ =
config_file["RecPreProcess"]["transform_ops"][0]["ResizeImage"]["size"] config_file["RecPreProcess"]["transform_ops"][0]["ResizeImage"]["size"]
.as<int>(); .as<int>();
this->scale_ = config_file["RecPreProcess"]["transform_ops"][1] this->scale_ = config_file["RecPreProcess"]["transform_ops"][1]["NormalizeImage"]["scale"].as<float>();
["NormalizeImage"]["scale"] this->mean_ = config_file["RecPreProcess"]["transform_ops"][1]
.as<float>(); ["NormalizeImage"]["mean"]
this->mean_ = config_file["RecPreProcess"]["transform_ops"][1] .as < std::vector < float >> ();
["NormalizeImage"]["mean"] this->std_ = config_file["RecPreProcess"]["transform_ops"][1]
.as<std::vector<float>>(); ["NormalizeImage"]["std"]
this->std_ = config_file["RecPreProcess"]["transform_ops"][1] .as < std::vector < float >> ();
["NormalizeImage"]["std"] if (config_file["Global"]["rec_feature_normlize"].IsDefined())
.as<std::vector<float>>(); this->feature_norm =
if (config_file["Global"]["rec_feature_normlize"].IsDefined()) config_file["Global"]["rec_feature_normlize"].as<bool>();
this->feature_norm =
config_file["Global"]["rec_feature_normlize"].as<bool>(); LoadModel(cls_model_path_, cls_params_path_);
}
LoadModel(cls_model_path_, cls_params_path_);
} // Load Paddle inference model
void LoadModel(const std::string &model_path, const std::string &params_path);
// Load Paddle inference model
void LoadModel(const std::string &model_path, const std::string &params_path); // Run predictor
void Run(cv::Mat &img, std::vector<float> &out_data,
// Run predictor std::vector<double> &times);
void Run(cv::Mat &img, std::vector<float> &out_data,
std::vector<double> &times); void FeatureNorm(std::vector<float> &feature);
void FeatureNorm(std::vector<float> &feature);
std::shared_ptr <Predictor> predictor_;
std::shared_ptr<Predictor> predictor_;
private:
private: bool use_gpu_ = false;
bool use_gpu_ = false; int gpu_id_ = 0;
int gpu_id_ = 0; int gpu_mem_ = 4000;
int gpu_mem_ = 4000; int cpu_math_library_num_threads_ = 4;
int cpu_math_library_num_threads_ = 4; bool use_mkldnn_ = false;
bool use_mkldnn_ = false; bool use_tensorrt_ = false;
bool use_tensorrt_ = false; bool feature_norm = true;
bool feature_norm = true; bool use_fp16_ = false;
bool use_fp16_ = false; std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f}; std::vector<float> std_ = {0.229f, 0.224f, 0.225f};
std::vector<float> std_ = {0.229f, 0.224f, 0.225f}; float scale_ = 0.00392157;
float scale_ = 0.00392157; int resize_size_ = 224;
int resize_size_ = 224; int resize_short_ = 224;
int resize_short_ = 224; std::string cls_model_path_;
std::string cls_model_path_; std::string cls_params_path_;
std::string cls_params_path_;
// pre-process
// pre-process ResizeImg resize_op_;
ResizeImg resize_op_; Normalize normalize_op_;
Normalize normalize_op_; Permute permute_op_;
Permute permute_op_; };
};
} // namespace Feature } // namespace Feature
...@@ -17,21 +17,21 @@ ...@@ -17,21 +17,21 @@
#include <algorithm> #include <algorithm>
#include <include/object_detector.h> #include <include/object_detector.h>
template <typename T> template<typename T>
static inline bool SortScorePairDescend(const std::pair<float, T> &pair1, static inline bool SortScorePairDescend(const std::pair<float, T> &pair1,
const std::pair<float, T> &pair2) { const std::pair<float, T> &pair2) {
return pair1.first > pair2.first; return pair1.first > pair2.first;
} }
float RectOverlap(const Detection::ObjectResult &a, float RectOverlap(const Detection::ObjectResult &a,
const Detection::ObjectResult &b) { const Detection::ObjectResult &b) {
float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1); float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1);
float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1); float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1);
int iou_w = max(min(a.rect[2], b.rect[2]) - max(a.rect[0], b.rect[0]) + 1, 0); int iou_w = max(min(a.rect[2], b.rect[2]) - max(a.rect[0], b.rect[0]) + 1, 0);
int iou_h = max(min(a.rect[3], b.rect[3]) - max(a.rect[1], b.rect[1]) + 1, 0); int iou_h = max(min(a.rect[3], b.rect[3]) - max(a.rect[1], b.rect[1]) + 1, 0);
float Aab = iou_w * iou_h; float Aab = iou_w * iou_h;
return Aab / (Aa + Ab - Aab); return Aab / (Aa + Ab - Aab);
} }
// Get max scores with corresponding indices. // Get max scores with corresponding indices.
...@@ -40,46 +40,46 @@ float RectOverlap(const Detection::ObjectResult &a, ...@@ -40,46 +40,46 @@ float RectOverlap(const Detection::ObjectResult &a,
// top_k: if -1, keep all; otherwise, keep at most top_k. // top_k: if -1, keep all; otherwise, keep at most top_k.
// score_index_vec: store the sorted (score, index) pair. // score_index_vec: store the sorted (score, index) pair.
inline void inline void
GetMaxScoreIndex(const std::vector<Detection::ObjectResult> &det_result, GetMaxScoreIndex(const std::vector <Detection::ObjectResult> &det_result,
const float threshold, const float threshold,
std::vector<std::pair<float, int>> &score_index_vec) { std::vector <std::pair<float, int>> &score_index_vec) {
// Generate index score pairs. // Generate index score pairs.
for (size_t i = 0; i < det_result.size(); ++i) { for (size_t i = 0; i < det_result.size(); ++i) {
if (det_result[i].confidence > threshold) { if (det_result[i].confidence > threshold) {
score_index_vec.push_back(std::make_pair(det_result[i].confidence, i)); score_index_vec.push_back(std::make_pair(det_result[i].confidence, i));
}
} }
}
// Sort the score pair according to the scores in descending order // Sort the score pair according to the scores in descending order
std::stable_sort(score_index_vec.begin(), score_index_vec.end(), std::stable_sort(score_index_vec.begin(), score_index_vec.end(),
SortScorePairDescend<int>); SortScorePairDescend<int>);
// // Keep top_k scores if needed. // // Keep top_k scores if needed.
// if (top_k > 0 && top_k < (int)score_index_vec.size()) // if (top_k > 0 && top_k < (int)score_index_vec.size())
// { // {
// score_index_vec.resize(top_k); // score_index_vec.resize(top_k);
// } // }
} }
void NMSBoxes(const std::vector<Detection::ObjectResult> det_result, void NMSBoxes(const std::vector <Detection::ObjectResult> det_result,
const float score_threshold, const float nms_threshold, const float score_threshold, const float nms_threshold,
std::vector<int> &indices) { std::vector<int> &indices) {
int a = 1; int a = 1;
// Get top_k scores (with corresponding indices). // Get top_k scores (with corresponding indices).
std::vector<std::pair<float, int>> score_index_vec; std::vector <std::pair<float, int>> score_index_vec;
GetMaxScoreIndex(det_result, score_threshold, score_index_vec); GetMaxScoreIndex(det_result, score_threshold, score_index_vec);
// Do nms // Do nms
indices.clear(); indices.clear();
for (size_t i = 0; i < score_index_vec.size(); ++i) { for (size_t i = 0; i < score_index_vec.size(); ++i) {
const int idx = score_index_vec[i].second; const int idx = score_index_vec[i].second;
bool keep = true; bool keep = true;
for (int k = 0; k < (int)indices.size() && keep; ++k) { for (int k = 0; k < (int) indices.size() && keep; ++k) {
const int kept_idx = indices[k]; const int kept_idx = indices[k];
float overlap = RectOverlap(det_result[idx], det_result[kept_idx]); float overlap = RectOverlap(det_result[idx], det_result[kept_idx]);
keep = overlap <= nms_threshold; keep = overlap <= nms_threshold;
}
if (keep)
indices.push_back(idx);
} }
if (keep)
indices.push_back(idx);
}
} }
...@@ -33,103 +33,106 @@ using namespace paddle_infer; ...@@ -33,103 +33,106 @@ using namespace paddle_infer;
namespace Detection { namespace Detection {
// Object Detection Result // Object Detection Result
struct ObjectResult { struct ObjectResult {
// Rectangle coordinates of detected object: left, right, top, down // Rectangle coordinates of detected object: left, right, top, down
std::vector<int> rect; std::vector<int> rect;
// Class id of detected object // Class id of detected object
int class_id; int class_id;
// Confidence of detected object // Confidence of detected object
float confidence; float confidence;
}; };
// Generate visualization colormap for each class // Generate visualization colormap for each class
std::vector<int> GenerateColorMap(int num_class); std::vector<int> GenerateColorMap(int num_class);
// Visualiztion Detection Result // Visualiztion Detection Result
cv::Mat VisualizeResult(const cv::Mat &img, cv::Mat VisualizeResult(const cv::Mat &img,
const std::vector<ObjectResult> &results, const std::vector <ObjectResult> &results,
const std::vector<std::string> &lables, const std::vector <std::string> &lables,
const std::vector<int> &colormap, const bool is_rbox); const std::vector<int> &colormap, const bool is_rbox);
class ObjectDetector { class ObjectDetector {
public: public:
explicit ObjectDetector(const YAML::Node &config_file) { explicit ObjectDetector(const YAML::Node &config_file) {
this->use_gpu_ = config_file["Global"]["use_gpu"].as<bool>(); this->use_gpu_ = config_file["Global"]["use_gpu"].as<bool>();
if (config_file["Global"]["gpu_id"].IsDefined()) if (config_file["Global"]["gpu_id"].IsDefined())
this->gpu_id_ = config_file["Global"]["gpu_id"].as<int>(); this->gpu_id_ = config_file["Global"]["gpu_id"].as<int>();
this->gpu_mem_ = config_file["Global"]["gpu_mem"].as<int>(); this->gpu_mem_ = config_file["Global"]["gpu_mem"].as<int>();
this->cpu_math_library_num_threads_ = this->cpu_math_library_num_threads_ =
config_file["Global"]["cpu_num_threads"].as<int>(); config_file["Global"]["cpu_num_threads"].as<int>();
this->use_mkldnn_ = config_file["Global"]["enable_mkldnn"].as<bool>(); this->use_mkldnn_ = config_file["Global"]["enable_mkldnn"].as<bool>();
this->use_tensorrt_ = config_file["Global"]["use_tensorrt"].as<bool>(); this->use_tensorrt_ = config_file["Global"]["use_tensorrt"].as<bool>();
this->use_fp16_ = config_file["Global"]["use_fp16"].as<bool>(); this->use_fp16_ = config_file["Global"]["use_fp16"].as<bool>();
this->model_dir_ = this->model_dir_ =
config_file["Global"]["det_inference_model_dir"].as<std::string>(); config_file["Global"]["det_inference_model_dir"].as<std::string>();
this->threshold_ = config_file["Global"]["threshold"].as<float>(); this->threshold_ = config_file["Global"]["threshold"].as<float>();
this->max_det_results_ = config_file["Global"]["max_det_results"].as<int>(); this->max_det_results_ = config_file["Global"]["max_det_results"].as<int>();
this->image_shape_ = this->image_shape_ =
config_file["Global"]["image_shape"].as<std::vector<int>>(); config_file["Global"]["image_shape"].as < std::vector < int >> ();
this->label_list_ = this->label_list_ =
config_file["Global"]["labe_list"].as<std::vector<std::string>>(); config_file["Global"]["labe_list"].as < std::vector < std::string >> ();
this->ir_optim_ = config_file["Global"]["ir_optim"].as<bool>(); this->ir_optim_ = config_file["Global"]["ir_optim"].as<bool>();
this->batch_size_ = config_file["Global"]["batch_size"].as<int>(); this->batch_size_ = config_file["Global"]["batch_size"].as<int>();
preprocessor_.Init(config_file["DetPreProcess"]["transform_ops"]); preprocessor_.Init(config_file["DetPreProcess"]["transform_ops"]);
LoadModel(model_dir_, batch_size_, run_mode); LoadModel(model_dir_, batch_size_, run_mode);
} }
// Load Paddle inference model // Load Paddle inference model
void LoadModel(const std::string &model_dir, const int batch_size = 1, void LoadModel(const std::string &model_dir, const int batch_size = 1,
const std::string &run_mode = "fluid"); const std::string &run_mode = "fluid");
// Run predictor // Run predictor
void Predict(const std::vector<cv::Mat> imgs, const int warmup = 0, void Predict(const std::vector <cv::Mat> imgs, const int warmup = 0,
const int repeats = 1, const int repeats = 1,
std::vector<ObjectResult> *result = nullptr, std::vector <ObjectResult> *result = nullptr,
std::vector<int> *bbox_num = nullptr, std::vector<int> *bbox_num = nullptr,
std::vector<double> *times = nullptr); std::vector<double> *times = nullptr);
const std::vector<std::string> &GetLabelList() const {
return this->label_list_; const std::vector <std::string> &GetLabelList() const {
} return this->label_list_;
const float &GetThreshold() const { return this->threshold_; } }
private: const float &GetThreshold() const { return this->threshold_; }
bool use_gpu_ = true;
int gpu_id_ = 0; private:
int gpu_mem_ = 800; bool use_gpu_ = true;
int cpu_math_library_num_threads_ = 6; int gpu_id_ = 0;
std::string run_mode = "fluid"; int gpu_mem_ = 800;
bool use_mkldnn_ = false; int cpu_math_library_num_threads_ = 6;
bool use_tensorrt_ = false; std::string run_mode = "fluid";
bool batch_size_ = 1; bool use_mkldnn_ = false;
bool use_fp16_ = false; bool use_tensorrt_ = false;
std::string model_dir_; bool batch_size_ = 1;
float threshold_ = 0.5; bool use_fp16_ = false;
float max_det_results_ = 5; std::string model_dir_;
std::vector<int> image_shape_ = {3, 640, 640}; float threshold_ = 0.5;
std::vector<std::string> label_list_; float max_det_results_ = 5;
bool ir_optim_ = true; std::vector<int> image_shape_ = {3, 640, 640};
bool det_permute_ = true; std::vector <std::string> label_list_;
bool det_postprocess_ = true; bool ir_optim_ = true;
int min_subgraph_size_ = 30; bool det_permute_ = true;
bool use_dynamic_shape_ = false; bool det_postprocess_ = true;
int trt_min_shape_ = 1; int min_subgraph_size_ = 30;
int trt_max_shape_ = 1280; bool use_dynamic_shape_ = false;
int trt_opt_shape_ = 640; int trt_min_shape_ = 1;
bool trt_calib_mode_ = false; int trt_max_shape_ = 1280;
int trt_opt_shape_ = 640;
// Preprocess image and copy data to input buffer bool trt_calib_mode_ = false;
void Preprocess(const cv::Mat &image_mat);
// Postprocess result // Preprocess image and copy data to input buffer
void Postprocess(const std::vector<cv::Mat> mats, void Preprocess(const cv::Mat &image_mat);
std::vector<ObjectResult> *result, std::vector<int> bbox_num,
bool is_rbox); // Postprocess result
void Postprocess(const std::vector <cv::Mat> mats,
std::shared_ptr<Predictor> predictor_; std::vector <ObjectResult> *result, std::vector<int> bbox_num,
Preprocessor preprocessor_; bool is_rbox);
ImageBlob inputs_;
std::vector<float> output_data_; std::shared_ptr <Predictor> predictor_;
std::vector<int> out_bbox_num_data_; Preprocessor preprocessor_;
}; ImageBlob inputs_;
std::vector<float> output_data_;
std::vector<int> out_bbox_num_data_;
};
} // namespace Detection } // namespace Detection
...@@ -31,27 +31,27 @@ using namespace std; ...@@ -31,27 +31,27 @@ using namespace std;
namespace Feature { namespace Feature {
class Normalize { class Normalize {
public: public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean, virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale); const std::vector<float> &std, float scale);
}; };
// RGB -> CHW // RGB -> CHW
class Permute { class Permute {
public: public:
virtual void Run(const cv::Mat *im, float *data); virtual void Run(const cv::Mat *im, float *data);
}; };
class CenterCropImg { class CenterCropImg {
public: public:
virtual void Run(cv::Mat &im, const int crop_size = 224); virtual void Run(cv::Mat &im, const int crop_size = 224);
}; };
class ResizeImg { class ResizeImg {
public: public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
int size = 0); int size = 0);
}; };
} // namespace Feature } // namespace Feature
...@@ -31,125 +31,128 @@ ...@@ -31,125 +31,128 @@
namespace Detection { namespace Detection {
// Object for storing all preprocessed data // Object for storing all preprocessed data
class ImageBlob { class ImageBlob {
public: public:
// image width and height // image width and height
std::vector<float> im_shape_; std::vector<float> im_shape_;
// Buffer for image data after preprocessing // Buffer for image data after preprocessing
std::vector<float> im_data_; std::vector<float> im_data_;
// in net data shape(after pad) // in net data shape(after pad)
std::vector<float> in_net_shape_; std::vector<float> in_net_shape_;
// Evaluation image width and height // Evaluation image width and height
// std::vector<float> eval_im_size_f_; // std::vector<float> eval_im_size_f_;
// Scale factor for image size to origin image size // Scale factor for image size to origin image size
std::vector<float> scale_factor_; std::vector<float> scale_factor_;
}; };
// Abstraction of preprocessing opration class // Abstraction of preprocessing opration class
class PreprocessOp { class PreprocessOp {
public: public:
virtual void Init(const YAML::Node &item) = 0; virtual void Init(const YAML::Node &item) = 0;
virtual void Run(cv::Mat *im, ImageBlob *data) = 0;
}; virtual void Run(cv::Mat *im, ImageBlob *data) = 0;
};
class InitInfo : public PreprocessOp {
public: class InitInfo : public PreprocessOp {
virtual void Init(const YAML::Node &item) {} public:
virtual void Run(cv::Mat *im, ImageBlob *data); virtual void Init(const YAML::Node &item) {}
};
virtual void Run(cv::Mat *im, ImageBlob *data);
class NormalizeImage : public PreprocessOp { };
public:
virtual void Init(const YAML::Node &item) { class NormalizeImage : public PreprocessOp {
mean_ = item["mean"].as<std::vector<float>>(); public:
scale_ = item["std"].as<std::vector<float>>(); virtual void Init(const YAML::Node &item) {
is_scale_ = item["is_scale"].as<bool>(); mean_ = item["mean"].as < std::vector < float >> ();
} scale_ = item["std"].as < std::vector < float >> ();
is_scale_ = item["is_scale"].as<bool>();
virtual void Run(cv::Mat *im, ImageBlob *data); }
private: virtual void Run(cv::Mat *im, ImageBlob *data);
// CHW or HWC
std::vector<float> mean_; private:
std::vector<float> scale_; // CHW or HWC
bool is_scale_; std::vector<float> mean_;
}; std::vector<float> scale_;
bool is_scale_;
class Permute : public PreprocessOp { };
public:
virtual void Init(const YAML::Node &item) {} class Permute : public PreprocessOp {
virtual void Run(cv::Mat *im, ImageBlob *data); public:
}; virtual void Init(const YAML::Node &item) {}
class Resize : public PreprocessOp { virtual void Run(cv::Mat *im, ImageBlob *data);
public: };
virtual void Init(const YAML::Node &item) {
interp_ = item["interp"].as<int>(); class Resize : public PreprocessOp {
// max_size_ = item["target_size"].as<int>(); public:
keep_ratio_ = item["keep_ratio"].as<bool>(); virtual void Init(const YAML::Node &item) {
target_size_ = item["target_size"].as<std::vector<int>>(); interp_ = item["interp"].as<int>();
} // max_size_ = item["target_size"].as<int>();
keep_ratio_ = item["keep_ratio"].as<bool>();
// Compute best resize scale for x-dimension, y-dimension target_size_ = item["target_size"].as < std::vector < int >> ();
std::pair<double, double> GenerateScale(const cv::Mat &im); }
virtual void Run(cv::Mat *im, ImageBlob *data); // Compute best resize scale for x-dimension, y-dimension
std::pair<double, double> GenerateScale(const cv::Mat &im);
private:
int interp_ = 2; virtual void Run(cv::Mat *im, ImageBlob *data);
bool keep_ratio_;
std::vector<int> target_size_; private:
std::vector<int> in_net_shape_; int interp_ = 2;
}; bool keep_ratio_;
std::vector<int> target_size_;
std::vector<int> in_net_shape_;
};
// Models with FPN need input shape % stride == 0 // Models with FPN need input shape % stride == 0
class PadStride : public PreprocessOp { class PadStride : public PreprocessOp {
public: public:
virtual void Init(const YAML::Node &item) { virtual void Init(const YAML::Node &item) {
stride_ = item["stride"].as<int>(); stride_ = item["stride"].as<int>();
} }
virtual void Run(cv::Mat *im, ImageBlob *data); virtual void Run(cv::Mat *im, ImageBlob *data);
private: private:
int stride_; int stride_;
}; };
class Preprocessor { class Preprocessor {
public: public:
void Init(const YAML::Node &config_node) { void Init(const YAML::Node &config_node) {
// initialize image info at first // initialize image info at first
ops_["InitInfo"] = std::make_shared<InitInfo>(); ops_["InitInfo"] = std::make_shared<InitInfo>();
for (int i = 0; i < config_node.size(); ++i) { for (int i = 0; i < config_node.size(); ++i) {
if (config_node[i]["DetResize"].IsDefined()) { if (config_node[i]["DetResize"].IsDefined()) {
ops_["Resize"] = std::make_shared<Resize>(); ops_["Resize"] = std::make_shared<Resize>();
ops_["Resize"]->Init(config_node[i]["DetResize"]); ops_["Resize"]->Init(config_node[i]["DetResize"]);
} }
if (config_node[i]["DetNormalizeImage"].IsDefined()) { if (config_node[i]["DetNormalizeImage"].IsDefined()) {
ops_["NormalizeImage"] = std::make_shared<NormalizeImage>(); ops_["NormalizeImage"] = std::make_shared<NormalizeImage>();
ops_["NormalizeImage"]->Init(config_node[i]["DetNormalizeImage"]); ops_["NormalizeImage"]->Init(config_node[i]["DetNormalizeImage"]);
} }
if (config_node[i]["DetPermute"].IsDefined()) { if (config_node[i]["DetPermute"].IsDefined()) {
ops_["Permute"] = std::make_shared<Permute>(); ops_["Permute"] = std::make_shared<Permute>();
ops_["Permute"]->Init(config_node[i]["DetPermute"]); ops_["Permute"]->Init(config_node[i]["DetPermute"]);
} }
if (config_node[i]["DetPadStrid"].IsDefined()) { if (config_node[i]["DetPadStrid"].IsDefined()) {
ops_["PadStride"] = std::make_shared<PadStride>(); ops_["PadStride"] = std::make_shared<PadStride>();
ops_["PadStride"]->Init(config_node[i]["DetPadStrid"]); ops_["PadStride"]->Init(config_node[i]["DetPadStrid"]);
} }
} }
} }
void Run(cv::Mat *im, ImageBlob *data); void Run(cv::Mat *im, ImageBlob *data);
public: public:
static const std::vector<std::string> RUN_ORDER; static const std::vector <std::string> RUN_ORDER;
private: private:
std::unordered_map<std::string, std::shared_ptr<PreprocessOp>> ops_; std::unordered_map <std::string, std::shared_ptr<PreprocessOp>> ops_;
}; };
} // namespace Detection } // namespace Detection
...@@ -26,40 +26,45 @@ ...@@ -26,40 +26,45 @@
#include <map> #include <map>
struct SearchResult { struct SearchResult {
std::vector<faiss::Index::idx_t> I; std::vector <faiss::Index::idx_t> I;
std::vector<float> D; std::vector<float> D;
int return_k; int return_k;
}; };
class VectorSearch { class VectorSearch {
public: public:
explicit VectorSearch(const YAML::Node &config_file) { explicit VectorSearch(const YAML::Node &config_file) {
// IndexProcess // IndexProcess
this->index_dir = this->index_dir =
config_file["IndexProcess"]["index_dir"].as<std::string>(); config_file["IndexProcess"]["index_dir"].as<std::string>();
this->return_k = config_file["IndexProcess"]["return_k"].as<int>(); this->return_k = config_file["IndexProcess"]["return_k"].as<int>();
this->score_thres = config_file["IndexProcess"]["score_thres"].as<float>(); this->score_thres = config_file["IndexProcess"]["score_thres"].as<float>();
this->max_query_number = this->max_query_number =
config_file["Global"]["max_det_results"].as<int>() + 1; config_file["Global"]["max_det_results"].as<int>() + 1;
LoadIdMap(); LoadIdMap();
LoadIndexFile(); LoadIndexFile();
this->I.resize(this->return_k * this->max_query_number); this->I.resize(this->return_k * this->max_query_number);
this->D.resize(this->return_k * this->max_query_number); this->D.resize(this->return_k * this->max_query_number);
}; };
void LoadIdMap();
void LoadIndexFile(); void LoadIdMap();
const SearchResult &Search(float *feature, int query_number);
const std::string &GetLabel(faiss::Index::idx_t ind); void LoadIndexFile();
const float &GetThreshold() { return this->score_thres; }
const SearchResult &Search(float *feature, int query_number);
const std::string &GetLabel(faiss::Index::idx_t ind);
const float &GetThreshold() { return this->score_thres; }
private: private:
std::string index_dir; std::string index_dir;
int return_k = 5; int return_k = 5;
float score_thres = 0.5; float score_thres = 0.5;
std::map<long int, std::string> id_map; std::map<long int, std::string> id_map;
faiss::Index *index; faiss::Index *index;
int max_query_number = 6; int max_query_number = 6;
std::vector<float> D; std::vector<float> D;
std::vector<faiss::Index::idx_t> I; std::vector <faiss::Index::idx_t> I;
SearchResult sr; SearchResult sr;
}; };
...@@ -42,12 +42,17 @@ ...@@ -42,12 +42,17 @@
class YamlConfig { class YamlConfig {
public: public:
explicit YamlConfig(const std::string &path) { explicit YamlConfig(const std::string &path) {
config_file = ReadYamlConfig(path); config_file = ReadYamlConfig(path);
} }
static std::vector<std::string> ReadDict(const std::string &path);
static std::map<int, std::string> ReadIndexId(const std::string &path); static std::vector <std::string> ReadDict(const std::string &path);
static YAML::Node ReadYamlConfig(const std::string &path);
void PrintConfigInfo(); static std::map<int, std::string> ReadIndexId(const std::string &path);
YAML::Node config_file;
static YAML::Node ReadYamlConfig(const std::string &path);
void PrintConfigInfo();
YAML::Node config_file;
}; };
...@@ -6,10 +6,7 @@ ...@@ -6,10 +6,7 @@
## 1. 准备环境 ## 1. 准备环境
### 运行准备 ### 运行准备
- Linux环境,推荐使用docker。 - Linux环境,推荐使用ubuntu docker。
- Windows环境,目前支持基于`Visual Studio 2019 Community`进行编译;此外,如果您希望通过生成`sln解决方案`的方式进行编译,可以参考该文档:[https://zhuanlan.zhihu.com/p/145446681](https://zhuanlan.zhihu.com/p/145446681)
* 该文档主要介绍基于Linux环境下的PaddleClas C++预测流程,如果需要在Windows环境下使用预测库进行C++预测,具体编译方法请参考[Windows下编译教程](./docs/windows_vs2019_build.md)
### 1.1 编译opencv库 ### 1.1 编译opencv库
...@@ -103,7 +100,7 @@ make -j ...@@ -103,7 +100,7 @@ make -j
make inference_lib_dist make inference_lib_dist
``` ```
更多编译参数选项可以参考Paddle C++预测库官网:[https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#id16](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#id16) 更多编译参数选项可以参考[Paddle C++预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#id16)
* 编译完成之后,可以在`build/paddle_inference_install_dir/`文件下看到生成了以下文件及文件夹。 * 编译完成之后,可以在`build/paddle_inference_install_dir/`文件下看到生成了以下文件及文件夹。
...@@ -137,29 +134,27 @@ tar -xvf paddle_inference.tgz ...@@ -137,29 +134,27 @@ tar -xvf paddle_inference.tgz
### 1.3 安装faiss库 ### 1.3 安装faiss库
```shell ```shell
# 下载 faiss
git clone https://github.com/facebookresearch/faiss.git git clone https://github.com/facebookresearch/faiss.git
cd faiss cd faiss
cmake -B build . -DFAISS_ENABLE_PYTHON=OFF -DCMAKE_INSTALL_PREFIX=${faiss_install_path} cmake -B build . -DFAISS_ENABLE_PYTHON=OFF -DCMAKE_INSTALL_PREFIX=${faiss_install_path}
make -C build -j faiss make -C build -j faiss
make -C build install make -C build install
``` ```
## 2 开始运行 在安装`faiss`前,请安装`openblas``ubuntu`系统中安装命令如下:
### 2.1 将模型导出为inference model ```shell
apt-get install libopenblas-dev
```
* 可以参考[模型导出](../../tools/export_model.py),导出`inference model`,用于模型预测。得到预测模型后,假设模型文件放在`inference`目录下,则目录结构如下 注意本教程以安装faiss cpu版本为例,安装时请参考[faiss](https://github.com/facebookresearch/faiss)官网文档,根据需求自行安装
``` ## 2 代码编译
inference/
|--cls_infer.pdmodel
|--cls_infer.pdiparams
```
**注意**:上述文件中,`cls_infer.pdmodel`文件存储了模型结构信息,`cls_infer.pdiparams`文件存储了模型参数信息。注意两个文件的路径需要与配置文件`tools/config.txt`中的`cls_model_path``cls_params_path`参数对应一致。
### 2.2 编译PaddleClas C++预测demo ### 2.2 编译PaddleClas C++预测demo
* 编译命令如下,其中Paddle C++预测库、opencv等其他依赖库的地址需要换成自己机器上的实际地址 编译命令如下,其中Paddle C++预测库、opencv等其他依赖库的地址需要换成自己机器上的实际地址。同时,编译过程中需要下载编译`yaml-cpp`等C++库,请保持联网环境
```shell ```shell
...@@ -169,11 +164,12 @@ sh tools/build.sh ...@@ -169,11 +164,12 @@ sh tools/build.sh
具体地,`tools/build.sh`中内容如下。 具体地,`tools/build.sh`中内容如下。
```shell ```shell
OPENCV_DIR=your_opencv_dir OPENCV_DIR=${opencv_install_dir}
LIB_DIR=your_paddle_inference_dir LIB_DIR=${paddle_inference_dir}
CUDA_LIB_DIR=your_cuda_lib_dir CUDA_LIB_DIR=/usr/local/cuda/lib64
CUDNN_LIB_DIR=your_cudnn_lib_dir CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/
TENSORRT_DIR=your_tensorrt_lib_dir FAISS_DIR=${faiss_install_dir}
FAISS_WITH_MKL=OFF
BUILD_DIR=build BUILD_DIR=build
rm -rf ${BUILD_DIR} rm -rf ${BUILD_DIR}
...@@ -182,14 +178,14 @@ cd ${BUILD_DIR} ...@@ -182,14 +178,14 @@ cd ${BUILD_DIR}
cmake .. \ cmake .. \
-DPADDLE_LIB=${LIB_DIR} \ -DPADDLE_LIB=${LIB_DIR} \
-DWITH_MKL=ON \ -DWITH_MKL=ON \
-DDEMO_NAME=clas_system \
-DWITH_GPU=OFF \ -DWITH_GPU=OFF \
-DWITH_STATIC_LIB=OFF \ -DWITH_STATIC_LIB=OFF \
-DWITH_TENSORRT=OFF \ -DUSE_TENSORRT=OFF \
-DTENSORRT_DIR=${TENSORRT_DIR} \
-DOPENCV_DIR=${OPENCV_DIR} \ -DOPENCV_DIR=${OPENCV_DIR} \
-DCUDNN_LIB=${CUDNN_LIB_DIR} \ -DCUDNN_LIB=${CUDNN_LIB_DIR} \
-DCUDA_LIB=${CUDA_LIB_DIR} \ -DCUDA_LIB=${CUDA_LIB_DIR} \
-DFAISS_DIR=${FAISS_DIR} \
-DFAISS_WITH_MKL=${FAISS_WITH_MKL}
make -j make -j
``` ```
...@@ -197,47 +193,75 @@ make -j ...@@ -197,47 +193,75 @@ make -j
上述命令中, 上述命令中,
* `OPENCV_DIR`为opencv编译安装的地址(本例中为`opencv-3.4.7/opencv3`文件夹的路径); * `OPENCV_DIR`为opencv编译安装的地址(本例中为`opencv-3.4.7/opencv3`文件夹的路径);
* `LIB_DIR`为下载的Paddle预测库(`paddle_inference`文件夹),或编译生成的Paddle预测库(`build/paddle_inference_install_dir`文件夹)的路径; * `LIB_DIR`为下载的Paddle预测库(`paddle_inference`文件夹),或编译生成的Paddle预测库(`build/paddle_inference_install_dir`文件夹)的路径;
* `CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64` * `CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64`
* `CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/` * `CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`
* `TENSORRT_DIR`是tensorrt库文件地址,在dokcer中为`/usr/local/TensorRT6-cuda10.0-cudnn7/`,TensorRT需要结合GPU使用。 * `TENSORRT_DIR`是tensorrt库文件地址,在dokcer中为`/usr/local/TensorRT6-cuda10.0-cudnn7/`,TensorRT需要结合GPU使用。
* `FAISS_DIR`是faiss的安装地址
* `FAISS_WITH_MKL`是指在编译faiss的过程中,是否使用了mkldnn,本文档中编译faiss,没有使用,而使用了openblas,故设置为`OFF`,若使用了mkldnn,则为`ON`.
在执行上述命令,编译完成之后,会在当前路径下生成`build`文件夹,其中生成一个名为`clas_system`的可执行文件。
在执行上述命令,编译完成之后,会在当前路径下生成`build`文件夹,其中生成一个名为`pp_shitu`的可执行文件。
### 运行demo ## 3 运行demo
* 首先修改`tools/config.txt`中对应字段:
* use_gpu:是否使用GPU;
* gpu_id:使用的GPU卡号;
* gpu_mem:显存;
* cpu_math_library_num_threads:底层科学计算库所用线程的数量;
* use_mkldnn:是否使用MKLDNN加速;
* use_tensorrt: 是否使用tensorRT进行加速;
* use_fp16:是否使用半精度浮点数进行计算,该选项仅在use_tensorrt为true时有效;
* cls_model_path:预测模型结构文件路径;
* cls_params_path:预测模型参数文件路径;
* resize_short_size:预处理时图像缩放大小;
* crop_size:预处理时图像裁剪后的大小。
* 然后修改`tools/run.sh` - 请参考[识别快速开始文档](../../docs/zh_CN/quick_start/quick_start_recognition.md),下载好相应的 轻量级通用主体检测模型、轻量级通用识别模型及瓶装饮料测试数据并解压。
* `./build/clas_system ./tools/config.txt ./docs/imgs/ILSVRC2012_val_00000666.JPEG`
* 上述命令中分别为:编译得到的可执行文件`clas_system`;运行时的配置文件`config.txt`;待预测的图像。
* 最后执行以下命令,完成对一幅图像的分类。 ```shell
mkdir models
cd models
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
tar -xf picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
tar -xf general_PPLCNet_x2_5_lite_v1.0_infer.tar
cd ..
mkdir data
cd data
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar
tar -xf drink_dataset_v1.0.tar
cd ..
```
```shell - 将相应的yaml文件拷到`test`文件夹下
sh tools/run.sh
``` ```shell
cp ../configs/inference_drink.yaml .
```
-`inference_drink.yaml`中的相对路径,改成基于本目录的路径或者绝对路径。涉及到的参数有
- Global.infer_imgs :此参数可以是具体的图像地址,也可以是图像集所在的目录
- Global.det_inference_model_dir : 检测模型存储目录
- Global.rec_inference_model_dir : 识别模型存储目录
- IndexProcess.index_dir : 检索库的存储目录,在示例中,检索库在下载的demo数据中。
- 字典转换
由于python的检索库的字典,使用`pickle`进行的序列化存储,导致C++不方便读取,因此进行转换
```shell
python tools/transform_id_map.py -c inference_drink.yaml
```
转换成功后,在`IndexProcess.index_dir`目录下生成`id_map.txt`,方便c++ 读取。
- 执行程序
```shell
./build/pp_shitu -c inference_drink.yaml
# or
./build/pp_shitu -config inference_drink.yaml
```
若对图像集进行检索,则可能得到,如下结果。注意,此结果只做展示,具体以实际运行结果为准。
同时,需注意的是,由于opencv 版本问题,会导致图像在预处理的过程中,resize产生细微差别,导致python 和c++结果,轻微不同,如bbox相差几个像素,检索结果小数点后3位diff等。但不会改变最终检索label。
* 最终屏幕上会输出结果,如下图所示。 ![](../../docs/images/quick_start/shitu_c++_result.png)
<div align="center"> ## 4 使用自己模型
<img src="./docs/imgs/cpp_infer_result.png" width="600">
</div>
使用自己训练的模型,可以参考[模型导出](../../docs/zh_CN/inference_deployment/export_model.md),导出`inference model`,用于模型预测。
其中`class id`表示置信度最高的类别对应的id,score表示图片属于该类别的概率 同时注意修改`yaml`文件中具体参数
...@@ -18,102 +18,102 @@ ...@@ -18,102 +18,102 @@
namespace Feature { namespace Feature {
void FeatureExtracter::LoadModel(const std::string &model_path, void FeatureExtracter::LoadModel(const std::string &model_path,
const std::string &params_path) { const std::string &params_path) {
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(model_path, params_path); config.SetModel(model_path, params_path);
if (this->use_gpu_) { if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) { if (this->use_tensorrt_) {
config.EnableTensorRtEngine( config.EnableTensorRtEngine(
1 << 20, 1, 3, 1 << 20, 1, 3,
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
: paddle_infer::Config::Precision::kFloat32, : paddle_infer::Config::Precision::kFloat32,
false, false); false, false);
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
} }
} else {
config.DisableGpu(); void FeatureExtracter::Run(cv::Mat &img, std::vector<float> &out_data,
if (this->use_mkldnn_) { std::vector<double> &times) {
config.EnableMKLDNN(); cv::Mat resize_img;
// cache 10 different shapes for mkldnn to avoid memory leak std::vector<double> time;
config.SetMkldnnCacheCapacity(10);
auto preprocess_start = std::chrono::system_clock::now();
this->resize_op_.Run(img, resize_img, this->resize_short_,
this->resize_size_);
this->normalize_op_.Run(&resize_img, this->mean_, this->std_, this->scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto preprocess_end = std::chrono::system_clock::now();
auto infer_start = std::chrono::system_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
auto infer_end = std::chrono::system_clock::now();
auto postprocess_start = std::chrono::system_clock::now();
if (this->feature_norm)
FeatureNorm(out_data);
auto postprocess_end = std::chrono::system_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
time.push_back(double(preprocess_diff.count()));
std::chrono::duration<float> inference_diff = infer_end - infer_start;
double inference_cost_time = double(inference_diff.count());
time.push_back(inference_cost_time);
// std::chrono::duration<float> postprocess_diff =
// postprocess_end - postprocess_start;
time.push_back(0);
// std::cout << "result: " << std::endl;
// std::cout << "\tclass id: " << maxPosition << std::endl;
// std::cout << std::fixed << std::setprecision(10)
// << "\tscore: " << double(out_data[maxPosition]) << std::endl;
times[0] += time[0];
times[1] += time[1];
times[2] += time[2];
}
void FeatureExtracter::FeatureNorm(std::vector<float> &featuer) {
float featuer_sqrt = std::sqrt(std::inner_product(
featuer.begin(), featuer.end(), featuer.begin(), 0.0f));
for (int i = 0; i < featuer.size(); ++i)
featuer[i] /= featuer_sqrt;
} }
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
}
void FeatureExtracter::Run(cv::Mat &img, std::vector<float> &out_data,
std::vector<double> &times) {
cv::Mat resize_img;
std::vector<double> time;
auto preprocess_start = std::chrono::system_clock::now();
this->resize_op_.Run(img, resize_img, this->resize_short_,
this->resize_size_);
this->normalize_op_.Run(&resize_img, this->mean_, this->std_, this->scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto preprocess_end = std::chrono::system_clock::now();
auto infer_start = std::chrono::system_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
auto infer_end = std::chrono::system_clock::now();
auto postprocess_start = std::chrono::system_clock::now();
if (this->feature_norm)
FeatureNorm(out_data);
auto postprocess_end = std::chrono::system_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
time.push_back(double(preprocess_diff.count()));
std::chrono::duration<float> inference_diff = infer_end - infer_start;
double inference_cost_time = double(inference_diff.count());
time.push_back(inference_cost_time);
// std::chrono::duration<float> postprocess_diff =
// postprocess_end - postprocess_start;
time.push_back(0);
// std::cout << "result: " << std::endl;
// std::cout << "\tclass id: " << maxPosition << std::endl;
// std::cout << std::fixed << std::setprecision(10)
// << "\tscore: " << double(out_data[maxPosition]) << std::endl;
times[0] += time[0];
times[1] += time[1];
times[2] += time[2];
}
void FeatureExtracter::FeatureNorm(std::vector<float> &featuer) {
float featuer_sqrt = std::sqrt(std::inner_product(
featuer.begin(), featuer.end(), featuer.begin(), 0.0f));
for (int i = 0; i < featuer.size(); ++i)
featuer[i] /= featuer_sqrt;
}
} // namespace Feature } // namespace Feature
此差异已折叠。
...@@ -32,60 +32,60 @@ ...@@ -32,60 +32,60 @@
namespace Feature { namespace Feature {
void Permute::Run(const cv::Mat *im, float *data) { void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows; int rh = im->rows;
int rw = im->cols; int rw = im->cols;
int rc = im->channels(); int rc = im->channels();
for (int i = 0; i < rc; ++i) { for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i); cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
} }
} }
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean, void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale) { const std::vector<float> &std, float scale) {
(*im).convertTo(*im, CV_32FC3, scale); (*im).convertTo(*im, CV_32FC3, scale);
for (int h = 0; h < im->rows; h++) { for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) { for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] = im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0]; (im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
im->at<cv::Vec3f>(h, w)[1] = im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1]; (im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
im->at<cv::Vec3f>(h, w)[2] = im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2]; (im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
}
}
} }
}
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) { void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols; int resize_w = img.cols;
int resize_h = img.rows; int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2); int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2); int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size); cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect); img = img(rect);
} }
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size, int size) { int resize_short_size, int size) {
int resize_h = 0; int resize_h = 0;
int resize_w = 0; int resize_w = 0;
if (size > 0) { if (size > 0) {
resize_h = size; resize_h = size;
resize_w = size; resize_w = size;
} else { } else {
int w = img.cols; int w = img.cols;
int h = img.rows; int h = img.rows;
float ratio = 1.f; float ratio = 1.f;
if (h < w) { if (h < w) {
ratio = float(resize_short_size) / float(h); ratio = float(resize_short_size) / float(h);
} else { } else {
ratio = float(resize_short_size) / float(w); ratio = float(resize_short_size) / float(w);
}
resize_h = round(float(h) * ratio);
resize_w = round(float(w) * ratio);
}
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
} }
resize_h = round(float(h) * ratio);
resize_w = round(float(w) * ratio);
}
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace Feature } // namespace Feature
...@@ -19,112 +19,112 @@ ...@@ -19,112 +19,112 @@
namespace Detection { namespace Detection {
void InitInfo::Run(cv::Mat *im, ImageBlob *data) { void InitInfo::Run(cv::Mat *im, ImageBlob *data) {
data->im_shape_ = {static_cast<float>(im->rows), data->im_shape_ = {static_cast<float>(im->rows),
static_cast<float>(im->cols)}; static_cast<float>(im->cols)};
data->scale_factor_ = {1., 1.}; data->scale_factor_ = {1., 1.};
data->in_net_shape_ = {static_cast<float>(im->rows), data->in_net_shape_ = {static_cast<float>(im->rows),
static_cast<float>(im->cols)}; static_cast<float>(im->cols)};
} }
void NormalizeImage::Run(cv::Mat *im, ImageBlob *data) { void NormalizeImage::Run(cv::Mat *im, ImageBlob *data) {
double e = 1.0; double e = 1.0;
if (is_scale_) { if (is_scale_) {
e /= 255.0; e /= 255.0;
} }
(*im).convertTo(*im, CV_32FC3, e); (*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) { for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) { for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] = im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean_[0]) / scale_[0]; (im->at<cv::Vec3f>(h, w)[0] - mean_[0]) / scale_[0];
im->at<cv::Vec3f>(h, w)[1] = im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean_[1]) / scale_[1]; (im->at<cv::Vec3f>(h, w)[1] - mean_[1]) / scale_[1];
im->at<cv::Vec3f>(h, w)[2] = im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean_[2]) / scale_[2]; (im->at<cv::Vec3f>(h, w)[2] - mean_[2]) / scale_[2];
}
}
} }
}
}
void Permute::Run(cv::Mat *im, ImageBlob *data) { void Permute::Run(cv::Mat *im, ImageBlob *data) {
int rh = im->rows; int rh = im->rows;
int rw = im->cols; int rw = im->cols;
int rc = im->channels(); int rc = im->channels();
(data->im_data_).resize(rc * rh * rw); (data->im_data_).resize(rc * rh * rw);
float *base = (data->im_data_).data(); float *base = (data->im_data_).data();
for (int i = 0; i < rc; ++i) { for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i);
} }
} }
void Resize::Run(cv::Mat *im, ImageBlob *data) { void Resize::Run(cv::Mat *im, ImageBlob *data) {
auto resize_scale = GenerateScale(*im); auto resize_scale = GenerateScale(*im);
data->im_shape_ = {static_cast<float>(im->cols * resize_scale.first), data->im_shape_ = {static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)}; static_cast<float>(im->rows * resize_scale.second)};
data->in_net_shape_ = {static_cast<float>(im->cols * resize_scale.first), data->in_net_shape_ = {static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)}; static_cast<float>(im->rows * resize_scale.second)};
cv::resize(*im, *im, cv::Size(), resize_scale.first, resize_scale.second, cv::resize(*im, *im, cv::Size(), resize_scale.first, resize_scale.second,
interp_); interp_);
data->im_shape_ = { data->im_shape_ = {
static_cast<float>(im->rows), static_cast<float>(im->cols), static_cast<float>(im->rows), static_cast<float>(im->cols),
}; };
data->scale_factor_ = { data->scale_factor_ = {
resize_scale.second, resize_scale.first, resize_scale.second, resize_scale.first,
}; };
} }
std::pair<double, double> Resize::GenerateScale(const cv::Mat &im) { std::pair<double, double> Resize::GenerateScale(const cv::Mat &im) {
std::pair<double, double> resize_scale; std::pair<double, double> resize_scale;
int origin_w = im.cols; int origin_w = im.cols;
int origin_h = im.rows; int origin_h = im.rows;
if (keep_ratio_) { if (keep_ratio_) {
int im_size_max = std::max(origin_w, origin_h); int im_size_max = std::max(origin_w, origin_h);
int im_size_min = std::min(origin_w, origin_h); int im_size_min = std::min(origin_w, origin_h);
int target_size_max = int target_size_max =
*std::max_element(target_size_.begin(), target_size_.end()); *std::max_element(target_size_.begin(), target_size_.end());
int target_size_min = int target_size_min =
*std::min_element(target_size_.begin(), target_size_.end()); *std::min_element(target_size_.begin(), target_size_.end());
double scale_min = double scale_min =
static_cast<double>(target_size_min) / static_cast<double>(im_size_min); static_cast<double>(target_size_min) / static_cast<double>(im_size_min);
double scale_max = double scale_max =
static_cast<double>(target_size_max) / static_cast<double>(im_size_max); static_cast<double>(target_size_max) / static_cast<double>(im_size_max);
double scale_ratio = std::min(scale_min, scale_max); double scale_ratio = std::min(scale_min, scale_max);
resize_scale = {scale_ratio, scale_ratio}; resize_scale = {scale_ratio, scale_ratio};
} else { } else {
resize_scale.first = resize_scale.first =
static_cast<double>(target_size_[1]) / static_cast<double>(origin_w); static_cast<double>(target_size_[1]) / static_cast<double>(origin_w);
resize_scale.second = resize_scale.second =
static_cast<double>(target_size_[0]) / static_cast<double>(origin_h); static_cast<double>(target_size_[0]) / static_cast<double>(origin_h);
} }
return resize_scale; return resize_scale;
} }
void PadStride::Run(cv::Mat *im, ImageBlob *data) { void PadStride::Run(cv::Mat *im, ImageBlob *data) {
if (stride_ <= 0) { if (stride_ <= 0) {
return; return;
} }
int rc = im->channels(); int rc = im->channels();
int rh = im->rows; int rh = im->rows;
int rw = im->cols; int rw = im->cols;
int nh = (rh / stride_) * stride_ + (rh % stride_ != 0) * stride_; int nh = (rh / stride_) * stride_ + (rh % stride_ != 0) * stride_;
int nw = (rw / stride_) * stride_ + (rw % stride_ != 0) * stride_; int nw = (rw / stride_) * stride_ + (rw % stride_ != 0) * stride_;
cv::copyMakeBorder(*im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT, cv::copyMakeBorder(*im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT,
cv::Scalar(0)); cv::Scalar(0));
data->in_net_shape_ = { data->in_net_shape_ = {
static_cast<float>(im->rows), static_cast<float>(im->cols), static_cast<float>(im->rows), static_cast<float>(im->cols),
}; };
} }
// Preprocessor op running order // Preprocessor op running order
const std::vector<std::string> Preprocessor::RUN_ORDER = { const std::vector <std::string> Preprocessor::RUN_ORDER = {
"InitInfo", "Resize", "NormalizeImage", "PadStride", "Permute"}; "InitInfo", "Resize", "NormalizeImage", "PadStride", "Permute"};
void Preprocessor::Run(cv::Mat *im, ImageBlob *data) { void Preprocessor::Run(cv::Mat *im, ImageBlob *data) {
for (const auto &name : RUN_ORDER) { for (const auto &name : RUN_ORDER) {
if (ops_.find(name) != ops_.end()) { if (ops_.find(name) != ops_.end()) {
ops_[name]->Run(im, data); ops_[name]->Run(im, data);
}
}
} }
}
}
} // namespace Detection } // namespace Detection
...@@ -20,43 +20,43 @@ ...@@ -20,43 +20,43 @@
#include <regex> #include <regex>
void VectorSearch::LoadIndexFile() { void VectorSearch::LoadIndexFile() {
std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index"; std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index";
const char *fname = file_path.c_str(); const char *fname = file_path.c_str();
this->index = faiss::read_index(fname, 0); this->index = faiss::read_index(fname, 0);
} }
void VectorSearch::LoadIdMap() { void VectorSearch::LoadIdMap() {
std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt"; std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt";
std::ifstream in(file_path); std::ifstream in(file_path);
std::string line; std::string line;
std::vector<std::string> m_vec; std::vector <std::string> m_vec;
if (in) { if (in) {
while (getline(in, line)) { while (getline(in, line)) {
std::regex ws_re("\\s+"); std::regex ws_re("\\s+");
std::vector<std::string> v( std::vector <std::string> v(
std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1), std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1),
std::sregex_token_iterator()); std::sregex_token_iterator());
if (v.size() != 2) { if (v.size() != 2) {
std::cout << "The number of element for each line in : " << file_path std::cout << "The number of element for each line in : " << file_path
<< "must be 2, exit the program..." << std::endl; << "must be 2, exit the program..." << std::endl;
exit(1); exit(1);
} else } else
this->id_map.insert(std::pair<long int, std::string>( this->id_map.insert(std::pair<long int, std::string>(
std::stol(v[0], nullptr, 10), v[1])); std::stol(v[0], nullptr, 10), v[1]));
}
} }
}
} }
const SearchResult &VectorSearch::Search(float *feature, int query_number) { const SearchResult &VectorSearch::Search(float *feature, int query_number) {
this->D.resize(this->return_k * query_number); this->D.resize(this->return_k * query_number);
this->I.resize(this->return_k * query_number); this->I.resize(this->return_k * query_number);
this->index->search(query_number, feature, return_k, D.data(), I.data()); this->index->search(query_number, feature, return_k, D.data(), I.data());
this->sr.return_k = this->return_k; this->sr.return_k = this->return_k;
this->sr.D = this->D; this->sr.D = this->D;
this->sr.I = this->I; this->sr.I = this->I;
return this->sr; return this->sr;
} }
const std::string &VectorSearch::GetLabel(faiss::Index::idx_t ind) { const std::string &VectorSearch::GetLabel(faiss::Index::idx_t ind) {
return this->id_map.at(ind); return this->id_map.at(ind);
} }
...@@ -19,60 +19,60 @@ ...@@ -19,60 +19,60 @@
#include <include/yaml_config.h> #include <include/yaml_config.h>
std::vector<std::string> YamlConfig::ReadDict(const std::string &path) { std::vector <std::string> YamlConfig::ReadDict(const std::string &path) {
std::ifstream in(path); std::ifstream in(path);
std::string line; std::string line;
std::vector<std::string> m_vec; std::vector <std::string> m_vec;
if (in) { if (in) {
while (getline(in, line)) { while (getline(in, line)) {
m_vec.push_back(line); m_vec.push_back(line);
}
} else {
std::cout << "no such label file: " << path << ", exit the program..."
<< std::endl;
exit(1);
} }
} else { return m_vec;
std::cout << "no such label file: " << path << ", exit the program..."
<< std::endl;
exit(1);
}
return m_vec;
} }
std::map<int, std::string> YamlConfig::ReadIndexId(const std::string &path) { std::map<int, std::string> YamlConfig::ReadIndexId(const std::string &path) {
std::ifstream in(path); std::ifstream in(path);
std::string line; std::string line;
std::map<int, std::string> m_vec; std::map<int, std::string> m_vec;
if (in) { if (in) {
while (getline(in, line)) { while (getline(in, line)) {
std::regex ws_re("\\s+"); std::regex ws_re("\\s+");
std::vector<std::string> v( std::vector <std::string> v(
std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1), std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1),
std::sregex_token_iterator()); std::sregex_token_iterator());
if (v.size() != 3) { if (v.size() != 3) {
std::cout << "The number of element for each line in : " << path std::cout << "The number of element for each line in : " << path
<< "must be 3, exit the program..." << std::endl; << "must be 3, exit the program..." << std::endl;
exit(1); exit(1);
} else } else
m_vec.insert(std::pair<int, std::string>(stoi(v[0]), v[2])); m_vec.insert(std::pair<int, std::string>(stoi(v[0]), v[2]));
}
} }
} return m_vec;
return m_vec;
} }
YAML::Node YamlConfig::ReadYamlConfig(const std::string &path) { YAML::Node YamlConfig::ReadYamlConfig(const std::string &path) {
YAML::Node config; YAML::Node config;
try { try {
config = YAML::LoadFile(path); config = YAML::LoadFile(path);
} catch (YAML::BadFile &e) { } catch (YAML::BadFile &e) {
std::cout << "Something wrong in yaml file, please check yaml file" std::cout << "Something wrong in yaml file, please check yaml file"
<< std::endl; << std::endl;
exit(1); exit(1);
} }
return config; return config;
} }
void YamlConfig::PrintConfigInfo() { void YamlConfig::PrintConfigInfo() {
std::cout << this->config_file << std::endl; std::cout << this->config_file << std::endl;
// for (YAML::const_iterator // for (YAML::const_iterator
// it=config_file.begin();it!=config_file.end();++it) // it=config_file.begin();it!=config_file.end();++it)
// { // {
// std::cout << it->as<std::string>() << "\n"; // std::cout << it->as<std::string>() << "\n";
// } // }
} }
OPENCV_DIR=/work/project/project/cpp_infer/opencv-3.4.7/opencv3 OPENCV_DIR=${opencv_install_dir}
LIB_DIR=/work/project/project/cpp_infer/paddle_inference/ LIB_DIR=${paddle_inference_dir}
CUDA_LIB_DIR=/usr/local/cuda/lib64 CUDA_LIB_DIR=/usr/local/cuda/lib64
CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/ CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/
FAISS_DIR=/work/project/project/cpp_infer/faiss/faiss_install FAISS_DIR=${faiss_install_dir}
FAISS_WITH_MKL=OFF FAISS_WITH_MKL=OFF
BUILD_DIR=build BUILD_DIR=build
...@@ -21,4 +21,4 @@ cmake .. \ ...@@ -21,4 +21,4 @@ cmake .. \
-DFAISS_DIR=${FAISS_DIR} \ -DFAISS_DIR=${FAISS_DIR} \
-DFAISS_WITH_MKL=${FAISS_WITH_MKL} -DFAISS_WITH_MKL=${FAISS_WITH_MKL}
make -j make -j
\ No newline at end of file
# model load config
use_gpu 0
gpu_id 0
gpu_mem 4000
cpu_threads 10
use_mkldnn 1
use_tensorrt 0
use_fp16 0
# cls config
cls_model_path /PaddleClas/inference/cls_infer.pdmodel
cls_params_path /PaddleClas/inference/cls_infer.pdiparams
resize_short_size 256
crop_size 224
# for log env info
benchmark 0
./build/clas_system ../configs/inference_rec.yaml
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册