提交 fa87707d 编写于 作者: D dongshuilong

fix lite_shitu bugs

上级 3a28ee29
......@@ -92,9 +92,9 @@ PaddleClas 提供了转换并优化后的推理模型,可以直接参考下方
```shell
# 进入lite_ppshitu目录
cd $PaddleClas/deploy/lite_shitu
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/lite/ppshitu_lite_models_v1.1.tar
tar -xf ppshitu_lite_models_v1.1.tar
rm -f ppshitu_lite_models_v1.1.tar
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/lite/ppshitu_lite_models_v1.2.tar
tar -xf ppshitu_lite_models_v1.2.tar
rm -f ppshitu_lite_models_v1.2.tar
```
#### 2.1.2 使用其他模型
......@@ -162,7 +162,7 @@ git clone https://github.com/PaddlePaddle/PaddleDetection.git
# 进入PaddleDetection根目录
cd PaddleDetection
# 将预训练模型导出为inference模型
python tools/export_model.py -c configs/picodet/application/mainbody_detection/picodet_lcnet_x2_5_640_mainbody.yml -o weights=https://paddledet.bj.bcebos.com/models/picodet_lcnet_x2_5_640_mainbody.pdparams --output_dir=inference
python tools/export_model.py -c configs/picodet/application/mainbody_detection/picodet_lcnet_x2_5_640_mainbody.yml -o weights=https://paddledet.bj.bcebos.com/models/picodet_lcnet_x2_5_640_mainbody.pdparams export_post_process=False --output_dir=inference
# 将inference模型转化为Paddle-Lite优化模型
paddle_lite_opt --model_file=inference/picodet_lcnet_x2_5_640_mainbody/model.pdmodel --param_file=inference/picodet_lcnet_x2_5_640_mainbody/model.pdiparams --optimize_out=inference/picodet_lcnet_x2_5_640_mainbody/mainbody_det
# 将转好的模型复制到lite_shitu目录下
......@@ -183,24 +183,56 @@ cd deploy/lite_shitu
**注意**`--optimize_out` 参数为优化后模型的保存路径,无需加后缀`.nb``--model_file` 参数为模型结构信息文件的路径,`--param_file` 参数为模型权重信息文件的路径,请注意文件名。
### 2.2 将yaml文件转换成json文件
### 2.2 生成新的检索库
由于lite 版本的检索库用的是`faiss1.5.3`版本,与新版本不兼容,因此需要重新生成index库
#### 2.2.1 数据及环境配置
```shell
# 进入上级目录
cd ..
# 下载瓶装饮料数据集
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar && tar -xf drink_dataset_v1.0.tar
rm -rf drink_dataset_v1.0.tar
rm -rf drink_dataset_v1.0/index
# 安装1.5.3版本的faiss
pip install faiss-cpu==1.5.3
# 下载通用识别模型,可替换成自己的inference model
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
tar -xf general_PPLCNet_x2_5_lite_v1.0_infer.tar
rm -rf general_PPLCNet_x2_5_lite_v1.0_infer.tar
```
#### 2.2.2 生成新的index文件
```shell
# 生成新的index库,注意指定好识别模型的路径,同时将index_mothod修改成Flat,HNSW32和IVF在此版本中可能存在bug,请慎重使用。
# 如果使用自己的识别模型,对应的修改inference model的目录
python python/build_gallery.py -c configs/inference_drink.yaml -o Global.rec_inference_model_dir=general_PPLCNet_x2_5_lite_v1.0_infer -o IndexProcess.index_method=Flat
# 进入到lite_shitu目录
cd lite_shitu
mv ../drink_dataset_v1.0 .
```
### 2.3 将yaml文件转换成json文件
```shell
# 如果测试单张图像
python generate_json_config.py --det_model_path ppshitu_lite_models_v1.1/mainbody_PPLCNet_x2_5_640_quant_v1.1_lite.nb --rec_model_path ppshitu_lite_models_v1.1/general_PPLCNet_x2_5_lite_v1.1_infer.nb --img_path images/demo.jpg
python generate_json_config.py --det_model_path ppshitu_lite_models_v1.2/mainbody_PPLCNet_x2_5_640_v1.2_lite.nb --rec_model_path ppshitu_lite_models_v1.2/general_PPLCNet_x2_5_lite_v1.2_infer.nb --img_path images/demo.jpeg
# or
# 如果测试多张图像
python generate_json_config.py --det_model_path ppshitu_lite_models_v1.1/mainbody_PPLCNet_x2_5_640_quant_v1.1_lite.nb --rec_model_path ppshitu_lite_models_v1.1/general_PPLCNet_x2_5_lite_v1.1_infer.nb --img_dir images
python generate_json_config.py --det_model_path ppshitu_lite_models_v1.2/mainbody_PPLCNet_x2_5_640_v1.2_lite.nb --rec_model_path ppshitu_lite_models_v1.2/general_PPLCNet_x2_5_lite_v1.2_infer.nb --img_dir images
# 执行完成后,会在lit_shitu下生成shitu_config.json配置文件
```
### 2.3 index字典转换
### 2.4 index字典转换
由于python的检索库字典,使用`pickle`进行的序列化存储,导致C++不方便读取,因此需要进行转换
```shell
# 下载瓶装饮料数据集
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar && tar -xf drink_dataset_v1.0.tar
rm -rf drink_dataset_v1.0.tar
# 转化id_map.pkl为id_map.txt
python transform_id_map.py -c ../configs/inference_drink.yaml
......@@ -208,7 +240,7 @@ python transform_id_map.py -c ../configs/inference_drink.yaml
转换成功后,会在`IndexProcess.index_dir`目录下生成`id_map.txt`
### 2.4 与手机联调
### 2.5 与手机联调
首先需要进行一些准备工作。
1. 准备一台arm8的安卓手机,如果编译的预测库是armv7,则需要arm7的手机,并修改Makefile中`ARM_ABI=arm7`
......@@ -308,8 +340,9 @@ chmod 777 pp_shitu
运行效果如下:
```
images/demo.jpg:
result0: bbox[253, 275, 1146, 872], score: 0.974196, label: 伊藤园_果蔬汁
images/demo.jpeg:
result0: bbox[344, 98, 527, 593], score: 0.811656, label: 红牛-强化型
result1: bbox[0, 0, 600, 600], score: 0.729664, label: 红牛-强化型
```
## FAQ
......
......@@ -24,6 +24,7 @@
#include <stdlib.h>
#include <sys/time.h>
#include <vector>
#include <include/preprocess_op.h>
using namespace paddle::lite_api; // NOLINT
using namespace std;
......@@ -48,10 +49,6 @@ public:
config_file["Global"]["rec_model_path"].as<std::string>());
this->predictor = CreatePaddlePredictor<MobileConfig>(config);
if (config_file["Global"]["rec_label_path"].as<std::string>().empty()) {
std::cout << "Please set [rec_label_path] in config file" << std::endl;
exit(-1);
}
SetPreProcessParam(config_file["RecPreProcess"]["transform_ops"]);
printf("feature extract model create!\n");
}
......@@ -68,7 +65,7 @@ public:
this->mean.emplace_back(tmp.as<float>());
}
for (auto tmp : item["std"]) {
this->std.emplace_back(1 / tmp.as<float>());
this->std.emplace_back(tmp.as<float>());
}
this->scale = item["scale"].as<double>();
}
......@@ -77,15 +74,19 @@ public:
void RunRecModel(const cv::Mat &img, double &cost_time, std::vector<float> &feature);
//void PostProcess(std::vector<float> &feature);
cv::Mat ResizeImage(const cv::Mat &img);
void NeonMeanScale(const float *din, float *dout, int size);
void FeatureNorm(std::vector<float> &featuer);
private:
std::shared_ptr<PaddlePredictor> predictor;
//std::vector<std::string> label_list;
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> std = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
std::vector<float> std = {0.229f, 0.224f, 0.225f};
double scale = 0.00392157;
float size = 224;
int size = 224;
// pre-process
Resize resize_op_;
NormalizeImage normalize_op_;
Permute permute_op_;
};
} // namespace PPShiTu
......@@ -71,6 +71,8 @@ class NormalizeImage : public PreprocessOp {
}
virtual void Run(cv::Mat* im, ImageBlob* data);
void Run_feature(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale);
private:
// CHW or HWC
......@@ -83,6 +85,7 @@ class Permute : public PreprocessOp {
public:
virtual void Init(const Json::Value& item) {}
virtual void Run(cv::Mat* im, ImageBlob* data);
void Run_feature(const cv::Mat *im, float *data);
};
class Resize : public PreprocessOp {
......@@ -101,6 +104,7 @@ class Resize : public PreprocessOp {
std::pair<float, float> GenerateScale(const cv::Mat& im);
virtual void Run(cv::Mat* im, ImageBlob* data);
void Run_feature(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, int size=0);
private:
int interp_;
......
......@@ -38,6 +38,24 @@ struct ObjectResult {
std::vector<RESULT> rec_result;
};
void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold, bool rec_nms=false);
void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold,
bool rec_nms = false);
template <typename T>
static inline bool SortScorePairDescend(const std::pair<float, T> &pair1,
const std::pair<float, T> &pair2){
return pair1.first > pair2.first;
}
float RectOverlap(const ObjectResult &a,
const ObjectResult &b);
inline void
GetMaxScoreIndex(const std::vector<ObjectResult> &det_result,
const float threshold,
std::vector<std::pair<float, int>> &score_index_vec);
void NMSBoxes(const std::vector<ObjectResult> det_result,
const float score_threshold, const float nms_threshold,
std::vector<int> &indices);
} // namespace PPShiTu
......@@ -13,24 +13,30 @@
// limitations under the License.
#include "include/feature_extractor.h"
#include <cmath>
#include <numeric>
namespace PPShiTu {
void FeatureExtract::RunRecModel(const cv::Mat &img,
double &cost_time,
std::vector<float> &feature) {
// Read img
cv::Mat resize_image = ResizeImage(img);
cv::Mat img_fp;
resize_image.convertTo(img_fp, CV_32FC3, scale);
this->resize_op_.Run_feature(img, img_fp, this->size, this->size);
this->normalize_op_.Run_feature(&img_fp, this->mean, this->std, this->scale);
std::vector<float> input(1 * 3 * img_fp.rows * img_fp.cols, 0.0f);
this->permute_op_.Run_feature(&img_fp, input.data());
// Prepare input data from image
std::unique_ptr<Tensor> input_tensor(std::move(this->predictor->GetInput(0)));
input_tensor->Resize({1, 3, img_fp.rows, img_fp.cols});
input_tensor->Resize({1, 3, this->size, this->size});
auto *data0 = input_tensor->mutable_data<float>();
const float *dimg = reinterpret_cast<const float *>(img_fp.data);
NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols);
// const float *dimg = reinterpret_cast<const float *>(img_fp.data);
// NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols);
for(int i=0; i < input.size(); ++i){
data0[i] = input[i];
}
auto start = std::chrono::system_clock::now();
// Run predictor
......@@ -55,62 +61,14 @@ void FeatureExtract::RunRecModel(const cv::Mat &img,
output_tensor->CopyToCpu(feature.data());
//postprocess include sqrt or binarize.
//PostProcess(feature);
FeatureNorm(feature);
return;
}
// void FeatureExtract::PostProcess(std::vector<float> &feature){
// float feature_sqrt = std::sqrt(std::inner_product(
// feature.begin(), feature.end(), feature.begin(), 0.0f));
// for (int i = 0; i < feature.size(); ++i)
// feature[i] /= feature_sqrt;
// }
void FeatureExtract::NeonMeanScale(const float *din, float *dout, int size) {
if (this->mean.size() != 3 || this->std.size() != 3) {
std::cerr << "[ERROR] mean or scale size must equal to 3\n";
exit(1);
}
float32x4_t vmean0 = vdupq_n_f32(mean[0]);
float32x4_t vmean1 = vdupq_n_f32(mean[1]);
float32x4_t vmean2 = vdupq_n_f32(mean[2]);
float32x4_t vscale0 = vdupq_n_f32(std[0]);
float32x4_t vscale1 = vdupq_n_f32(std[1]);
float32x4_t vscale2 = vdupq_n_f32(std[2]);
float *dout_c0 = dout;
float *dout_c1 = dout + size;
float *dout_c2 = dout + size * 2;
int i = 0;
for (; i < size - 3; i += 4) {
float32x4x3_t vin3 = vld3q_f32(din);
float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0);
float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1);
float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2);
float32x4_t vs0 = vmulq_f32(vsub0, vscale0);
float32x4_t vs1 = vmulq_f32(vsub1, vscale1);
float32x4_t vs2 = vmulq_f32(vsub2, vscale2);
vst1q_f32(dout_c0, vs0);
vst1q_f32(dout_c1, vs1);
vst1q_f32(dout_c2, vs2);
din += 12;
dout_c0 += 4;
dout_c1 += 4;
dout_c2 += 4;
}
for (; i < size; i++) {
*(dout_c0++) = (*(din++) - this->mean[0]) * this->std[0];
*(dout_c1++) = (*(din++) - this->mean[1]) * this->std[1];
*(dout_c2++) = (*(din++) - this->mean[2]) * this->std[2];
}
}
cv::Mat FeatureExtract::ResizeImage(const cv::Mat &img) {
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(this->size, this->size));
return resize_img;
void FeatureExtract::FeatureNorm(std::vector<float> &feature){
float feature_sqrt = std::sqrt(std::inner_product(
feature.begin(), feature.end(), feature.begin(), 0.0f));
for (int i = 0; i < feature.size(); ++i)
feature[i] /= feature_sqrt;
}
}
......@@ -28,6 +28,7 @@
#include "include/object_detector.h"
#include "include/preprocess_op.h"
#include "include/vector_search.h"
#include "include/utils.h"
#include "json/json.h"
Json::Value RT_Config;
......@@ -158,6 +159,11 @@ int main(int argc, char **argv) {
<< " [image_dir]>" << std::endl;
return -1;
}
float rec_nms_threshold = 0.05;
if (RT_Config["Global"]["rec_nms_thresold"].isDouble())
rec_nms_threshold = RT_Config["Global"]["rec_nms_thresold"].as<float>();
// Load model and create a object detector
PPShiTu::ObjectDetector det(
RT_Config, RT_Config["Global"]["det_model_path"].as<std::string>(),
......@@ -174,6 +180,7 @@ int main(int argc, char **argv) {
// for vector search
std::vector<float> feature;
std::vector<float> features;
std::vector<int> indeices;
double rec_time;
if (!RT_Config["Global"]["infer_imgs"].as<std::string>().empty() ||
!img_dir.empty()) {
......@@ -208,9 +215,9 @@ int main(int argc, char **argv) {
RT_Config["Global"]["max_det_results"].as<int>(), false, &det);
// add the whole image for recognition to improve recall
// PPShiTu::ObjectResult result_whole_img = {
// {0, 0, srcimg.cols, srcimg.rows}, 0, 1.0};
// det_result.push_back(result_whole_img);
PPShiTu::ObjectResult result_whole_img = {
{0, 0, srcimg.cols, srcimg.rows}, 0, 1.0};
det_result.push_back(result_whole_img);
// get rec result
PPShiTu::SearchResult search_result;
......@@ -225,10 +232,18 @@ int main(int argc, char **argv) {
// do vectore search
search_result = searcher.Search(features.data(), det_result.size());
for (int i = 0; i < det_result.size(); ++i) {
det_result[i].confidence = search_result.D[search_result.return_k * i];
}
NMSBoxes(det_result, searcher.GetThreshold(), rec_nms_threshold, indeices);
PrintResult(img_path, det_result, searcher, search_result);
batch_imgs.clear();
det_result.clear();
features.clear();
feature.clear();
indeices.clear();
}
}
return 0;
......
......@@ -20,7 +20,7 @@
namespace PPShiTu {
void InitInfo::Run(cv::Mat* im, ImageBlob* data) {
void InitInfo::Run(cv::Mat *im, ImageBlob *data) {
data->im_shape_ = {static_cast<float>(im->rows),
static_cast<float>(im->cols)};
data->scale_factor_ = {1., 1.};
......@@ -28,10 +28,10 @@ void InitInfo::Run(cv::Mat* im, ImageBlob* data) {
static_cast<float>(im->cols)};
}
void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) {
void NormalizeImage::Run(cv::Mat *im, ImageBlob *data) {
double e = 1.0;
if (is_scale_) {
e *= 1./255.0;
e *= 1. / 255.0;
}
(*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) {
......@@ -46,35 +46,61 @@ void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) {
}
}
void Permute::Run(cv::Mat* im, ImageBlob* data) {
void NormalizeImage::Run_feature(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale) {
(*im).convertTo(*im, CV_32FC3, scale);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
}
}
}
void Permute::Run(cv::Mat *im, ImageBlob *data) {
(*im).convertTo(*im, CV_32FC3);
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
(data->im_data_).resize(rc * rh * rw);
float* base = (data->im_data_).data();
float *base = (data->im_data_).data();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i);
}
}
void Resize::Run(cv::Mat* im, ImageBlob* data) {
void Permute::Run_feature(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Resize::Run(cv::Mat *im, ImageBlob *data) {
auto resize_scale = GenerateScale(*im);
data->im_shape_ = {static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)};
data->in_net_shape_ = {static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)};
cv::resize(
*im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_);
cv::resize(*im, *im, cv::Size(), resize_scale.first, resize_scale.second,
interp_);
data->im_shape_ = {
static_cast<float>(im->rows), static_cast<float>(im->cols),
static_cast<float>(im->rows),
static_cast<float>(im->cols),
};
data->scale_factor_ = {
resize_scale.second, resize_scale.first,
resize_scale.second,
resize_scale.first,
};
}
std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
std::pair<float, float> Resize::GenerateScale(const cv::Mat &im) {
std::pair<float, float> resize_scale;
int origin_w = im.cols;
int origin_h = im.rows;
......@@ -101,7 +127,30 @@ std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
return resize_scale;
}
void PadStride::Run(cv::Mat* im, ImageBlob* data) {
void Resize::Run_feature(const cv::Mat &img, cv::Mat &resize_img, int resize_short_size,
int size) {
int resize_h = 0;
int resize_w = 0;
if (size > 0) {
resize_h = size;
resize_w = size;
} else {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
resize_h = round(float(h) * ratio);
resize_w = round(float(w) * ratio);
}
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
void PadStride::Run(cv::Mat *im, ImageBlob *data) {
if (stride_ <= 0) {
return;
}
......@@ -110,42 +159,38 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) {
int rw = im->cols;
int nh = (rh / stride_) * stride_ + (rh % stride_ != 0) * stride_;
int nw = (rw / stride_) * stride_ + (rw % stride_ != 0) * stride_;
cv::copyMakeBorder(
*im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT, cv::Scalar(0));
cv::copyMakeBorder(*im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT,
cv::Scalar(0));
data->in_net_shape_ = {
static_cast<float>(im->rows), static_cast<float>(im->cols),
static_cast<float>(im->rows),
static_cast<float>(im->cols),
};
}
void TopDownEvalAffine::Run(cv::Mat* im, ImageBlob* data) {
void TopDownEvalAffine::Run(cv::Mat *im, ImageBlob *data) {
cv::resize(*im, *im, cv::Size(trainsize_[0], trainsize_[1]), 0, 0, interp_);
// todo: Simd::ResizeBilinear();
data->in_net_shape_ = {
static_cast<float>(trainsize_[1]), static_cast<float>(trainsize_[0]),
static_cast<float>(trainsize_[1]),
static_cast<float>(trainsize_[0]),
};
}
// Preprocessor op running order
const std::vector<std::string> Preprocessor::RUN_ORDER = {"InitInfo",
"DetTopDownEvalAffine",
"DetResize",
"DetNormalizeImage",
"DetPadStride",
"DetPermute"};
void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
for (const auto& name : RUN_ORDER) {
const std::vector<std::string> Preprocessor::RUN_ORDER = {
"InitInfo", "DetTopDownEvalAffine", "DetResize",
"DetNormalizeImage", "DetPadStride", "DetPermute"};
void Preprocessor::Run(cv::Mat *im, ImageBlob *data) {
for (const auto &name : RUN_ORDER) {
if (ops_.find(name) != ops_.end()) {
ops_[name]->Run(im, data);
}
}
}
void CropImg(cv::Mat& img,
cv::Mat& crop_img,
std::vector<int>& area,
std::vector<float>& center,
std::vector<float>& scale,
void CropImg(cv::Mat &img, cv::Mat &crop_img, std::vector<int> &area,
std::vector<float> &center, std::vector<float> &scale,
float expandratio) {
int crop_x1 = std::max(0, area[0]);
int crop_y1 = std::max(0, area[1]);
......
......@@ -54,4 +54,55 @@ void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold,
}
}
float RectOverlap(const ObjectResult &a,
const ObjectResult &b) {
float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1);
float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1);
int iou_w = max(min(a.rect[2], b.rect[2]) - max(a.rect[0], b.rect[0]) + 1, 0);
int iou_h = max(min(a.rect[3], b.rect[3]) - max(a.rect[1], b.rect[1]) + 1, 0);
float Aab = iou_w * iou_h;
return Aab / (Aa + Ab - Aab);
}
inline void
GetMaxScoreIndex(const std::vector<ObjectResult> &det_result,
const float threshold,
std::vector<std::pair<float, int>> &score_index_vec) {
// Generate index score pairs.
for (size_t i = 0; i < det_result.size(); ++i) {
if (det_result[i].confidence > threshold) {
score_index_vec.push_back(std::make_pair(det_result[i].confidence, i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(score_index_vec.begin(), score_index_vec.end(),
SortScorePairDescend<int>);
}
void NMSBoxes(const std::vector<ObjectResult> det_result,
const float score_threshold, const float nms_threshold,
std::vector<int> &indices) {
int a = 1;
// Get top_k scores (with corresponding indices).
std::vector<std::pair<float, int>> score_index_vec;
GetMaxScoreIndex(det_result, score_threshold, score_index_vec);
// Do nms
indices.clear();
for (size_t i = 0; i < score_index_vec.size(); ++i) {
const int idx = score_index_vec[i].second;
bool keep = true;
for (int k = 0; k < (int)indices.size() && keep; ++k) {
const int kept_idx = indices[k];
float overlap = RectOverlap(det_result[idx], det_result[kept_idx]);
keep = overlap <= nms_threshold;
}
if (keep)
indices.push_back(idx);
}
}
} // namespace PPShiTu
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册