提交 74da43ad 编写于 作者: D dongshuilong

add l2norm for rec feature

上级 ba65fa9a
...@@ -33,11 +33,11 @@ ...@@ -33,11 +33,11 @@
using namespace paddle_infer; using namespace paddle_infer;
namespace PaddleClas { namespace Feature {
class Classifier { class FeatureExtracter {
public: public:
explicit Classifier(const YAML::Node &config_file) { explicit FeatureExtracter(const YAML::Node &config_file) {
this->use_gpu_ = config_file["Global"]["use_gpu"].as<bool>(); this->use_gpu_ = config_file["Global"]["use_gpu"].as<bool>();
if (config_file["Global"]["gpu_id"].IsDefined()) if (config_file["Global"]["gpu_id"].IsDefined())
this->gpu_id_ = config_file["Global"]["gpu_id"].as<int>(); this->gpu_id_ = config_file["Global"]["gpu_id"].as<int>();
...@@ -68,6 +68,9 @@ public: ...@@ -68,6 +68,9 @@ public:
this->std_ = config_file["RecPreProcess"]["transform_ops"][1] this->std_ = config_file["RecPreProcess"]["transform_ops"][1]
["NormalizeImage"]["std"] ["NormalizeImage"]["std"]
.as<std::vector<float>>(); .as<std::vector<float>>();
if (config_file["Global"]["rec_feature_normlize"].IsDefined())
this->feature_norm =
config_file["Global"]["rec_feature_normlize"].as<bool>();
LoadModel(cls_model_path_, cls_params_path_); LoadModel(cls_model_path_, cls_params_path_);
} }
...@@ -78,6 +81,7 @@ public: ...@@ -78,6 +81,7 @@ public:
// Run predictor // Run predictor
void Run(cv::Mat &img, std::vector<float> &out_data, void Run(cv::Mat &img, std::vector<float> &out_data,
std::vector<double> &times); std::vector<double> &times);
void FeatureNorm(std::vector<float> &feature);
std::shared_ptr<Predictor> predictor_; std::shared_ptr<Predictor> predictor_;
...@@ -88,6 +92,7 @@ private: ...@@ -88,6 +92,7 @@ private:
int cpu_math_library_num_threads_ = 4; int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false; bool use_mkldnn_ = false;
bool use_tensorrt_ = false; bool use_tensorrt_ = false;
bool feature_norm = true;
bool use_fp16_ = false; bool use_fp16_ = false;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f}; std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> std_ = {0.229f, 0.224f, 0.225f}; std::vector<float> std_ = {0.229f, 0.224f, 0.225f};
...@@ -103,4 +108,4 @@ private: ...@@ -103,4 +108,4 @@ private:
Permute permute_op_; Permute permute_op_;
}; };
} // namespace PaddleClas } // namespace Feature
...@@ -23,8 +23,8 @@ static inline bool SortScorePairDescend(const std::pair<float, T> &pair1, ...@@ -23,8 +23,8 @@ static inline bool SortScorePairDescend(const std::pair<float, T> &pair1,
return pair1.first > pair2.first; return pair1.first > pair2.first;
} }
float RectOverlap(const PaddleDetection::ObjectResult &a, float RectOverlap(const Detection::ObjectResult &a,
const PaddleDetection::ObjectResult &b) { const Detection::ObjectResult &b) {
float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1); float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1);
float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1); float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1);
...@@ -40,7 +40,7 @@ float RectOverlap(const PaddleDetection::ObjectResult &a, ...@@ -40,7 +40,7 @@ float RectOverlap(const PaddleDetection::ObjectResult &a,
// top_k: if -1, keep all; otherwise, keep at most top_k. // top_k: if -1, keep all; otherwise, keep at most top_k.
// score_index_vec: store the sorted (score, index) pair. // score_index_vec: store the sorted (score, index) pair.
inline void inline void
GetMaxScoreIndex(const std::vector<PaddleDetection::ObjectResult> &det_result, GetMaxScoreIndex(const std::vector<Detection::ObjectResult> &det_result,
const float threshold, const float threshold,
std::vector<std::pair<float, int>> &score_index_vec) { std::vector<std::pair<float, int>> &score_index_vec) {
// Generate index score pairs. // Generate index score pairs.
...@@ -61,7 +61,7 @@ GetMaxScoreIndex(const std::vector<PaddleDetection::ObjectResult> &det_result, ...@@ -61,7 +61,7 @@ GetMaxScoreIndex(const std::vector<PaddleDetection::ObjectResult> &det_result,
// } // }
} }
void NMSBoxes(const std::vector<PaddleDetection::ObjectResult> det_result, void NMSBoxes(const std::vector<Detection::ObjectResult> det_result,
const float score_threshold, const float nms_threshold, const float score_threshold, const float nms_threshold,
std::vector<int> &indices) { std::vector<int> &indices) {
int a = 1; int a = 1;
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
using namespace paddle_infer; using namespace paddle_infer;
namespace PaddleDetection { namespace Detection {
// Object Detection Result // Object Detection Result
struct ObjectResult { struct ObjectResult {
// Rectangle coordinates of detected object: left, right, top, down // Rectangle coordinates of detected object: left, right, top, down
...@@ -132,4 +132,4 @@ private: ...@@ -132,4 +132,4 @@ private:
std::vector<int> out_bbox_num_data_; std::vector<int> out_bbox_num_data_;
}; };
} // namespace PaddleDetection } // namespace Detection
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
using namespace std; using namespace std;
namespace PaddleClas { namespace Feature {
class Normalize { class Normalize {
public: public:
...@@ -54,4 +54,4 @@ public: ...@@ -54,4 +54,4 @@ public:
int size = 0); int size = 0);
}; };
} // namespace PaddleClas } // namespace Feature
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <opencv2/highgui/highgui.hpp> #include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp> #include <opencv2/imgproc/imgproc.hpp>
namespace PaddleDetection { namespace Detection {
// Object for storing all preprocessed data // Object for storing all preprocessed data
class ImageBlob { class ImageBlob {
...@@ -152,4 +152,4 @@ private: ...@@ -152,4 +152,4 @@ private:
std::unordered_map<std::string, std::shared_ptr<PreprocessOp>> ops_; std::unordered_map<std::string, std::shared_ptr<PreprocessOp>> ops_;
}; };
} // namespace PaddleDetection } // namespace Detection
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <include/cls.h> #include <cmath>
#include <include/feature_extracter.h>
#include <numeric>
namespace PaddleClas { namespace Feature {
void Classifier::LoadModel(const std::string &model_path, void FeatureExtracter::LoadModel(const std::string &model_path,
const std::string &params_path) { const std::string &params_path) {
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(model_path, params_path); config.SetModel(model_path, params_path);
...@@ -52,11 +54,9 @@ void Classifier::LoadModel(const std::string &model_path, ...@@ -52,11 +54,9 @@ void Classifier::LoadModel(const std::string &model_path,
this->predictor_ = CreatePredictor(config); this->predictor_ = CreatePredictor(config);
} }
void Classifier::Run(cv::Mat &img, std::vector<float> &out_data, void FeatureExtracter::Run(cv::Mat &img, std::vector<float> &out_data,
std::vector<double> &times) { std::vector<double> &times) {
cv::Mat srcimg;
cv::Mat resize_img; cv::Mat resize_img;
img.copyTo(srcimg);
std::vector<double> time; std::vector<double> time;
auto preprocess_start = std::chrono::system_clock::now(); auto preprocess_start = std::chrono::system_clock::now();
...@@ -86,10 +86,10 @@ void Classifier::Run(cv::Mat &img, std::vector<float> &out_data, ...@@ -86,10 +86,10 @@ void Classifier::Run(cv::Mat &img, std::vector<float> &out_data,
output_t->CopyToCpu(out_data.data()); output_t->CopyToCpu(out_data.data());
auto infer_end = std::chrono::system_clock::now(); auto infer_end = std::chrono::system_clock::now();
// auto postprocess_start = std::chrono::system_clock::now(); auto postprocess_start = std::chrono::system_clock::now();
// int maxPosition = if (this->feature_norm)
// max_element(out_data.begin(), out_data.end()) - out_data.begin(); FeatureNorm(out_data);
// auto postprocess_end = std::chrono::system_clock::now(); auto postprocess_end = std::chrono::system_clock::now();
std::chrono::duration<float> preprocess_diff = std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start; preprocess_end - preprocess_start;
...@@ -110,4 +110,10 @@ void Classifier::Run(cv::Mat &img, std::vector<float> &out_data, ...@@ -110,4 +110,10 @@ void Classifier::Run(cv::Mat &img, std::vector<float> &out_data,
times[2] += time[2]; times[2] += time[2];
} }
} // namespace PaddleClas void FeatureExtracter::FeatureNorm(std::vector<float> &featuer) {
float featuer_sqrt = std::sqrt(std::inner_product(
featuer.begin(), featuer.end(), featuer.begin(), 0.0f));
for (int i = 0; i < featuer.size(); ++i)
featuer[i] /= featuer_sqrt;
}
} // namespace Feature
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <auto_log/autolog.h> #include <auto_log/autolog.h>
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <include/cls.h> #include <include/feature_extracter.h>
#include <include/nms.h> #include <include/nms.h>
#include <include/object_detector.h> #include <include/object_detector.h>
#include <include/vector_search.h> #include <include/vector_search.h>
...@@ -42,8 +42,8 @@ DEFINE_string(c, "", "Path of yaml file"); ...@@ -42,8 +42,8 @@ DEFINE_string(c, "", "Path of yaml file");
void DetPredictImage(const std::vector<cv::Mat> &batch_imgs, void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
const std::vector<std::string> &all_img_paths, const std::vector<std::string> &all_img_paths,
const int batch_size, PaddleDetection::ObjectDetector *det, const int batch_size, Detection::ObjectDetector *det,
std::vector<PaddleDetection::ObjectResult> &im_result, std::vector<Detection::ObjectResult> &im_result,
std::vector<int> &im_bbox_num, std::vector<double> &det_t, std::vector<int> &im_bbox_num, std::vector<double> &det_t,
const bool visual_det = false, const bool visual_det = false,
const bool run_benchmark = false, const bool run_benchmark = false,
...@@ -63,7 +63,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs, ...@@ -63,7 +63,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
// } // }
// Store all detected result // Store all detected result
std::vector<PaddleDetection::ObjectResult> result; std::vector<Detection::ObjectResult> result;
std::vector<int> bbox_num; std::vector<int> bbox_num;
std::vector<double> det_times; std::vector<double> det_times;
bool is_rbox = false; bool is_rbox = false;
...@@ -73,7 +73,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs, ...@@ -73,7 +73,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times); det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
// get labels and colormap // get labels and colormap
auto labels = det->GetLabelList(); auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size()); auto colormap = Detection::GenerateColorMap(labels.size());
int item_start_idx = 0; int item_start_idx = 0;
for (int i = 0; i < left_image_cnt; i++) { for (int i = 0; i < left_image_cnt; i++) {
...@@ -81,7 +81,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs, ...@@ -81,7 +81,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
int detect_num = 0; int detect_num = 0;
for (int j = 0; j < bbox_num[i]; j++) { for (int j = 0; j < bbox_num[i]; j++) {
PaddleDetection::ObjectResult item = result[item_start_idx + j]; Detection::ObjectResult item = result[item_start_idx + j];
if (item.confidence < det->GetThreshold() || item.class_id == -1) { if (item.confidence < det->GetThreshold() || item.class_id == -1) {
continue; continue;
} }
...@@ -110,8 +110,8 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs, ...@@ -110,8 +110,8 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
std::cout << all_img_paths.at(idx * batch_size + i) std::cout << all_img_paths.at(idx * batch_size + i)
<< " The number of detected box: " << detect_num << " The number of detected box: " << detect_num
<< std::endl; << std::endl;
cv::Mat vis_img = PaddleDetection::VisualizeResult( cv::Mat vis_img = Detection::VisualizeResult(im, im_result, labels,
im, im_result, labels, colormap, is_rbox); colormap, is_rbox);
std::vector<int> compression_params; std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY); compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95); compression_params.push_back(95);
...@@ -134,7 +134,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs, ...@@ -134,7 +134,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
} }
void PrintResult(std::string &img_path, void PrintResult(std::string &img_path,
std::vector<PaddleDetection::ObjectResult> &det_result, std::vector<Detection::ObjectResult> &det_result,
std::vector<int> &indeices, VectorSearch &vector_search, std::vector<int> &indeices, VectorSearch &vector_search,
SearchResult &search_result) { SearchResult &search_result) {
printf("%s:\n", img_path.c_str()); printf("%s:\n", img_path.c_str());
...@@ -167,8 +167,8 @@ int main(int argc, char **argv) { ...@@ -167,8 +167,8 @@ int main(int argc, char **argv) {
config.PrintConfigInfo(); config.PrintConfigInfo();
// initialize detector, rec_Model, vector_search // initialize detector, rec_Model, vector_search
PaddleClas::Classifier classifier(config.config_file); Feature::FeatureExtracter feature_extracter(config.config_file);
PaddleDetection::ObjectDetector detector(config.config_file); Detection::ObjectDetector detector(config.config_file);
VectorSearch searcher(config.config_file); VectorSearch searcher(config.config_file);
// config // config
...@@ -212,7 +212,7 @@ int main(int argc, char **argv) { ...@@ -212,7 +212,7 @@ int main(int argc, char **argv) {
std::vector<cv::Mat> batch_imgs; std::vector<cv::Mat> batch_imgs;
std::vector<std::string> img_paths; std::vector<std::string> img_paths;
// for detection // for detection
std::vector<PaddleDetection::ObjectResult> det_result; std::vector<Detection::ObjectResult> det_result;
std::vector<int> det_bbox_num; std::vector<int> det_bbox_num;
// for vector search // for vector search
std::vector<float> features; std::vector<float> features;
...@@ -243,7 +243,7 @@ int main(int argc, char **argv) { ...@@ -243,7 +243,7 @@ int main(int argc, char **argv) {
det_result.resize(max_det_results); det_result.resize(max_det_results);
} }
// step2: add the whole image for recognition to improve recall // step2: add the whole image for recognition to improve recall
PaddleDetection::ObjectResult result_whole_img = { Detection::ObjectResult result_whole_img = {
{0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0}; {0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
det_result.push_back(result_whole_img); det_result.push_back(result_whole_img);
det_bbox_num[0] = det_result.size() + 1; det_bbox_num[0] = det_result.size() + 1;
...@@ -255,7 +255,7 @@ int main(int argc, char **argv) { ...@@ -255,7 +255,7 @@ int main(int argc, char **argv) {
int h = det_result[j].rect[3] - det_result[j].rect[1]; int h = det_result[j].rect[3] - det_result[j].rect[1];
cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h); cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
cv::Mat crop_img = srcimg(rect); cv::Mat crop_img = srcimg(rect);
classifier.Run(crop_img, feature, cls_times); feature_extracter.Run(crop_img, feature, cls_times);
features.insert(features.end(), feature.begin(), feature.end()); features.insert(features.end(), feature.begin(), feature.end());
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
using namespace paddle_infer; using namespace paddle_infer;
namespace PaddleDetection { namespace Detection {
// Load Model and create model predictor // Load Model and create model predictor
void ObjectDetector::LoadModel(const std::string &model_dir, void ObjectDetector::LoadModel(const std::string &model_dir,
...@@ -362,4 +362,4 @@ std::vector<int> GenerateColorMap(int num_class) { ...@@ -362,4 +362,4 @@ std::vector<int> GenerateColorMap(int num_class) {
return colormap; return colormap;
} }
} // namespace PaddleDetection } // namespace Detection
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
namespace PaddleClas { namespace Feature {
void Permute::Run(const cv::Mat *im, float *data) { void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows; int rh = im->rows;
...@@ -88,4 +88,4 @@ void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -88,4 +88,4 @@ void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
} }
} // namespace PaddleClas } // namespace Feature
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "include/preprocess_op_det.h" #include "include/preprocess_op_det.h"
namespace PaddleDetection { namespace Detection {
void InitInfo::Run(cv::Mat *im, ImageBlob *data) { void InitInfo::Run(cv::Mat *im, ImageBlob *data) {
data->im_shape_ = {static_cast<float>(im->rows), data->im_shape_ = {static_cast<float>(im->rows),
...@@ -127,4 +127,4 @@ void Preprocessor::Run(cv::Mat *im, ImageBlob *data) { ...@@ -127,4 +127,4 @@ void Preprocessor::Run(cv::Mat *im, ImageBlob *data) {
} }
} }
} // namespace PaddleDetection } // namespace Detection
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册