ocr_rec.cpp 6.8 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <include/ocr_rec.h>

namespace PaddleOCR {
文幕地方's avatar
文幕地方 已提交
18 19

void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
文幕地方's avatar
文幕地方 已提交
20 21
                         std::vector<std::string> &rec_texts,
                         std::vector<float> &rec_text_scores,
22
                         std::vector<double> &times) {
文幕地方's avatar
文幕地方 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
  std::chrono::duration<float> preprocess_diff =
      std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
  std::chrono::duration<float> inference_diff =
      std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
  std::chrono::duration<float> postprocess_diff =
      std::chrono::steady_clock::now() - std::chrono::steady_clock::now();

  int img_num = img_list.size();
  std::vector<float> width_list;
  for (int i = 0; i < img_num; i++) {
    width_list.push_back(float(img_list[i].cols) / img_list[i].rows);
  }
  std::vector<int> indices = Utility::argsort(width_list);

  for (int beg_img_no = 0; beg_img_no < img_num;
       beg_img_no += this->rec_batch_num_) {
    auto preprocess_start = std::chrono::steady_clock::now();
    int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_);
41
    int batch_num = end_img_no - beg_img_no;
文幕地方's avatar
文幕地方 已提交
42 43 44 45 46 47
    float max_wh_ratio = 0;
    for (int ino = beg_img_no; ino < end_img_no; ino++) {
      int h = img_list[indices[ino]].rows;
      int w = img_list[indices[ino]].cols;
      float wh_ratio = w * 1.0 / h;
      max_wh_ratio = max(max_wh_ratio, wh_ratio);
M
MissPenguin 已提交
48
    }
49

文幕地方's avatar
文幕地方 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63
    int batch_width = 0;
    std::vector<cv::Mat> norm_img_batch;
    for (int ino = beg_img_no; ino < end_img_no; ino++) {
      cv::Mat srcimg;
      img_list[indices[ino]].copyTo(srcimg);
      cv::Mat resize_img;
      this->resize_op_.Run(srcimg, resize_img, max_wh_ratio,
                           this->use_tensorrt_);
      this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
                              this->is_scale_);
      norm_img_batch.push_back(resize_img);
      batch_width = max(resize_img.cols, batch_width);
    }

64
    std::vector<float> input(batch_num * 3 * 32 * batch_width, 0.0f);
文幕地方's avatar
文幕地方 已提交
65 66 67 68 69 70
    this->permute_op_.Run(norm_img_batch, input.data());
    auto preprocess_end = std::chrono::steady_clock::now();
    preprocess_diff += preprocess_end - preprocess_start;
    // Inference.
    auto input_names = this->predictor_->GetInputNames();
    auto input_t = this->predictor_->GetInputHandle(input_names[0]);
71
    input_t->Reshape({batch_num, 3, 32, batch_width});
文幕地方's avatar
文幕地方 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    auto inference_start = std::chrono::steady_clock::now();
    input_t->CopyFromCpu(input.data());
    this->predictor_->Run();

    std::vector<float> predict_batch;
    auto output_names = this->predictor_->GetOutputNames();
    auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
    auto predict_shape = output_t->shape();

    int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1,
                                  std::multiplies<int>());
    predict_batch.resize(out_num);

    output_t->CopyToCpu(predict_batch.data());
    auto inference_end = std::chrono::steady_clock::now();
    inference_diff += inference_end - inference_start;
    // ctc decode
    auto postprocess_start = std::chrono::steady_clock::now();
    for (int m = 0; m < predict_shape[0]; m++) {
文幕地方's avatar
文幕地方 已提交
91
      std::string str_res;
文幕地方's avatar
文幕地方 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
      int argmax_idx;
      int last_index = 0;
      float score = 0.f;
      int count = 0;
      float max_value = 0.0f;

      for (int n = 0; n < predict_shape[1]; n++) {
        argmax_idx = int(Utility::argmax(
            &predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
            &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]]));
        max_value = float(*std::max_element(
            &predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
            &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]]));

        if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
          score += max_value;
          count += 1;
文幕地方's avatar
文幕地方 已提交
109
          str_res += label_list_[argmax_idx];
M
MissPenguin 已提交
110
        }
文幕地方's avatar
文幕地方 已提交
111 112 113
        last_index = argmax_idx;
      }
      score /= count;
文幕地方's avatar
文幕地方 已提交
114
      if (isnan(score)) {
文幕地方's avatar
文幕地方 已提交
115 116
        continue;
      }
文幕地方's avatar
文幕地方 已提交
117 118
      rec_texts[indices[beg_img_no + m]] = str_res;
      rec_text_scores[indices[beg_img_no + m]] = score;
W
WenmuZhou 已提交
119
    }
文幕地方's avatar
文幕地方 已提交
120 121 122
    auto postprocess_end = std::chrono::steady_clock::now();
    postprocess_diff += postprocess_end - postprocess_start;
  }
123 124 125
  times.push_back(double(preprocess_diff.count() * 1000));
  times.push_back(double(inference_diff.count() * 1000));
  times.push_back(double(postprocess_diff.count() * 1000));
littletomatodonkey's avatar
littletomatodonkey 已提交
126 127
}

littletomatodonkey's avatar
littletomatodonkey 已提交
128
void CRNNRecognizer::LoadModel(const std::string &model_dir) {
L
LDOUBLEV 已提交
129 130
  //   AnalysisConfig config;
  paddle_infer::Config config;
文幕地方's avatar
文幕地方 已提交
131 132
  config.SetModel(model_dir + "/inference.pdmodel",
                  model_dir + "/inference.pdiparams");
littletomatodonkey's avatar
littletomatodonkey 已提交
133

littletomatodonkey's avatar
littletomatodonkey 已提交
134 135
  if (this->use_gpu_) {
    config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
L
LDOUBLEV 已提交
136
    if (this->use_tensorrt_) {
M
MissPenguin 已提交
137 138 139 140
      auto precision = paddle_infer::Config::Precision::kFloat32;
      if (this->precision_ == "fp16") {
        precision = paddle_infer::Config::Precision::kHalf;
      }
文幕地方's avatar
文幕地方 已提交
141
      if (this->precision_ == "int8") {
M
MissPenguin 已提交
142
        precision = paddle_infer::Config::Precision::kInt8;
文幕地方's avatar
文幕地方 已提交
143 144
      }
      config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
M
MissPenguin 已提交
145

L
LDOUBLEV 已提交
146
      std::map<std::string, std::vector<int>> min_input_shape = {
文幕地方's avatar
文幕地方 已提交
147
          {"x", {1, 3, 32, 10}}, {"lstm_0.tmp_0", {10, 1, 96}}};
L
LDOUBLEV 已提交
148
      std::map<std::string, std::vector<int>> max_input_shape = {
文幕地方's avatar
文幕地方 已提交
149
          {"x", {1, 3, 32, 2000}}, {"lstm_0.tmp_0", {1000, 1, 96}}};
L
LDOUBLEV 已提交
150
      std::map<std::string, std::vector<int>> opt_input_shape = {
文幕地方's avatar
文幕地方 已提交
151
          {"x", {1, 3, 32, 320}}, {"lstm_0.tmp_0", {25, 1, 96}}};
L
LDOUBLEV 已提交
152 153 154

      config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
                                    opt_input_shape);
L
LDOUBLEV 已提交
155
    }
littletomatodonkey's avatar
littletomatodonkey 已提交
156 157
  } else {
    config.DisableGpu();
littletomatodonkey's avatar
littletomatodonkey 已提交
158 159
    if (this->use_mkldnn_) {
      config.EnableMKLDNN();
W
WenmuZhou 已提交
160 161
      // cache 10 different shapes for mkldnn to avoid memory leak
      config.SetMkldnnCacheCapacity(10);
littletomatodonkey's avatar
littletomatodonkey 已提交
162
    }
littletomatodonkey's avatar
littletomatodonkey 已提交
163 164
    config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
  }
littletomatodonkey's avatar
littletomatodonkey 已提交
165

L
LDOUBLEV 已提交
166
  config.SwitchUseFeedFetchOps(false);
littletomatodonkey's avatar
littletomatodonkey 已提交
167
  // true for multiple input
littletomatodonkey's avatar
littletomatodonkey 已提交
168
  config.SwitchSpecifyInputNames(true);
littletomatodonkey's avatar
littletomatodonkey 已提交
169 170 171 172

  config.SwitchIrOptim(true);

  config.EnableMemoryOptim();
文幕地方's avatar
文幕地方 已提交
173
  //   config.DisableGlogInfo();
littletomatodonkey's avatar
littletomatodonkey 已提交
174

L
LDOUBLEV 已提交
175
  this->predictor_ = CreatePredictor(config);
littletomatodonkey's avatar
littletomatodonkey 已提交
176 177
}

L
littletomatodonkey 已提交
178
} // namespace PaddleOCR