diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index d585112b051daff7c03060836a4c065ba6e3949c..32c0c9f77d557e69b07b48861bebc6e6552e3d01 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -44,7 +44,8 @@ public: const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, const bool &use_mkldnn, const string &label_path, - const bool &use_tensorrt, const std::string &precision) { + const bool &use_tensorrt, const std::string &precision, + const int &rec_batch_num) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; @@ -52,6 +53,7 @@ public: this->use_mkldnn_ = use_mkldnn; this->use_tensorrt_ = use_tensorrt; this->precision_ = precision; + this->rec_batch_num_ = rec_batch_num; this->label_list_ = Utility::ReadDict(label_path); this->label_list_.insert(this->label_list_.begin(), @@ -64,7 +66,8 @@ public: // Load Paddle inference model void LoadModel(const std::string &model_dir); - void Run(cv::Mat &img, std::vector *times); +// void Run(cv::Mat &img, std::vector *times); + void Run(std::vector img_list, std::vector *times); private: std::shared_ptr predictor_; @@ -82,10 +85,12 @@ private: bool is_scale_ = true; bool use_tensorrt_ = false; std::string precision_ = "fp32"; + int rec_batch_num_ = 6; + // pre-process CrnnResizeImg resize_op_; Normalize normalize_op_; - Permute permute_op_; + Permute_batch permute_op_; // post-process PostProcessor post_processor_; diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h index ab4c140059fdcaed9872d2d99b4aea57c7e5208f..c93473300b8859d33c1020b20f2fee05980ca00b 100644 --- a/deploy/cpp_infer/include/preprocess_op.h +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -44,6 +44,11 @@ public: virtual void Run(const cv::Mat *im, float *data); }; +class Permute_batch { +public: + virtual void Run(const std::vector imgs, float *data); +}; + class ResizeImgType0 { public: virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h index 678187d3fabfb1c91584226950155b3c47b5f93f..b6f68aee013d424c8ab58e5e3768159769dfa275 100644 --- a/deploy/cpp_infer/include/utility.h +++ b/deploy/cpp_infer/include/utility.h @@ -50,6 +50,10 @@ public: static cv::Mat GetRotateCropImage(const cv::Mat &srcimage, std::vector> box); + + // 实现argsort功能 + static std::vector argsort(const std::vector& array); + }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 82a248416f086dd2b90e891a23774c294ed50ae3..b7a199b548beca881e4ab69491adcc9351f52c0f 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -61,7 +61,7 @@ DEFINE_string(cls_model_dir, "", "Path of cls inference model."); DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh."); // recognition related DEFINE_string(rec_model_dir, "", "Path of rec inference model."); -DEFINE_int32(rec_batch_num, 1, "rec_batch_num."); +DEFINE_int32(rec_batch_num, 6, "rec_batch_num."); DEFINE_string(char_list_file, "../../ppocr/utils/ppocr_keys_v1.txt", "Path of dictionary."); @@ -146,8 +146,9 @@ int main_rec(std::vector cv_all_img_names) { CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn, char_list_file, - FLAGS_use_tensorrt, FLAGS_precision); + FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num); + std::vector img_list; for (int i = 0; i < cv_all_img_names.size(); ++i) { LOG(INFO) << "The predict img: " << cv_all_img_names[i]; @@ -156,22 +157,21 @@ int main_rec(std::vector cv_all_img_names) { std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << endl; exit(1); } - - std::vector rec_times; - rec.Run(srcimg, &rec_times); - - time_info[0] += rec_times[0]; - time_info[1] += rec_times[1]; - time_info[2] += rec_times[2]; + img_list.push_back(srcimg); } - + std::vector rec_times; + rec.Run(img_list, &rec_times); + time_info[0] += rec_times[0]; + time_info[1] += rec_times[1]; + time_info[2] += rec_times[2]; + if (FLAGS_benchmark) { AutoLogger autolog("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt, FLAGS_enable_mkldnn, FLAGS_cpu_threads, - 1, + FLAGS_rec_batch_num, "dynamic", FLAGS_precision, time_info, @@ -209,7 +209,7 @@ int main_system(std::vector cv_all_img_names) { CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn, char_list_file, - FLAGS_use_tensorrt, FLAGS_precision); + FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num); for (int i = 0; i < cv_all_img_names.size(); ++i) { LOG(INFO) << "The predict img: " << cv_all_img_names[i]; @@ -228,19 +228,22 @@ int main_system(std::vector cv_all_img_names) { time_info_det[1] += det_times[1]; time_info_det[2] += det_times[2]; - cv::Mat crop_img; + std::vector img_list; for (int j = 0; j < boxes.size(); j++) { - crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]); - - if (cls != nullptr) { - crop_img = cls->Run(crop_img); - } - rec.Run(crop_img, &rec_times); - time_info_rec[0] += rec_times[0]; - time_info_rec[1] += rec_times[1]; - time_info_rec[2] += rec_times[2]; + cv::Mat crop_img; + crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]); + if (cls != nullptr) { + crop_img = cls->Run(crop_img); + } + img_list.push_back(crop_img); } + + rec.Run(img_list, &rec_times); + time_info_rec[0] += rec_times[0]; + time_info_rec[1] += rec_times[1]; + time_info_rec[2] += rec_times[2]; } + if (FLAGS_benchmark) { AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, @@ -257,7 +260,7 @@ int main_system(std::vector cv_all_img_names) { FLAGS_use_tensorrt, FLAGS_enable_mkldnn, FLAGS_cpu_threads, - 1, + FLAGS_rec_batch_num, "dynamic", FLAGS_precision, time_info_rec, diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index 3739a66ad802fd108df16bbbbe8c8695963b7693..981c17a98adb81067fb98b7c00b3af90a39b6367 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -15,83 +15,109 @@ #include namespace PaddleOCR { - -void CRNNRecognizer::Run(cv::Mat &img, std::vector *times) { - cv::Mat srcimg; - img.copyTo(srcimg); - cv::Mat resize_img; - - float wh_ratio = float(srcimg.cols) / float(srcimg.rows); - auto preprocess_start = std::chrono::steady_clock::now(); - this->resize_op_.Run(srcimg, resize_img, wh_ratio, this->use_tensorrt_); - - this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, - this->is_scale_); - - std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); - - this->permute_op_.Run(&resize_img, input.data()); - auto preprocess_end = std::chrono::steady_clock::now(); - - // Inference. - auto input_names = this->predictor_->GetInputNames(); - auto input_t = this->predictor_->GetInputHandle(input_names[0]); - input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); - auto inference_start = std::chrono::steady_clock::now(); - input_t->CopyFromCpu(input.data()); - this->predictor_->Run(); - - std::vector predict_batch; - auto output_names = this->predictor_->GetOutputNames(); - auto output_t = this->predictor_->GetOutputHandle(output_names[0]); - auto predict_shape = output_t->shape(); - - int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1, + +void CRNNRecognizer::Run(std::vector img_list, std::vector *times) { + std::chrono::duration preprocess_diff = std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration inference_diff = std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration postprocess_diff = std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + + int img_num = img_list.size(); + std::vector width_list; + for (int i = 0; i < img_num; i++) { + width_list.push_back(float(img_list[i].cols) / img_list[i].rows); + } + std::vector indices = Utility::argsort(width_list); + + for (int beg_img_no = 0; beg_img_no < img_num; beg_img_no += this->rec_batch_num_) { + auto preprocess_start = std::chrono::steady_clock::now(); + int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_); + float max_wh_ratio = 0; + for (int ino = beg_img_no; ino < end_img_no; ino ++) { + int h = img_list[indices[ino]].rows; + int w = img_list[indices[ino]].cols; + float wh_ratio = w * 1.0 / h; + max_wh_ratio = max(max_wh_ratio, wh_ratio); + } +// cout << "max_wh_ratio: " << max_wh_ratio << endl; + std::vector norm_img_batch; + for (int ino = beg_img_no; ino < end_img_no; ino ++) { + cv::Mat srcimg; + img_list[indices[ino]].copyTo(srcimg); + cv::Mat resize_img; + this->resize_op_.Run(srcimg, resize_img, max_wh_ratio, this->use_tensorrt_); + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, this->is_scale_); + norm_img_batch.push_back(resize_img); + } + + int batch_width = int(ceilf(32 * max_wh_ratio)) - 1; + std::vector input(this->rec_batch_num_ * 3 * 32 * batch_width, 0.0f); + this->permute_op_.Run(norm_img_batch, input.data()); + auto preprocess_end = std::chrono::steady_clock::now(); + preprocess_diff += preprocess_end - preprocess_start; + + // Inference. + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape({this->rec_batch_num_, 3, 32, batch_width}); + auto inference_start = std::chrono::steady_clock::now(); + input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); + + std::vector predict_batch; + auto output_names = this->predictor_->GetOutputNames(); + auto output_t = this->predictor_->GetOutputHandle(output_names[0]); + auto predict_shape = output_t->shape(); + + int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1, std::multiplies()); - predict_batch.resize(out_num); - - output_t->CopyToCpu(predict_batch.data()); - auto inference_end = std::chrono::steady_clock::now(); - - // ctc decode - auto postprocess_start = std::chrono::steady_clock::now(); - std::vector str_res; - int argmax_idx; - int last_index = 0; - float score = 0.f; - int count = 0; - float max_value = 0.0f; - - for (int n = 0; n < predict_shape[1]; n++) { - argmax_idx = - int(Utility::argmax(&predict_batch[n * predict_shape[2]], - &predict_batch[(n + 1) * predict_shape[2]])); - max_value = - float(*std::max_element(&predict_batch[n * predict_shape[2]], - &predict_batch[(n + 1) * predict_shape[2]])); - - if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { - score += max_value; - count += 1; - str_res.push_back(label_list_[argmax_idx]); + predict_batch.resize(out_num); + + output_t->CopyToCpu(predict_batch.data()); + auto inference_end = std::chrono::steady_clock::now(); + inference_diff += inference_end - inference_start; + + // ctc decode + auto postprocess_start = std::chrono::steady_clock::now(); + for (int m = 0; m < predict_shape[0]; m++) { + std::vector str_res; + int argmax_idx; + int last_index = 0; + float score = 0.f; + int count = 0; + float max_value = 0.0f; + + for (int n = 0; n < predict_shape[1]; n++) { + argmax_idx = + int(Utility::argmax(&predict_batch[(m * predict_shape[1] + n) * predict_shape[2]], + &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]])); + max_value = + float(*std::max_element(&predict_batch[(m * predict_shape[1] + n) * predict_shape[2]], + &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]])); + + if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { + score += max_value; + count += 1; + str_res.push_back(label_list_[argmax_idx]); + } + last_index = argmax_idx; + } + score /= count; + if (isnan(score)) + continue; + for (int i = 0; i < str_res.size(); i++) { + std::cout << str_res[i]; + } + std::cout << "\tscore: " << score << std::endl; + } + auto postprocess_end = std::chrono::steady_clock::now(); + postprocess_diff += postprocess_end - postprocess_start; } - last_index = argmax_idx; - } - auto postprocess_end = std::chrono::steady_clock::now(); - score /= count; - for (int i = 0; i < str_res.size(); i++) { - std::cout << str_res[i]; - } - std::cout << "\tscore: " << score << std::endl; - - std::chrono::duration preprocess_diff = preprocess_end - preprocess_start; - times->push_back(double(preprocess_diff.count() * 1000)); - std::chrono::duration inference_diff = inference_end - inference_start; - times->push_back(double(inference_diff.count() * 1000)); - std::chrono::duration postprocess_diff = postprocess_end - postprocess_start; - times->push_back(double(postprocess_diff.count() * 1000)); + times->push_back(double(preprocess_diff.count() * 1000)); + times->push_back(double(inference_diff.count() * 1000)); + times->push_back(double(postprocess_diff.count() * 1000)); } + void CRNNRecognizer::LoadModel(const std::string &model_dir) { // AnalysisConfig config; paddle_infer::Config config; diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp index 23c51c2008dc7280ce4d6c232ed766dbf2a53226..4c2c2f7a54177777e7ec8a6e508abbc79fd66b57 100644 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -40,6 +40,17 @@ void Permute::Run(const cv::Mat *im, float *data) { } } +void Permute_batch::Run(const std::vector imgs, float *data) { + for (int j = 0; j < imgs.size(); j ++){ + int rh = imgs[j].rows; + int rw = imgs[j].cols; + int rc = imgs[j].channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i); + } + } +} + void Normalize::Run(cv::Mat *im, const std::vector &mean, const std::vector &scale, const bool is_scale) { double e = 1.0; @@ -90,16 +101,17 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, imgC = rec_image_shape[0]; imgH = rec_image_shape[1]; imgW = rec_image_shape[2]; - + imgW = int(32 * wh_ratio); float ratio = float(img.cols) / float(img.rows); int resize_w, resize_h; + if (ceilf(imgH * ratio) > imgW) resize_w = imgW; else resize_w = int(ceilf(imgH * ratio)); - + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, cv::INTER_LINEAR); cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp index dba445b747ff3f3c0d2db91061650c369977c4dd..49d3cb81ef7125ca8e40060cd39029b8e06a70a0 100644 --- a/deploy/cpp_infer/src/utility.cpp +++ b/deploy/cpp_infer/src/utility.cpp @@ -147,4 +147,18 @@ cv::Mat Utility::GetRotateCropImage(const cv::Mat &srcimage, } } +// 实现argsort功能 +std::vector Utility::argsort(const std::vector& array) +{ + const int array_len(array.size()); + std::vector array_index(array_len, 0); + for (int i = 0; i < array_len; ++i) + array_index[i] = i; + + std::sort(array_index.begin(), array_index.end(), + [&array](int pos1, int pos2) {return (array[pos1] < array[pos2]); }); + + return array_index; +} + } // namespace PaddleOCR \ No newline at end of file diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 936994a215d10d543537b29cb41bfa42b42590c7..6cc91b56a31708986ffbe35b649a67c4385c22b2 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -106,7 +106,6 @@ class TextRecognizer(object): return norm_img.astype(np.float32) / 128. - 1. assert imgC == img.shape[2] - max_wh_ratio = max(max_wh_ratio, imgW / imgH) imgW = int((32 * max_wh_ratio)) h, w = img.shape[:2] ratio = w / float(h)