ocr_det.cpp 5.2 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <include/ocr_det.h>
M
MissPenguin 已提交
16

littletomatodonkey's avatar
littletomatodonkey 已提交
17 18
namespace PaddleOCR {

littletomatodonkey's avatar
littletomatodonkey 已提交
19
void DBDetector::LoadModel(const std::string &model_dir) {
L
LDOUBLEV 已提交
20 21
  //   AnalysisConfig config;
  paddle_infer::Config config;
文幕地方's avatar
文幕地方 已提交
22 23
  config.SetModel(model_dir + "/inference.pdmodel",
                  model_dir + "/inference.pdiparams");
littletomatodonkey's avatar
littletomatodonkey 已提交
24

littletomatodonkey's avatar
littletomatodonkey 已提交
25 26
  if (this->use_gpu_) {
    config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
L
LDOUBLEV 已提交
27
    if (this->use_tensorrt_) {
M
MissPenguin 已提交
28 29 30 31
      auto precision = paddle_infer::Config::Precision::kFloat32;
      if (this->precision_ == "fp16") {
        precision = paddle_infer::Config::Precision::kHalf;
      }
文幕地方's avatar
文幕地方 已提交
32
      if (this->precision_ == "int8") {
M
MissPenguin 已提交
33
        precision = paddle_infer::Config::Precision::kInt8;
文幕地方's avatar
文幕地方 已提交
34
      }
35 36 37 38 39 40
      config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false);
      if (!Utility::PathExists("./trt_det_shape.txt")){
        config.CollectShapeRangeInfo("./trt_det_shape.txt");
      } else { 
        config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true);
      }
L
LDOUBLEV 已提交
41
    }
littletomatodonkey's avatar
littletomatodonkey 已提交
42 43
  } else {
    config.DisableGpu();
littletomatodonkey's avatar
littletomatodonkey 已提交
44 45
    if (this->use_mkldnn_) {
      config.EnableMKLDNN();
W
WenmuZhou 已提交
46 47
      // cache 10 different shapes for mkldnn to avoid memory leak
      config.SetMkldnnCacheCapacity(10);
littletomatodonkey's avatar
littletomatodonkey 已提交
48
    }
littletomatodonkey's avatar
littletomatodonkey 已提交
49 50
    config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
  }
L
LDOUBLEV 已提交
51 52
  // use zero_copy_run as default
  config.SwitchUseFeedFetchOps(false);
littletomatodonkey's avatar
littletomatodonkey 已提交
53
  // true for multiple input
littletomatodonkey's avatar
littletomatodonkey 已提交
54
  config.SwitchSpecifyInputNames(true);
littletomatodonkey's avatar
littletomatodonkey 已提交
55 56 57 58

  config.SwitchIrOptim(true);

  config.EnableMemoryOptim();
L
LDOUBLEV 已提交
59
  // config.DisableGlogInfo();
littletomatodonkey's avatar
littletomatodonkey 已提交
60

L
LDOUBLEV 已提交
61
  this->predictor_ = CreatePredictor(config);
littletomatodonkey's avatar
littletomatodonkey 已提交
62 63 64
}

void DBDetector::Run(cv::Mat &img,
M
MissPenguin 已提交
65
                     std::vector<std::vector<std::vector<int>>> &boxes,
66
                     std::vector<double> &times) {
littletomatodonkey's avatar
littletomatodonkey 已提交
67 68 69 70 71 72
  float ratio_h{};
  float ratio_w{};

  cv::Mat srcimg;
  cv::Mat resize_img;
  img.copyTo(srcimg);
文幕地方's avatar
文幕地方 已提交
73

M
MissPenguin 已提交
74
  auto preprocess_start = std::chrono::steady_clock::now();
文幕地方's avatar
文幕地方 已提交
75 76
  this->resize_op_.Run(img, resize_img, this->limit_type_,
                       this->limit_side_len_, ratio_h, ratio_w,
R
root 已提交
77
                       this->use_tensorrt_);
littletomatodonkey's avatar
littletomatodonkey 已提交
78 79 80 81

  this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
                          this->is_scale_);

littletomatodonkey's avatar
littletomatodonkey 已提交
82 83
  std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
  this->permute_op_.Run(&resize_img, input.data());
M
MissPenguin 已提交
84
  auto preprocess_end = std::chrono::steady_clock::now();
文幕地方's avatar
文幕地方 已提交
85

86
  // Inference.
L
LDOUBLEV 已提交
87 88 89
  auto input_names = this->predictor_->GetInputNames();
  auto input_t = this->predictor_->GetInputHandle(input_names[0]);
  input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
M
update  
MissPenguin 已提交
90
  auto inference_start = std::chrono::steady_clock::now();
L
LDOUBLEV 已提交
91
  input_t->CopyFromCpu(input.data());
文幕地方's avatar
文幕地方 已提交
92

L
LDOUBLEV 已提交
93
  this->predictor_->Run();
文幕地方's avatar
文幕地方 已提交
94

littletomatodonkey's avatar
littletomatodonkey 已提交
95 96
  std::vector<float> out_data;
  auto output_names = this->predictor_->GetOutputNames();
L
LDOUBLEV 已提交
97
  auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
littletomatodonkey's avatar
littletomatodonkey 已提交
98 99 100 101 102
  std::vector<int> output_shape = output_t->shape();
  int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
                                std::multiplies<int>());

  out_data.resize(out_num);
L
LDOUBLEV 已提交
103
  output_t->CopyToCpu(out_data.data());
M
MissPenguin 已提交
104
  auto inference_end = std::chrono::steady_clock::now();
文幕地方's avatar
文幕地方 已提交
105

M
MissPenguin 已提交
106
  auto postprocess_start = std::chrono::steady_clock::now();
littletomatodonkey's avatar
littletomatodonkey 已提交
107 108 109 110
  int n2 = output_shape[2];
  int n3 = output_shape[3];
  int n = n2 * n3;

littletomatodonkey's avatar
littletomatodonkey 已提交
111 112
  std::vector<float> pred(n, 0.0);
  std::vector<unsigned char> cbuf(n, ' ');
littletomatodonkey's avatar
littletomatodonkey 已提交
113 114 115 116 117 118

  for (int i = 0; i < n; i++) {
    pred[i] = float(out_data[i]);
    cbuf[i] = (unsigned char)((out_data[i]) * 255);
  }

littletomatodonkey's avatar
littletomatodonkey 已提交
119 120
  cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
  cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data());
littletomatodonkey's avatar
littletomatodonkey 已提交
121

littletomatodonkey's avatar
littletomatodonkey 已提交
122
  const double threshold = this->det_db_thresh_ * 255;
littletomatodonkey's avatar
littletomatodonkey 已提交
123 124 125
  const double maxvalue = 255;
  cv::Mat bit_map;
  cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
文幕地方's avatar
文幕地方 已提交
126 127 128 129 130 131
  if (this->use_dilation_) {
    cv::Mat dila_ele =
        cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
    cv::dilate(bit_map, bit_map, dila_ele);
  }

132
  boxes = post_processor_.BoxesFromBitmap(
文幕地方's avatar
文幕地方 已提交
133
      pred_map, bit_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_,
134
      this->det_db_score_mode_);
littletomatodonkey's avatar
littletomatodonkey 已提交
135

littletomatodonkey's avatar
littletomatodonkey 已提交
136
  boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
M
MissPenguin 已提交
137 138
  auto postprocess_end = std::chrono::steady_clock::now();

文幕地方's avatar
文幕地方 已提交
139 140
  std::chrono::duration<float> preprocess_diff =
      preprocess_end - preprocess_start;
141
  times.push_back(double(preprocess_diff.count() * 1000));
M
MissPenguin 已提交
142
  std::chrono::duration<float> inference_diff = inference_end - inference_start;
143
  times.push_back(double(inference_diff.count() * 1000));
文幕地方's avatar
文幕地方 已提交
144 145
  std::chrono::duration<float> postprocess_diff =
      postprocess_end - postprocess_start;
146
  times.push_back(double(postprocess_diff.count() * 1000));
littletomatodonkey's avatar
littletomatodonkey 已提交
147 148
}

littletomatodonkey's avatar
littletomatodonkey 已提交
149
} // namespace PaddleOCR