提交 74da43ad 编写于 作者: D dongshuilong

add l2norm for rec feature

上级 ba65fa9a
......@@ -33,11 +33,11 @@
using namespace paddle_infer;
namespace PaddleClas {
namespace Feature {
class Classifier {
class FeatureExtracter {
public:
explicit Classifier(const YAML::Node &config_file) {
explicit FeatureExtracter(const YAML::Node &config_file) {
this->use_gpu_ = config_file["Global"]["use_gpu"].as<bool>();
if (config_file["Global"]["gpu_id"].IsDefined())
this->gpu_id_ = config_file["Global"]["gpu_id"].as<int>();
......@@ -68,6 +68,9 @@ public:
this->std_ = config_file["RecPreProcess"]["transform_ops"][1]
["NormalizeImage"]["std"]
.as<std::vector<float>>();
if (config_file["Global"]["rec_feature_normlize"].IsDefined())
this->feature_norm =
config_file["Global"]["rec_feature_normlize"].as<bool>();
LoadModel(cls_model_path_, cls_params_path_);
}
......@@ -78,6 +81,7 @@ public:
// Run predictor
void Run(cv::Mat &img, std::vector<float> &out_data,
std::vector<double> &times);
void FeatureNorm(std::vector<float> &feature);
std::shared_ptr<Predictor> predictor_;
......@@ -88,6 +92,7 @@ private:
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
bool use_tensorrt_ = false;
bool feature_norm = true;
bool use_fp16_ = false;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> std_ = {0.229f, 0.224f, 0.225f};
......@@ -103,4 +108,4 @@ private:
Permute permute_op_;
};
} // namespace PaddleClas
} // namespace Feature
......@@ -23,8 +23,8 @@ static inline bool SortScorePairDescend(const std::pair<float, T> &pair1,
return pair1.first > pair2.first;
}
float RectOverlap(const PaddleDetection::ObjectResult &a,
const PaddleDetection::ObjectResult &b) {
float RectOverlap(const Detection::ObjectResult &a,
const Detection::ObjectResult &b) {
float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1);
float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1);
......@@ -40,7 +40,7 @@ float RectOverlap(const PaddleDetection::ObjectResult &a,
// top_k: if -1, keep all; otherwise, keep at most top_k.
// score_index_vec: store the sorted (score, index) pair.
inline void
GetMaxScoreIndex(const std::vector<PaddleDetection::ObjectResult> &det_result,
GetMaxScoreIndex(const std::vector<Detection::ObjectResult> &det_result,
const float threshold,
std::vector<std::pair<float, int>> &score_index_vec) {
// Generate index score pairs.
......@@ -61,7 +61,7 @@ GetMaxScoreIndex(const std::vector<PaddleDetection::ObjectResult> &det_result,
// }
}
void NMSBoxes(const std::vector<PaddleDetection::ObjectResult> det_result,
void NMSBoxes(const std::vector<Detection::ObjectResult> det_result,
const float score_threshold, const float nms_threshold,
std::vector<int> &indices) {
int a = 1;
......
......@@ -31,7 +31,7 @@
using namespace paddle_infer;
namespace PaddleDetection {
namespace Detection {
// Object Detection Result
struct ObjectResult {
// Rectangle coordinates of detected object: left, right, top, down
......@@ -132,4 +132,4 @@ private:
std::vector<int> out_bbox_num_data_;
};
} // namespace PaddleDetection
} // namespace Detection
......@@ -29,7 +29,7 @@
using namespace std;
namespace PaddleClas {
namespace Feature {
class Normalize {
public:
......@@ -54,4 +54,4 @@ public:
int size = 0);
};
} // namespace PaddleClas
} // namespace Feature
......@@ -28,7 +28,7 @@
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
namespace PaddleDetection {
namespace Detection {
// Object for storing all preprocessed data
class ImageBlob {
......@@ -152,4 +152,4 @@ private:
std::unordered_map<std::string, std::shared_ptr<PreprocessOp>> ops_;
};
} // namespace PaddleDetection
} // namespace Detection
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/cls.h>
#include <cmath>
#include <include/feature_extracter.h>
#include <numeric>
namespace PaddleClas {
namespace Feature {
void Classifier::LoadModel(const std::string &model_path,
const std::string &params_path) {
void FeatureExtracter::LoadModel(const std::string &model_path,
const std::string &params_path) {
paddle_infer::Config config;
config.SetModel(model_path, params_path);
......@@ -52,11 +54,9 @@ void Classifier::LoadModel(const std::string &model_path,
this->predictor_ = CreatePredictor(config);
}
void Classifier::Run(cv::Mat &img, std::vector<float> &out_data,
std::vector<double> &times) {
cv::Mat srcimg;
void FeatureExtracter::Run(cv::Mat &img, std::vector<float> &out_data,
std::vector<double> &times) {
cv::Mat resize_img;
img.copyTo(srcimg);
std::vector<double> time;
auto preprocess_start = std::chrono::system_clock::now();
......@@ -86,10 +86,10 @@ void Classifier::Run(cv::Mat &img, std::vector<float> &out_data,
output_t->CopyToCpu(out_data.data());
auto infer_end = std::chrono::system_clock::now();
// auto postprocess_start = std::chrono::system_clock::now();
// int maxPosition =
// max_element(out_data.begin(), out_data.end()) - out_data.begin();
// auto postprocess_end = std::chrono::system_clock::now();
auto postprocess_start = std::chrono::system_clock::now();
if (this->feature_norm)
FeatureNorm(out_data);
auto postprocess_end = std::chrono::system_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
......@@ -110,4 +110,10 @@ void Classifier::Run(cv::Mat &img, std::vector<float> &out_data,
times[2] += time[2];
}
} // namespace PaddleClas
void FeatureExtracter::FeatureNorm(std::vector<float> &featuer) {
float featuer_sqrt = std::sqrt(std::inner_product(
featuer.begin(), featuer.end(), featuer.begin(), 0.0f));
for (int i = 0; i < featuer.size(); ++i)
featuer[i] /= featuer_sqrt;
}
} // namespace Feature
......@@ -28,7 +28,7 @@
#include <auto_log/autolog.h>
#include <gflags/gflags.h>
#include <include/cls.h>
#include <include/feature_extracter.h>
#include <include/nms.h>
#include <include/object_detector.h>
#include <include/vector_search.h>
......@@ -42,8 +42,8 @@ DEFINE_string(c, "", "Path of yaml file");
void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
const std::vector<std::string> &all_img_paths,
const int batch_size, PaddleDetection::ObjectDetector *det,
std::vector<PaddleDetection::ObjectResult> &im_result,
const int batch_size, Detection::ObjectDetector *det,
std::vector<Detection::ObjectResult> &im_result,
std::vector<int> &im_bbox_num, std::vector<double> &det_t,
const bool visual_det = false,
const bool run_benchmark = false,
......@@ -63,7 +63,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
// }
// Store all detected result
std::vector<PaddleDetection::ObjectResult> result;
std::vector<Detection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times;
bool is_rbox = false;
......@@ -73,7 +73,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
// get labels and colormap
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
auto colormap = Detection::GenerateColorMap(labels.size());
int item_start_idx = 0;
for (int i = 0; i < left_image_cnt; i++) {
......@@ -81,7 +81,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
int detect_num = 0;
for (int j = 0; j < bbox_num[i]; j++) {
PaddleDetection::ObjectResult item = result[item_start_idx + j];
Detection::ObjectResult item = result[item_start_idx + j];
if (item.confidence < det->GetThreshold() || item.class_id == -1) {
continue;
}
......@@ -110,8 +110,8 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
std::cout << all_img_paths.at(idx * batch_size + i)
<< " The number of detected box: " << detect_num
<< std::endl;
cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, im_result, labels, colormap, is_rbox);
cv::Mat vis_img = Detection::VisualizeResult(im, im_result, labels,
colormap, is_rbox);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
......@@ -134,7 +134,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
}
void PrintResult(std::string &img_path,
std::vector<PaddleDetection::ObjectResult> &det_result,
std::vector<Detection::ObjectResult> &det_result,
std::vector<int> &indeices, VectorSearch &vector_search,
SearchResult &search_result) {
printf("%s:\n", img_path.c_str());
......@@ -167,8 +167,8 @@ int main(int argc, char **argv) {
config.PrintConfigInfo();
// initialize detector, rec_Model, vector_search
PaddleClas::Classifier classifier(config.config_file);
PaddleDetection::ObjectDetector detector(config.config_file);
Feature::FeatureExtracter feature_extracter(config.config_file);
Detection::ObjectDetector detector(config.config_file);
VectorSearch searcher(config.config_file);
// config
......@@ -212,7 +212,7 @@ int main(int argc, char **argv) {
std::vector<cv::Mat> batch_imgs;
std::vector<std::string> img_paths;
// for detection
std::vector<PaddleDetection::ObjectResult> det_result;
std::vector<Detection::ObjectResult> det_result;
std::vector<int> det_bbox_num;
// for vector search
std::vector<float> features;
......@@ -243,7 +243,7 @@ int main(int argc, char **argv) {
det_result.resize(max_det_results);
}
// step2: add the whole image for recognition to improve recall
PaddleDetection::ObjectResult result_whole_img = {
Detection::ObjectResult result_whole_img = {
{0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
det_result.push_back(result_whole_img);
det_bbox_num[0] = det_result.size() + 1;
......@@ -255,7 +255,7 @@ int main(int argc, char **argv) {
int h = det_result[j].rect[3] - det_result[j].rect[1];
cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
cv::Mat crop_img = srcimg(rect);
classifier.Run(crop_img, feature, cls_times);
feature_extracter.Run(crop_img, feature, cls_times);
features.insert(features.end(), feature.begin(), feature.end());
}
......
......@@ -19,7 +19,7 @@
using namespace paddle_infer;
namespace PaddleDetection {
namespace Detection {
// Load Model and create model predictor
void ObjectDetector::LoadModel(const std::string &model_dir,
......@@ -362,4 +362,4 @@ std::vector<int> GenerateColorMap(int num_class) {
return colormap;
}
} // namespace PaddleDetection
} // namespace Detection
......@@ -30,7 +30,7 @@
#include <include/preprocess_op.h>
namespace PaddleClas {
namespace Feature {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
......@@ -88,4 +88,4 @@ void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace PaddleClas
} // namespace Feature
......@@ -17,7 +17,7 @@
#include "include/preprocess_op_det.h"
namespace PaddleDetection {
namespace Detection {
void InitInfo::Run(cv::Mat *im, ImageBlob *data) {
data->im_shape_ = {static_cast<float>(im->rows),
......@@ -127,4 +127,4 @@ void Preprocessor::Run(cv::Mat *im, ImageBlob *data) {
}
}
} // namespace PaddleDetection
} // namespace Detection
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册