diff --git a/deploy/cpp_infer/include/config.h b/deploy/cpp_infer/include/config.h index 27539ea7934dc192e86bca3ea6bfd7999ee229a3..4bee4f4e5cd707633ee5e6cf6bdc1a2bd7953c0f 100644 --- a/deploy/cpp_infer/include/config.h +++ b/deploy/cpp_infer/include/config.h @@ -63,6 +63,8 @@ public: this->cls_thresh = stod(config_map_["cls_thresh"]); + this->rec_batch_num = stoi(config_map_["rec_batch_num"]); + this->visualize = bool(stoi(config_map_["visualize"])); } @@ -86,6 +88,8 @@ public: double det_db_unclip_ratio = 2.0; + int rec_batch_num = 30; + std::string det_model_dir; std::string rec_model_dir; diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index a8b99a5960ac3e6238dfea2285ec51c9e80e1749..73b0f906f1aed41520d28795addfe89e0641eb48 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -40,13 +40,14 @@ public: const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, const bool &use_mkldnn, const bool &use_zero_copy_run, - const string &label_path) { + const string &label_path, const int& rec_batch_num) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; this->use_zero_copy_run_ = use_zero_copy_run; + this->rec_batch_num_ = rec_batch_num; this->label_list_ = Utility::ReadDict(label_path); this->label_list_.push_back(" "); @@ -69,6 +70,7 @@ private: int cpu_math_library_num_threads_ = 4; bool use_mkldnn_ = false; bool use_zero_copy_run_ = false; + int rec_batch_num_ = 30; std::vector label_list_; diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 4b84dbd0903b4290c8d3b2b4feeb95c3bd234fe8..147933414aae51650626a709c8dee8a683482ab3 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -67,7 +67,7 @@ int main(int argc, char **argv) { CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, config.use_mkldnn, config.use_zero_copy_run, - config.char_list_file); + config.char_list_file, config.rec_batch_num); #ifdef USE_MKL #pragma omp parallel @@ -91,11 +91,11 @@ int main(int argc, char **argv) { auto end = std::chrono::system_clock::now(); auto duration = std::chrono::duration_cast(end - start); - std::cout << "花费了" + std::cout << "cost" << double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den - << "秒" << std::endl; + << "s" << std::endl; return 0; } diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index 7f88adc54636b4ecc61d257b7cb9159ebcdb82af..50910d7c09601dfc3f6f4eb918717abf1b208fcd 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -14,6 +14,20 @@ #include +template +vector argsort(const std::vector& array) +{ + const int array_len(array.size()); + std::vector array_index(array_len, 0); + for (int i = 0; i < array_len; ++i) + array_index[i] = i; + + std::sort(array_index.begin(), array_index.end(), + [&array](int pos1, int pos2) {return (array[pos1] < array[pos2]); }); + + return array_index; +} + namespace PaddleOCR { void CRNNRecognizer::Run(std::vector>> boxes, @@ -22,100 +36,122 @@ void CRNNRecognizer::Run(std::vector>> boxes, img.copyTo(srcimg); cv::Mat crop_img; cv::Mat resize_img; + std::vector width_list; + std::vector img_list; + for (int i = boxes.size() - 1; i >= 0; i--) { + crop_img = GetRotateCropImage(srcimg, boxes[i]); + if (cls != nullptr) { + crop_img = cls->Run(crop_img); + } + img_list.push_back(crop_img); + float wh_ratio = float(crop_img.cols) / float(crop_img.rows); + width_list.push_back(wh_ratio); + } + //sort box + vector sort_index = argsort(width_list); + int batch_num1 = this->rec_batch_num_;//batchsize std::cout << "The predicted text is :" << std::endl; int index = 0; - for (int i = boxes.size() - 1; i >= 0; i--) { - crop_img = GetRotateCropImage(srcimg, boxes[i]); - if (cls != nullptr) { - crop_img = cls->Run(crop_img); + int beg_img_no = 0; + int end_img_no = 0; + for (int beg_img_no = 0; beg_img_no < img_list.size(); beg_img_no += batch_num1) + { + float max_wh_ratio = 0; + end_img_no = min((int)boxes.size(), beg_img_no + batch_num1); + int batch_num = min(end_img_no - beg_img_no, batch_num1); + max_wh_ratio = width_list[sort_index[end_img_no - 1]]; + int imgW1 = int(32 * max_wh_ratio); + int nqu, nra; + nqu = imgW1 / 4; + nra = imgW1 % 4; + int imgW = imgW1; + if (nra > 0) + { + imgW = int(4 * (nqu + 1)); } - - float wh_ratio = float(crop_img.cols) / float(crop_img.rows); - - this->resize_op_.Run(crop_img, resize_img, wh_ratio); - - this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, - this->is_scale_); - - std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); - - this->permute_op_.Run(&resize_img, input.data()); - - // Inference. - if (this->use_zero_copy_run_) { - auto input_names = this->predictor_->GetInputNames(); - auto input_t = this->predictor_->GetInputTensor(input_names[0]); - input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); - input_t->copy_from_cpu(input.data()); - this->predictor_->ZeroCopyRun(); - } else { - paddle::PaddleTensor input_t; - input_t.shape = {1, 3, resize_img.rows, resize_img.cols}; - input_t.data = - paddle::PaddleBuf(input.data(), input.size() * sizeof(float)); - input_t.dtype = PaddleDType::FLOAT32; - std::vector outputs; - this->predictor_->Run({input_t}, &outputs, 1); + std::vector input(batch_num * 3 * 32 * imgW, 0.0f);//batchsize input + for (int i = beg_img_no; i < end_img_no; i++) + { + crop_img = img_list[sort_index[i]]; + this->resize_op_.Run(crop_img, resize_img, max_wh_ratio);//resize + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + + cv::Mat padding_im; + cv::copyMakeBorder(resize_img, padding_im, 0, 0, 0, int(imgW - resize_img.cols), cv::BORDER_CONSTANT, { 0, 0, 0 });//padding image + + this->permute_op_.Run(&padding_im, input.data() + (i - beg_img_no) * 3 * padding_im.rows * padding_im.cols); } + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputTensor(input_names[0]); + input_t->Reshape({ batch_num, 3, 32, imgW }); + input_t->copy_from_cpu(input.data()); + + this->predictor_->ZeroCopyRun(); std::vector rec_idx; auto output_names = this->predictor_->GetOutputNames(); auto output_t = this->predictor_->GetOutputTensor(output_names[0]); - auto rec_idx_lod = output_t->lod(); - auto shape_out = output_t->shape(); - - int out_num = std::accumulate(shape_out.begin(), shape_out.end(), 1, - std::multiplies()); - - rec_idx.resize(out_num); - output_t->copy_to_cpu(rec_idx.data()); - - std::vector pred_idx; - for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1]); n++) { - pred_idx.push_back(int(rec_idx[n])); - } + auto rec_idx_lod = output_t->lod()[0]; - if (pred_idx.size() < 1e-3) - continue; - - index += 1; - std::cout << index << "\t"; - for (int n = 0; n < pred_idx.size(); n++) { - std::cout << label_list_[pred_idx[n]]; + std::vector output_shape = output_t->shape(); + int out_num = 1; + for (int i = 0; i < output_shape.size(); ++i) { + out_num *= output_shape[i]; } + rec_idx.resize(out_num); + output_t->copy_to_cpu(rec_idx.data());//output data std::vector predict_batch; auto output_t_1 = this->predictor_->GetOutputTensor(output_names[1]); - auto predict_lod = output_t_1->lod(); + auto predict_lod = output_t_1->lod()[0]; auto predict_shape = output_t_1->shape(); - int out_num_1 = std::accumulate(predict_shape.begin(), predict_shape.end(), - 1, std::multiplies()); + + int out_num_1 = 1; + for (int i = 0; i < predict_shape.size(); ++i) { + out_num_1 *= predict_shape[i]; + } predict_batch.resize(out_num_1); output_t_1->copy_to_cpu(predict_batch.data()); int argmax_idx; int blank = predict_shape[1]; - float score = 0.f; - int count = 0; - float max_value = 0.0f; - - for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) { - argmax_idx = - int(Utility::argmax(&predict_batch[n * predict_shape[1]], - &predict_batch[(n + 1) * predict_shape[1]])); - max_value = - float(*std::max_element(&predict_batch[n * predict_shape[1]], - &predict_batch[(n + 1) * predict_shape[1]])); - if (blank - 1 - argmax_idx > 1e-5) { - score += max_value; - count += 1; - } + + for (int j = 0; j < rec_idx_lod.size() - 1; j++) + { + std::vector pred_idx; + float score = 0.f; + int count = 0; + float max_value = 0.0f; + for (int n = int(rec_idx_lod[j]); n < int(rec_idx_lod[j + 1]); n++) { + pred_idx.push_back(int(rec_idx[n])); + } + if (pred_idx.size() < 1e-3) + continue; + + index += 1; + std::cout << index << "\t"; + for (int n = 0; n < pred_idx.size(); n++) { + std::cout << label_list_[pred_idx[n]]; + } + + for (int n = predict_lod[j]; n < predict_lod[j + 1] - 1; n++) { + argmax_idx = + int(Utility::argmax(&predict_batch[n * predict_shape[1]], + &predict_batch[(n + 1) * predict_shape[1]])); + + max_value = predict_batch[n * predict_shape[1] + argmax_idx]; + if (blank - 1 - argmax_idx > 1e-5) { + score += max_value; + count += 1; + } + } + score /= count; + std::cout << "\tscore: " << score << std::endl; } - score /= count; - std::cout << "\tscore: " << score << std::endl; } } diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt index 7e03b9d13af9b65239dc257059ef0fa94106e880..16ec8e1e679c8b2d738da2eb60dd9ec2541efa5f 100644 --- a/deploy/cpp_infer/tools/config.txt +++ b/deploy/cpp_infer/tools/config.txt @@ -21,6 +21,7 @@ cls_thresh 0.9 # rec config rec_model_dir ./inference/rec_crnn char_list_file ../../ppocr/utils/ppocr_keys_v1.txt +rec_batch_num 30 # show the detection results visualize 1