diff --git a/deploy/lite_shitu/README.md b/deploy/lite_shitu/README.md index 52871c3c16dc9990f9cf23de24b24cb54067cac6..e2a03caedd0d4bf63af96d3541d1a8d021206e52 100644 --- a/deploy/lite_shitu/README.md +++ b/deploy/lite_shitu/README.md @@ -92,9 +92,9 @@ PaddleClas 提供了转换并优化后的推理模型,可以直接参考下方 ```shell # 进入lite_ppshitu目录 cd $PaddleClas/deploy/lite_shitu -wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/lite/ppshitu_lite_models_v1.1.tar -tar -xf ppshitu_lite_models_v1.1.tar -rm -f ppshitu_lite_models_v1.1.tar +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/lite/ppshitu_lite_models_v1.2.tar +tar -xf ppshitu_lite_models_v1.2.tar +rm -f ppshitu_lite_models_v1.2.tar ``` #### 2.1.2 使用其他模型 @@ -162,7 +162,7 @@ git clone https://github.com/PaddlePaddle/PaddleDetection.git # 进入PaddleDetection根目录 cd PaddleDetection # 将预训练模型导出为inference模型 -python tools/export_model.py -c configs/picodet/application/mainbody_detection/picodet_lcnet_x2_5_640_mainbody.yml -o weights=https://paddledet.bj.bcebos.com/models/picodet_lcnet_x2_5_640_mainbody.pdparams --output_dir=inference +python tools/export_model.py -c configs/picodet/application/mainbody_detection/picodet_lcnet_x2_5_640_mainbody.yml -o weights=https://paddledet.bj.bcebos.com/models/picodet_lcnet_x2_5_640_mainbody.pdparams export_post_process=False --output_dir=inference # 将inference模型转化为Paddle-Lite优化模型 paddle_lite_opt --model_file=inference/picodet_lcnet_x2_5_640_mainbody/model.pdmodel --param_file=inference/picodet_lcnet_x2_5_640_mainbody/model.pdiparams --optimize_out=inference/picodet_lcnet_x2_5_640_mainbody/mainbody_det # 将转好的模型复制到lite_shitu目录下 @@ -183,24 +183,56 @@ cd deploy/lite_shitu **注意**:`--optimize_out` 参数为优化后模型的保存路径,无需加后缀`.nb`;`--model_file` 参数为模型结构信息文件的路径,`--param_file` 参数为模型权重信息文件的路径,请注意文件名。 -### 2.2 将yaml文件转换成json文件 +### 2.2 生成新的检索库 + +由于lite 版本的检索库用的是`faiss1.5.3`版本,与新版本不兼容,因此需要重新生成index库 + +#### 2.2.1 数据及环境配置 + +```shell +# 进入上级目录 +cd .. +# 下载瓶装饮料数据集 +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar && tar -xf drink_dataset_v1.0.tar +rm -rf drink_dataset_v1.0.tar +rm -rf drink_dataset_v1.0/index + +# 安装1.5.3版本的faiss +pip install faiss-cpu==1.5.3 + +# 下载通用识别模型,可替换成自己的inference model +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar +tar -xf general_PPLCNet_x2_5_lite_v1.0_infer.tar +rm -rf general_PPLCNet_x2_5_lite_v1.0_infer.tar +``` + +#### 2.2.2 生成新的index文件 + +```shell +# 生成新的index库,注意指定好识别模型的路径,同时将index_mothod修改成Flat,HNSW32和IVF在此版本中可能存在bug,请慎重使用。 +# 如果使用自己的识别模型,对应的修改inference model的目录 +python python/build_gallery.py -c configs/inference_drink.yaml -o Global.rec_inference_model_dir=general_PPLCNet_x2_5_lite_v1.0_infer -o IndexProcess.index_method=Flat + +# 进入到lite_shitu目录 +cd lite_shitu +mv ../drink_dataset_v1.0 . +``` + +### 2.3 将yaml文件转换成json文件 ```shell # 如果测试单张图像 -python generate_json_config.py --det_model_path ppshitu_lite_models_v1.1/mainbody_PPLCNet_x2_5_640_quant_v1.1_lite.nb --rec_model_path ppshitu_lite_models_v1.1/general_PPLCNet_x2_5_lite_v1.1_infer.nb --img_path images/demo.jpg +python generate_json_config.py --det_model_path ppshitu_lite_models_v1.2/mainbody_PPLCNet_x2_5_640_v1.2_lite.nb --rec_model_path ppshitu_lite_models_v1.2/general_PPLCNet_x2_5_lite_v1.2_infer.nb --img_path images/demo.jpeg # or # 如果测试多张图像 -python generate_json_config.py --det_model_path ppshitu_lite_models_v1.1/mainbody_PPLCNet_x2_5_640_quant_v1.1_lite.nb --rec_model_path ppshitu_lite_models_v1.1/general_PPLCNet_x2_5_lite_v1.1_infer.nb --img_dir images +python generate_json_config.py --det_model_path ppshitu_lite_models_v1.2/mainbody_PPLCNet_x2_5_640_v1.2_lite.nb --rec_model_path ppshitu_lite_models_v1.2/general_PPLCNet_x2_5_lite_v1.2_infer.nb --img_dir images # 执行完成后,会在lit_shitu下生成shitu_config.json配置文件 ``` -### 2.3 index字典转换 +### 2.4 index字典转换 由于python的检索库字典,使用`pickle`进行的序列化存储,导致C++不方便读取,因此需要进行转换 ```shell -# 下载瓶装饮料数据集 -wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar && tar -xf drink_dataset_v1.0.tar -rm -rf drink_dataset_v1.0.tar # 转化id_map.pkl为id_map.txt python transform_id_map.py -c ../configs/inference_drink.yaml @@ -208,7 +240,7 @@ python transform_id_map.py -c ../configs/inference_drink.yaml 转换成功后,会在`IndexProcess.index_dir`目录下生成`id_map.txt`。 -### 2.4 与手机联调 +### 2.5 与手机联调 首先需要进行一些准备工作。 1. 准备一台arm8的安卓手机,如果编译的预测库是armv7,则需要arm7的手机,并修改Makefile中`ARM_ABI=arm7`。 @@ -308,8 +340,9 @@ chmod 777 pp_shitu 运行效果如下: ``` -images/demo.jpg: - result0: bbox[253, 275, 1146, 872], score: 0.974196, label: 伊藤园_果蔬汁 +images/demo.jpeg: + result0: bbox[344, 98, 527, 593], score: 0.811656, label: 红牛-强化型 + result1: bbox[0, 0, 600, 600], score: 0.729664, label: 红牛-强化型 ``` ## FAQ diff --git a/deploy/lite_shitu/images/demo.jpeg b/deploy/lite_shitu/images/demo.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..2ef10aae5f7f5ce515cb51f857b66c6195f6664b Binary files /dev/null and b/deploy/lite_shitu/images/demo.jpeg differ diff --git a/deploy/lite_shitu/images/demo.jpg b/deploy/lite_shitu/images/demo.jpg deleted file mode 100644 index 075dc31d4e6b407b792cc8abca82dcd541be8d11..0000000000000000000000000000000000000000 Binary files a/deploy/lite_shitu/images/demo.jpg and /dev/null differ diff --git a/deploy/lite_shitu/include/feature_extractor.h b/deploy/lite_shitu/include/feature_extractor.h index 1961459ecfab149695890df60cef550ed5177b52..9eb29215f44c437746ab5f4bab4f042e9117e764 100644 --- a/deploy/lite_shitu/include/feature_extractor.h +++ b/deploy/lite_shitu/include/feature_extractor.h @@ -24,6 +24,7 @@ #include #include #include +#include using namespace paddle::lite_api; // NOLINT using namespace std; @@ -48,10 +49,6 @@ public: config_file["Global"]["rec_model_path"].as()); this->predictor = CreatePaddlePredictor(config); - if (config_file["Global"]["rec_label_path"].as().empty()) { - std::cout << "Please set [rec_label_path] in config file" << std::endl; - exit(-1); - } SetPreProcessParam(config_file["RecPreProcess"]["transform_ops"]); printf("feature extract model create!\n"); } @@ -68,7 +65,7 @@ public: this->mean.emplace_back(tmp.as()); } for (auto tmp : item["std"]) { - this->std.emplace_back(1 / tmp.as()); + this->std.emplace_back(tmp.as()); } this->scale = item["scale"].as(); } @@ -77,15 +74,19 @@ public: void RunRecModel(const cv::Mat &img, double &cost_time, std::vector &feature); //void PostProcess(std::vector &feature); - cv::Mat ResizeImage(const cv::Mat &img); - void NeonMeanScale(const float *din, float *dout, int size); + void FeatureNorm(std::vector &featuer); private: std::shared_ptr predictor; //std::vector label_list; std::vector mean = {0.485f, 0.456f, 0.406f}; - std::vector std = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + std::vector std = {0.229f, 0.224f, 0.225f}; double scale = 0.00392157; - float size = 224; + int size = 224; + + // pre-process + Resize resize_op_; + NormalizeImage normalize_op_; + Permute permute_op_; }; } // namespace PPShiTu diff --git a/deploy/lite_shitu/include/preprocess_op.h b/deploy/lite_shitu/include/preprocess_op.h index f7050fa86951bfe80aa4030adabc11ff43f82371..e414219e4e28b7b9d656b25bd926e496bad4da25 100644 --- a/deploy/lite_shitu/include/preprocess_op.h +++ b/deploy/lite_shitu/include/preprocess_op.h @@ -71,6 +71,8 @@ class NormalizeImage : public PreprocessOp { } virtual void Run(cv::Mat* im, ImageBlob* data); + void Run_feature(cv::Mat *im, const std::vector &mean, + const std::vector &std, float scale); private: // CHW or HWC @@ -83,6 +85,7 @@ class Permute : public PreprocessOp { public: virtual void Init(const Json::Value& item) {} virtual void Run(cv::Mat* im, ImageBlob* data); + void Run_feature(const cv::Mat *im, float *data); }; class Resize : public PreprocessOp { @@ -101,6 +104,7 @@ class Resize : public PreprocessOp { std::pair GenerateScale(const cv::Mat& im); virtual void Run(cv::Mat* im, ImageBlob* data); + void Run_feature(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, int size=0); private: int interp_; diff --git a/deploy/lite_shitu/include/utils.h b/deploy/lite_shitu/include/utils.h index a3b57c882561577defff97e384fb775b78204f36..482dc384cfc6d585928d33d11807d76fe59ff342 100644 --- a/deploy/lite_shitu/include/utils.h +++ b/deploy/lite_shitu/include/utils.h @@ -38,6 +38,24 @@ struct ObjectResult { std::vector rec_result; }; -void nms(std::vector &input_boxes, float nms_threshold, bool rec_nms=false); +void nms(std::vector &input_boxes, float nms_threshold, + bool rec_nms = false); +template +static inline bool SortScorePairDescend(const std::pair &pair1, + const std::pair &pair2){ + return pair1.first > pair2.first; +} + +float RectOverlap(const ObjectResult &a, + const ObjectResult &b); + +inline void +GetMaxScoreIndex(const std::vector &det_result, + const float threshold, + std::vector> &score_index_vec); + +void NMSBoxes(const std::vector det_result, + const float score_threshold, const float nms_threshold, + std::vector &indices); } // namespace PPShiTu diff --git a/deploy/lite_shitu/src/feature_extractor.cc b/deploy/lite_shitu/src/feature_extractor.cc index aca5c1cbbe5c70cd214c922609831e9350be28a0..4cc17186b5791da9c38accc0f947c14b732557e9 100644 --- a/deploy/lite_shitu/src/feature_extractor.cc +++ b/deploy/lite_shitu/src/feature_extractor.cc @@ -13,24 +13,30 @@ // limitations under the License. #include "include/feature_extractor.h" +#include +#include namespace PPShiTu { void FeatureExtract::RunRecModel(const cv::Mat &img, double &cost_time, std::vector &feature) { // Read img - cv::Mat resize_image = ResizeImage(img); - cv::Mat img_fp; - resize_image.convertTo(img_fp, CV_32FC3, scale); + this->resize_op_.Run_feature(img, img_fp, this->size, this->size); + this->normalize_op_.Run_feature(&img_fp, this->mean, this->std, this->scale); + std::vector input(1 * 3 * img_fp.rows * img_fp.cols, 0.0f); + this->permute_op_.Run_feature(&img_fp, input.data()); // Prepare input data from image std::unique_ptr input_tensor(std::move(this->predictor->GetInput(0))); - input_tensor->Resize({1, 3, img_fp.rows, img_fp.cols}); + input_tensor->Resize({1, 3, this->size, this->size}); auto *data0 = input_tensor->mutable_data(); - const float *dimg = reinterpret_cast(img_fp.data); - NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols); + // const float *dimg = reinterpret_cast(img_fp.data); + // NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols); + for(int i=0; i < input.size(); ++i){ + data0[i] = input[i]; + } auto start = std::chrono::system_clock::now(); // Run predictor @@ -55,62 +61,14 @@ void FeatureExtract::RunRecModel(const cv::Mat &img, output_tensor->CopyToCpu(feature.data()); //postprocess include sqrt or binarize. - //PostProcess(feature); + FeatureNorm(feature); return; } -// void FeatureExtract::PostProcess(std::vector &feature){ -// float feature_sqrt = std::sqrt(std::inner_product( -// feature.begin(), feature.end(), feature.begin(), 0.0f)); -// for (int i = 0; i < feature.size(); ++i) -// feature[i] /= feature_sqrt; -// } - -void FeatureExtract::NeonMeanScale(const float *din, float *dout, int size) { - - if (this->mean.size() != 3 || this->std.size() != 3) { - std::cerr << "[ERROR] mean or scale size must equal to 3\n"; - exit(1); - } - float32x4_t vmean0 = vdupq_n_f32(mean[0]); - float32x4_t vmean1 = vdupq_n_f32(mean[1]); - float32x4_t vmean2 = vdupq_n_f32(mean[2]); - float32x4_t vscale0 = vdupq_n_f32(std[0]); - float32x4_t vscale1 = vdupq_n_f32(std[1]); - float32x4_t vscale2 = vdupq_n_f32(std[2]); - - float *dout_c0 = dout; - float *dout_c1 = dout + size; - float *dout_c2 = dout + size * 2; - - int i = 0; - for (; i < size - 3; i += 4) { - float32x4x3_t vin3 = vld3q_f32(din); - float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0); - float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1); - float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2); - float32x4_t vs0 = vmulq_f32(vsub0, vscale0); - float32x4_t vs1 = vmulq_f32(vsub1, vscale1); - float32x4_t vs2 = vmulq_f32(vsub2, vscale2); - vst1q_f32(dout_c0, vs0); - vst1q_f32(dout_c1, vs1); - vst1q_f32(dout_c2, vs2); - - din += 12; - dout_c0 += 4; - dout_c1 += 4; - dout_c2 += 4; - } - for (; i < size; i++) { - *(dout_c0++) = (*(din++) - this->mean[0]) * this->std[0]; - *(dout_c1++) = (*(din++) - this->mean[1]) * this->std[1]; - *(dout_c2++) = (*(din++) - this->mean[2]) * this->std[2]; - } -} - -cv::Mat FeatureExtract::ResizeImage(const cv::Mat &img) { - cv::Mat resize_img; - cv::resize(img, resize_img, cv::Size(this->size, this->size)); - return resize_img; +void FeatureExtract::FeatureNorm(std::vector &feature){ + float feature_sqrt = std::sqrt(std::inner_product( + feature.begin(), feature.end(), feature.begin(), 0.0f)); + for (int i = 0; i < feature.size(); ++i) + feature[i] /= feature_sqrt; } } diff --git a/deploy/lite_shitu/src/main.cc b/deploy/lite_shitu/src/main.cc index 3f278dc778701a7a7591e74336e0f86fe52105ea..16f3ac522cd3d607644cc643cfc69489191158c5 100644 --- a/deploy/lite_shitu/src/main.cc +++ b/deploy/lite_shitu/src/main.cc @@ -28,6 +28,7 @@ #include "include/object_detector.h" #include "include/preprocess_op.h" #include "include/vector_search.h" +#include "include/utils.h" #include "json/json.h" Json::Value RT_Config; @@ -158,6 +159,11 @@ int main(int argc, char **argv) { << " [image_dir]>" << std::endl; return -1; } + + float rec_nms_threshold = 0.05; + if (RT_Config["Global"]["rec_nms_thresold"].isDouble()) + rec_nms_threshold = RT_Config["Global"]["rec_nms_thresold"].as(); + // Load model and create a object detector PPShiTu::ObjectDetector det( RT_Config, RT_Config["Global"]["det_model_path"].as(), @@ -174,6 +180,7 @@ int main(int argc, char **argv) { // for vector search std::vector feature; std::vector features; + std::vector indeices; double rec_time; if (!RT_Config["Global"]["infer_imgs"].as().empty() || !img_dir.empty()) { @@ -208,9 +215,9 @@ int main(int argc, char **argv) { RT_Config["Global"]["max_det_results"].as(), false, &det); // add the whole image for recognition to improve recall -// PPShiTu::ObjectResult result_whole_img = { -// {0, 0, srcimg.cols, srcimg.rows}, 0, 1.0}; -// det_result.push_back(result_whole_img); + PPShiTu::ObjectResult result_whole_img = { + {0, 0, srcimg.cols, srcimg.rows}, 0, 1.0}; + det_result.push_back(result_whole_img); // get rec result PPShiTu::SearchResult search_result; @@ -225,10 +232,18 @@ int main(int argc, char **argv) { // do vectore search search_result = searcher.Search(features.data(), det_result.size()); + for (int i = 0; i < det_result.size(); ++i) { + det_result[i].confidence = search_result.D[search_result.return_k * i]; + } + NMSBoxes(det_result, searcher.GetThreshold(), rec_nms_threshold, indeices); PrintResult(img_path, det_result, searcher, search_result); batch_imgs.clear(); det_result.clear(); + features.clear(); + feature.clear(); + indeices.clear(); + } } return 0; diff --git a/deploy/lite_shitu/src/preprocess_op.cc b/deploy/lite_shitu/src/preprocess_op.cc index 9c74d6ee7241c93b9fb206317f634e523425793e..dc560c7f45fb0b68b098b269d415f76a74b18d1d 100644 --- a/deploy/lite_shitu/src/preprocess_op.cc +++ b/deploy/lite_shitu/src/preprocess_op.cc @@ -20,7 +20,7 @@ namespace PPShiTu { -void InitInfo::Run(cv::Mat* im, ImageBlob* data) { +void InitInfo::Run(cv::Mat *im, ImageBlob *data) { data->im_shape_ = {static_cast(im->rows), static_cast(im->cols)}; data->scale_factor_ = {1., 1.}; @@ -28,10 +28,10 @@ void InitInfo::Run(cv::Mat* im, ImageBlob* data) { static_cast(im->cols)}; } -void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) { +void NormalizeImage::Run(cv::Mat *im, ImageBlob *data) { double e = 1.0; if (is_scale_) { - e *= 1./255.0; + e *= 1. / 255.0; } (*im).convertTo(*im, CV_32FC3, e); for (int h = 0; h < im->rows; h++) { @@ -46,35 +46,61 @@ void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) { } } -void Permute::Run(cv::Mat* im, ImageBlob* data) { +void NormalizeImage::Run_feature(cv::Mat *im, const std::vector &mean, + const std::vector &std, float scale) { + (*im).convertTo(*im, CV_32FC3, scale); + for (int h = 0; h < im->rows; h++) { + for (int w = 0; w < im->cols; w++) { + im->at(h, w)[0] = + (im->at(h, w)[0] - mean[0]) / std[0]; + im->at(h, w)[1] = + (im->at(h, w)[1] - mean[1]) / std[1]; + im->at(h, w)[2] = + (im->at(h, w)[2] - mean[2]) / std[2]; + } + } +} + +void Permute::Run(cv::Mat *im, ImageBlob *data) { (*im).convertTo(*im, CV_32FC3); int rh = im->rows; int rw = im->cols; int rc = im->channels(); (data->im_data_).resize(rc * rh * rw); - float* base = (data->im_data_).data(); + float *base = (data->im_data_).data(); for (int i = 0; i < rc; ++i) { cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); } } -void Resize::Run(cv::Mat* im, ImageBlob* data) { +void Permute::Run_feature(const cv::Mat *im, float *data) { + int rh = im->rows; + int rw = im->cols; + int rc = im->channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i); + } +} + +void Resize::Run(cv::Mat *im, ImageBlob *data) { auto resize_scale = GenerateScale(*im); data->im_shape_ = {static_cast(im->cols * resize_scale.first), static_cast(im->rows * resize_scale.second)}; data->in_net_shape_ = {static_cast(im->cols * resize_scale.first), static_cast(im->rows * resize_scale.second)}; - cv::resize( - *im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_); + cv::resize(*im, *im, cv::Size(), resize_scale.first, resize_scale.second, + interp_); data->im_shape_ = { - static_cast(im->rows), static_cast(im->cols), + static_cast(im->rows), + static_cast(im->cols), }; data->scale_factor_ = { - resize_scale.second, resize_scale.first, + resize_scale.second, + resize_scale.first, }; } -std::pair Resize::GenerateScale(const cv::Mat& im) { +std::pair Resize::GenerateScale(const cv::Mat &im) { std::pair resize_scale; int origin_w = im.cols; int origin_h = im.rows; @@ -101,7 +127,30 @@ std::pair Resize::GenerateScale(const cv::Mat& im) { return resize_scale; } -void PadStride::Run(cv::Mat* im, ImageBlob* data) { +void Resize::Run_feature(const cv::Mat &img, cv::Mat &resize_img, int resize_short_size, + int size) { + int resize_h = 0; + int resize_w = 0; + if (size > 0) { + resize_h = size; + resize_w = size; + } else { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + if (h < w) { + ratio = float(resize_short_size) / float(h); + } else { + ratio = float(resize_short_size) / float(w); + } + resize_h = round(float(h) * ratio); + resize_w = round(float(w) * ratio); + } + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); +} + +void PadStride::Run(cv::Mat *im, ImageBlob *data) { if (stride_ <= 0) { return; } @@ -110,48 +159,44 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) { int rw = im->cols; int nh = (rh / stride_) * stride_ + (rh % stride_ != 0) * stride_; int nw = (rw / stride_) * stride_ + (rw % stride_ != 0) * stride_; - cv::copyMakeBorder( - *im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT, cv::Scalar(0)); + cv::copyMakeBorder(*im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT, + cv::Scalar(0)); data->in_net_shape_ = { - static_cast(im->rows), static_cast(im->cols), + static_cast(im->rows), + static_cast(im->cols), }; } -void TopDownEvalAffine::Run(cv::Mat* im, ImageBlob* data) { +void TopDownEvalAffine::Run(cv::Mat *im, ImageBlob *data) { cv::resize(*im, *im, cv::Size(trainsize_[0], trainsize_[1]), 0, 0, interp_); // todo: Simd::ResizeBilinear(); data->in_net_shape_ = { - static_cast(trainsize_[1]), static_cast(trainsize_[0]), + static_cast(trainsize_[1]), + static_cast(trainsize_[0]), }; } // Preprocessor op running order -const std::vector Preprocessor::RUN_ORDER = {"InitInfo", - "DetTopDownEvalAffine", - "DetResize", - "DetNormalizeImage", - "DetPadStride", - "DetPermute"}; - -void Preprocessor::Run(cv::Mat* im, ImageBlob* data) { - for (const auto& name : RUN_ORDER) { +const std::vector Preprocessor::RUN_ORDER = { + "InitInfo", "DetTopDownEvalAffine", "DetResize", + "DetNormalizeImage", "DetPadStride", "DetPermute"}; + +void Preprocessor::Run(cv::Mat *im, ImageBlob *data) { + for (const auto &name : RUN_ORDER) { if (ops_.find(name) != ops_.end()) { ops_[name]->Run(im, data); } } } -void CropImg(cv::Mat& img, - cv::Mat& crop_img, - std::vector& area, - std::vector& center, - std::vector& scale, +void CropImg(cv::Mat &img, cv::Mat &crop_img, std::vector &area, + std::vector ¢er, std::vector &scale, float expandratio) { int crop_x1 = std::max(0, area[0]); int crop_y1 = std::max(0, area[1]); int crop_x2 = std::min(img.cols - 1, area[2]); int crop_y2 = std::min(img.rows - 1, area[3]); - + int center_x = (crop_x1 + crop_x2) / 2.; int center_y = (crop_y1 + crop_y2) / 2.; int half_h = (crop_y2 - crop_y1) / 2.; @@ -182,4 +227,4 @@ void CropImg(cv::Mat& img, scale.emplace_back((crop_y2 - crop_y1)); } -} // namespace PPShiTu +} // namespace PPShiTu diff --git a/deploy/lite_shitu/src/utils.cc b/deploy/lite_shitu/src/utils.cc index 3bc461770e2d79e33e4de91a3f4cea8c131eb7ad..cf51789091d458860eae2a957171e8573ca074e5 100644 --- a/deploy/lite_shitu/src/utils.cc +++ b/deploy/lite_shitu/src/utils.cc @@ -54,4 +54,55 @@ void nms(std::vector &input_boxes, float nms_threshold, } } + +float RectOverlap(const ObjectResult &a, + const ObjectResult &b) { + float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1); + float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1); + + int iou_w = max(min(a.rect[2], b.rect[2]) - max(a.rect[0], b.rect[0]) + 1, 0); + int iou_h = max(min(a.rect[3], b.rect[3]) - max(a.rect[1], b.rect[1]) + 1, 0); + float Aab = iou_w * iou_h; + return Aab / (Aa + Ab - Aab); +} + +inline void +GetMaxScoreIndex(const std::vector &det_result, + const float threshold, + std::vector> &score_index_vec) { + // Generate index score pairs. + for (size_t i = 0; i < det_result.size(); ++i) { + if (det_result[i].confidence > threshold) { + score_index_vec.push_back(std::make_pair(det_result[i].confidence, i)); + } + } + + // Sort the score pair according to the scores in descending order + std::stable_sort(score_index_vec.begin(), score_index_vec.end(), + SortScorePairDescend); +} + +void NMSBoxes(const std::vector det_result, + const float score_threshold, const float nms_threshold, + std::vector &indices) { + int a = 1; + // Get top_k scores (with corresponding indices). + std::vector> score_index_vec; + GetMaxScoreIndex(det_result, score_threshold, score_index_vec); + + // Do nms + indices.clear(); + for (size_t i = 0; i < score_index_vec.size(); ++i) { + const int idx = score_index_vec[i].second; + bool keep = true; + for (int k = 0; k < (int)indices.size() && keep; ++k) { + const int kept_idx = indices[k]; + float overlap = RectOverlap(det_result[idx], det_result[kept_idx]); + keep = overlap <= nms_threshold; + } + if (keep) + indices.push_back(idx); + } +} + } // namespace PPShiTu diff --git a/deploy/lite_shitu/src/vector_search.cc b/deploy/lite_shitu/src/vector_search.cc index ea848959b651eb04effc25ad9efb7eb497ef2025..272c0855e56fb2ef3a905081975695b2086434c0 100644 --- a/deploy/lite_shitu/src/vector_search.cc +++ b/deploy/lite_shitu/src/vector_search.cc @@ -64,4 +64,4 @@ const SearchResult &VectorSearch::Search(float *feature, int query_number) { const std::string &VectorSearch::GetLabel(faiss::Index::idx_t ind) { return this->id_map.at(ind); } -} \ No newline at end of file +}