提交 86e49817 编写于 作者: M MissPenguin

add batch infer for cpp rec

上级 529133fb
...@@ -44,7 +44,8 @@ public: ...@@ -44,7 +44,8 @@ public:
const int &gpu_id, const int &gpu_mem, const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path, const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt, const std::string &precision) { const bool &use_tensorrt, const std::string &precision,
const int &rec_batch_num) {
this->use_gpu_ = use_gpu; this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id; this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem; this->gpu_mem_ = gpu_mem;
...@@ -52,6 +53,7 @@ public: ...@@ -52,6 +53,7 @@ public:
this->use_mkldnn_ = use_mkldnn; this->use_mkldnn_ = use_mkldnn;
this->use_tensorrt_ = use_tensorrt; this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision; this->precision_ = precision;
this->rec_batch_num_ = rec_batch_num;
this->label_list_ = Utility::ReadDict(label_path); this->label_list_ = Utility::ReadDict(label_path);
this->label_list_.insert(this->label_list_.begin(), this->label_list_.insert(this->label_list_.begin(),
...@@ -64,7 +66,8 @@ public: ...@@ -64,7 +66,8 @@ public:
// Load Paddle inference model // Load Paddle inference model
void LoadModel(const std::string &model_dir); void LoadModel(const std::string &model_dir);
void Run(cv::Mat &img, std::vector<double> *times); // void Run(cv::Mat &img, std::vector<double> *times);
void Run(std::vector<cv::Mat> img_list, std::vector<double> *times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<Predictor> predictor_;
...@@ -82,10 +85,12 @@ private: ...@@ -82,10 +85,12 @@ private:
bool is_scale_ = true; bool is_scale_ = true;
bool use_tensorrt_ = false; bool use_tensorrt_ = false;
std::string precision_ = "fp32"; std::string precision_ = "fp32";
int rec_batch_num_ = 6;
// pre-process // pre-process
CrnnResizeImg resize_op_; CrnnResizeImg resize_op_;
Normalize normalize_op_; Normalize normalize_op_;
Permute permute_op_; Permute_batch permute_op_;
// post-process // post-process
PostProcessor post_processor_; PostProcessor post_processor_;
......
...@@ -44,6 +44,11 @@ public: ...@@ -44,6 +44,11 @@ public:
virtual void Run(const cv::Mat *im, float *data); virtual void Run(const cv::Mat *im, float *data);
}; };
class Permute_batch {
public:
virtual void Run(const std::vector<cv::Mat> imgs, float *data);
};
class ResizeImgType0 { class ResizeImgType0 {
public: public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
......
...@@ -50,6 +50,10 @@ public: ...@@ -50,6 +50,10 @@ public:
static cv::Mat GetRotateCropImage(const cv::Mat &srcimage, static cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
std::vector<std::vector<int>> box); std::vector<std::vector<int>> box);
// 实现argsort功能
static std::vector<int> argsort(const std::vector<float>& array);
}; };
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
...@@ -61,7 +61,7 @@ DEFINE_string(cls_model_dir, "", "Path of cls inference model."); ...@@ -61,7 +61,7 @@ DEFINE_string(cls_model_dir, "", "Path of cls inference model.");
DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh."); DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh.");
// recognition related // recognition related
DEFINE_string(rec_model_dir, "", "Path of rec inference model."); DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
DEFINE_int32(rec_batch_num, 1, "rec_batch_num."); DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
DEFINE_string(char_list_file, "../../ppocr/utils/ppocr_keys_v1.txt", "Path of dictionary."); DEFINE_string(char_list_file, "../../ppocr/utils/ppocr_keys_v1.txt", "Path of dictionary.");
...@@ -146,8 +146,9 @@ int main_rec(std::vector<cv::String> cv_all_img_names) { ...@@ -146,8 +146,9 @@ int main_rec(std::vector<cv::String> cv_all_img_names) {
CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_gpu_mem, FLAGS_cpu_threads,
FLAGS_enable_mkldnn, char_list_file, FLAGS_enable_mkldnn, char_list_file,
FLAGS_use_tensorrt, FLAGS_precision); FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num);
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) { for (int i = 0; i < cv_all_img_names.size(); ++i) {
LOG(INFO) << "The predict img: " << cv_all_img_names[i]; LOG(INFO) << "The predict img: " << cv_all_img_names[i];
...@@ -156,22 +157,21 @@ int main_rec(std::vector<cv::String> cv_all_img_names) { ...@@ -156,22 +157,21 @@ int main_rec(std::vector<cv::String> cv_all_img_names) {
std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << endl; std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << endl;
exit(1); exit(1);
} }
img_list.push_back(srcimg);
std::vector<double> rec_times;
rec.Run(srcimg, &rec_times);
time_info[0] += rec_times[0];
time_info[1] += rec_times[1];
time_info[2] += rec_times[2];
} }
std::vector<double> rec_times;
rec.Run(img_list, &rec_times);
time_info[0] += rec_times[0];
time_info[1] += rec_times[1];
time_info[2] += rec_times[2];
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
AutoLogger autolog("ocr_rec", AutoLogger autolog("ocr_rec",
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_tensorrt, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_enable_mkldnn,
FLAGS_cpu_threads, FLAGS_cpu_threads,
1, FLAGS_rec_batch_num,
"dynamic", "dynamic",
FLAGS_precision, FLAGS_precision,
time_info, time_info,
...@@ -209,7 +209,7 @@ int main_system(std::vector<cv::String> cv_all_img_names) { ...@@ -209,7 +209,7 @@ int main_system(std::vector<cv::String> cv_all_img_names) {
CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_gpu_mem, FLAGS_cpu_threads,
FLAGS_enable_mkldnn, char_list_file, FLAGS_enable_mkldnn, char_list_file,
FLAGS_use_tensorrt, FLAGS_precision); FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num);
for (int i = 0; i < cv_all_img_names.size(); ++i) { for (int i = 0; i < cv_all_img_names.size(); ++i) {
LOG(INFO) << "The predict img: " << cv_all_img_names[i]; LOG(INFO) << "The predict img: " << cv_all_img_names[i];
...@@ -228,19 +228,22 @@ int main_system(std::vector<cv::String> cv_all_img_names) { ...@@ -228,19 +228,22 @@ int main_system(std::vector<cv::String> cv_all_img_names) {
time_info_det[1] += det_times[1]; time_info_det[1] += det_times[1];
time_info_det[2] += det_times[2]; time_info_det[2] += det_times[2];
cv::Mat crop_img; std::vector<cv::Mat> img_list;
for (int j = 0; j < boxes.size(); j++) { for (int j = 0; j < boxes.size(); j++) {
crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]); cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]);
if (cls != nullptr) { if (cls != nullptr) {
crop_img = cls->Run(crop_img); crop_img = cls->Run(crop_img);
} }
rec.Run(crop_img, &rec_times); img_list.push_back(crop_img);
time_info_rec[0] += rec_times[0];
time_info_rec[1] += rec_times[1];
time_info_rec[2] += rec_times[2];
} }
rec.Run(img_list, &rec_times);
time_info_rec[0] += rec_times[0];
time_info_rec[1] += rec_times[1];
time_info_rec[2] += rec_times[2];
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
AutoLogger autolog_det("ocr_det", AutoLogger autolog_det("ocr_det",
FLAGS_use_gpu, FLAGS_use_gpu,
...@@ -257,7 +260,7 @@ int main_system(std::vector<cv::String> cv_all_img_names) { ...@@ -257,7 +260,7 @@ int main_system(std::vector<cv::String> cv_all_img_names) {
FLAGS_use_tensorrt, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_enable_mkldnn,
FLAGS_cpu_threads, FLAGS_cpu_threads,
1, FLAGS_rec_batch_num,
"dynamic", "dynamic",
FLAGS_precision, FLAGS_precision,
time_info_rec, time_info_rec,
......
...@@ -15,83 +15,109 @@ ...@@ -15,83 +15,109 @@
#include <include/ocr_rec.h> #include <include/ocr_rec.h>
namespace PaddleOCR { namespace PaddleOCR {
void CRNNRecognizer::Run(cv::Mat &img, std::vector<double> *times) { void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, std::vector<double> *times) {
cv::Mat srcimg; std::chrono::duration<float> preprocess_diff = std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
img.copyTo(srcimg); std::chrono::duration<float> inference_diff = std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
cv::Mat resize_img; std::chrono::duration<float> postprocess_diff = std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
float wh_ratio = float(srcimg.cols) / float(srcimg.rows); int img_num = img_list.size();
auto preprocess_start = std::chrono::steady_clock::now(); std::vector<float> width_list;
this->resize_op_.Run(srcimg, resize_img, wh_ratio, this->use_tensorrt_); for (int i = 0; i < img_num; i++) {
width_list.push_back(float(img_list[i].cols) / img_list[i].rows);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, }
this->is_scale_); std::vector<int> indices = Utility::argsort(width_list);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); for (int beg_img_no = 0; beg_img_no < img_num; beg_img_no += this->rec_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now();
this->permute_op_.Run(&resize_img, input.data()); int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_);
auto preprocess_end = std::chrono::steady_clock::now(); float max_wh_ratio = 0;
for (int ino = beg_img_no; ino < end_img_no; ino ++) {
// Inference. int h = img_list[indices[ino]].rows;
auto input_names = this->predictor_->GetInputNames(); int w = img_list[indices[ino]].cols;
auto input_t = this->predictor_->GetInputHandle(input_names[0]); float wh_ratio = w * 1.0 / h;
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); max_wh_ratio = max(max_wh_ratio, wh_ratio);
auto inference_start = std::chrono::steady_clock::now(); }
input_t->CopyFromCpu(input.data()); // cout << "max_wh_ratio: " << max_wh_ratio << endl;
this->predictor_->Run(); std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino ++) {
std::vector<float> predict_batch; cv::Mat srcimg;
auto output_names = this->predictor_->GetOutputNames(); img_list[indices[ino]].copyTo(srcimg);
auto output_t = this->predictor_->GetOutputHandle(output_names[0]); cv::Mat resize_img;
auto predict_shape = output_t->shape(); this->resize_op_.Run(srcimg, resize_img, max_wh_ratio, this->use_tensorrt_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, this->is_scale_);
int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1, norm_img_batch.push_back(resize_img);
}
int batch_width = int(ceilf(32 * max_wh_ratio)) - 1;
std::vector<float> input(this->rec_batch_num_ * 3 * 32 * batch_width, 0.0f);
this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// Inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({this->rec_batch_num_, 3, 32, batch_width});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
std::vector<float> predict_batch;
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
auto predict_shape = output_t->shape();
int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1,
std::multiplies<int>()); std::multiplies<int>());
predict_batch.resize(out_num); predict_batch.resize(out_num);
output_t->CopyToCpu(predict_batch.data()); output_t->CopyToCpu(predict_batch.data());
auto inference_end = std::chrono::steady_clock::now(); auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// ctc decode
auto postprocess_start = std::chrono::steady_clock::now(); // ctc decode
std::vector<std::string> str_res; auto postprocess_start = std::chrono::steady_clock::now();
int argmax_idx; for (int m = 0; m < predict_shape[0]; m++) {
int last_index = 0; std::vector<std::string> str_res;
float score = 0.f; int argmax_idx;
int count = 0; int last_index = 0;
float max_value = 0.0f; float score = 0.f;
int count = 0;
for (int n = 0; n < predict_shape[1]; n++) { float max_value = 0.0f;
argmax_idx =
int(Utility::argmax(&predict_batch[n * predict_shape[2]], for (int n = 0; n < predict_shape[1]; n++) {
&predict_batch[(n + 1) * predict_shape[2]])); argmax_idx =
max_value = int(Utility::argmax(&predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
float(*std::max_element(&predict_batch[n * predict_shape[2]], &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]]));
&predict_batch[(n + 1) * predict_shape[2]])); max_value =
float(*std::max_element(&predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]]));
score += max_value;
count += 1; if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
str_res.push_back(label_list_[argmax_idx]); score += max_value;
count += 1;
str_res.push_back(label_list_[argmax_idx]);
}
last_index = argmax_idx;
}
score /= count;
if (isnan(score))
continue;
for (int i = 0; i < str_res.size(); i++) {
std::cout << str_res[i];
}
std::cout << "\tscore: " << score << std::endl;
}
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
} }
last_index = argmax_idx; times->push_back(double(preprocess_diff.count() * 1000));
} times->push_back(double(inference_diff.count() * 1000));
auto postprocess_end = std::chrono::steady_clock::now(); times->push_back(double(postprocess_diff.count() * 1000));
score /= count;
for (int i = 0; i < str_res.size(); i++) {
std::cout << str_res[i];
}
std::cout << "\tscore: " << score << std::endl;
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = inference_end - inference_start;
times->push_back(double(inference_diff.count() * 1000));
std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
} }
void CRNNRecognizer::LoadModel(const std::string &model_dir) { void CRNNRecognizer::LoadModel(const std::string &model_dir) {
// AnalysisConfig config; // AnalysisConfig config;
paddle_infer::Config config; paddle_infer::Config config;
......
...@@ -40,6 +40,17 @@ void Permute::Run(const cv::Mat *im, float *data) { ...@@ -40,6 +40,17 @@ void Permute::Run(const cv::Mat *im, float *data) {
} }
} }
void Permute_batch::Run(const std::vector<cv::Mat> imgs, float *data) {
for (int j = 0; j < imgs.size(); j ++){
int rh = imgs[j].rows;
int rw = imgs[j].cols;
int rc = imgs[j].channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i);
}
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean, void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale) { const std::vector<float> &scale, const bool is_scale) {
double e = 1.0; double e = 1.0;
...@@ -90,16 +101,17 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, ...@@ -90,16 +101,17 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
imgC = rec_image_shape[0]; imgC = rec_image_shape[0];
imgH = rec_image_shape[1]; imgH = rec_image_shape[1];
imgW = rec_image_shape[2]; imgW = rec_image_shape[2];
imgW = int(32 * wh_ratio); imgW = int(32 * wh_ratio);
float ratio = float(img.cols) / float(img.rows); float ratio = float(img.cols) / float(img.rows);
int resize_w, resize_h; int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW) if (ceilf(imgH * ratio) > imgW)
resize_w = imgW; resize_w = imgW;
else else
resize_w = int(ceilf(imgH * ratio)); resize_w = int(ceilf(imgH * ratio));
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR); cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
......
...@@ -147,4 +147,18 @@ cv::Mat Utility::GetRotateCropImage(const cv::Mat &srcimage, ...@@ -147,4 +147,18 @@ cv::Mat Utility::GetRotateCropImage(const cv::Mat &srcimage,
} }
} }
// 实现argsort功能
std::vector<int> Utility::argsort(const std::vector<float>& array)
{
const int array_len(array.size());
std::vector<int> array_index(array_len, 0);
for (int i = 0; i < array_len; ++i)
array_index[i] = i;
std::sort(array_index.begin(), array_index.end(),
[&array](int pos1, int pos2) {return (array[pos1] < array[pos2]); });
return array_index;
}
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
...@@ -106,7 +106,6 @@ class TextRecognizer(object): ...@@ -106,7 +106,6 @@ class TextRecognizer(object):
return norm_img.astype(np.float32) / 128. - 1. return norm_img.astype(np.float32) / 128. - 1.
assert imgC == img.shape[2] assert imgC == img.shape[2]
max_wh_ratio = max(max_wh_ratio, imgW / imgH)
imgW = int((32 * max_wh_ratio)) imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2] h, w = img.shape[:2]
ratio = w / float(h) ratio = w / float(h)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册