main.cpp 12.8 KB
Newer Older
M
MissPenguin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "omp.h"
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
文幕地方's avatar
文幕地方 已提交
23
#include <sys/stat.h>
M
MissPenguin 已提交
24 25 26 27 28 29 30
#include <vector>

#include <cstring>
#include <fstream>
#include <numeric>

#include <include/ocr_cls.h>
文幕地方's avatar
文幕地方 已提交
31
#include <include/ocr_det.h>
M
MissPenguin 已提交
32
#include <include/ocr_rec.h>
M
MissPenguin 已提交
33
#include <include/utility.h>
M
MissPenguin 已提交
34 35
#include <sys/stat.h>

M
MissPenguin 已提交
36
#include "auto_log/autolog.h"
文幕地方's avatar
文幕地方 已提交
37
#include <gflags/gflags.h>
M
MissPenguin 已提交
38 39 40 41

DEFINE_bool(use_gpu, false, "Infering with GPU or CPU.");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute.");
DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU.");
M
MissPenguin 已提交
42 43
DEFINE_int32(cpu_threads, 10, "Num of threads with CPU.");
DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU.");
M
MissPenguin 已提交
44
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
M
MissPenguin 已提交
45
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");
M
MissPenguin 已提交
46
DEFINE_bool(benchmark, false, "Whether use benchmark.");
文幕地方's avatar
文幕地方 已提交
47
DEFINE_string(output, "./output/", "Save benchmark log path.");
M
MissPenguin 已提交
48 49 50 51 52
// detection related
DEFINE_string(image_dir, "", "Dir of input image.");
DEFINE_string(det_model_dir, "", "Path of det inference model.");
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
文幕地方's avatar
文幕地方 已提交
53 54
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
M
MissPenguin 已提交
55
DEFINE_bool(use_polygon_score, false, "Whether use polygon score.");
文幕地方's avatar
文幕地方 已提交
56
DEFINE_bool(use_dilation, false, "Whether use the dilation on output map.");
M
MissPenguin 已提交
57 58 59 60 61 62 63
DEFINE_bool(visualize, true, "Whether show the detection results.");
// classification related
DEFINE_bool(use_angle_cls, false, "Whether use use_angle_cls.");
DEFINE_string(cls_model_dir, "", "Path of cls inference model.");
DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh.");
// recognition related
DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
M
MissPenguin 已提交
64
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
文幕地方's avatar
文幕地方 已提交
65
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
文幕地方's avatar
文幕地方 已提交
66
              "Path of dictionary.");
M
MissPenguin 已提交
67 68 69 70 71

using namespace std;
using namespace cv;
using namespace PaddleOCR;

文幕地方's avatar
文幕地方 已提交
72
static bool PathExists(const std::string &path) {
M
MissPenguin 已提交
73 74 75 76 77 78
#ifdef _WIN32
  struct _stat buffer;
  return (_stat(path.c_str(), &buffer) == 0);
#else
  struct stat buffer;
  return (stat(path.c_str(), &buffer) == 0);
文幕地方's avatar
文幕地方 已提交
79
#endif // !_WIN32
M
MissPenguin 已提交
80 81
}

M
MissPenguin 已提交
82
int main_det(std::vector<cv::String> cv_all_img_names) {
文幕地方's avatar
文幕地方 已提交
83 84 85 86 87
  std::vector<double> time_info = {0, 0, 0};
  DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
                 FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
                 FLAGS_max_side_len, FLAGS_det_db_thresh,
                 FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
文幕地方's avatar
文幕地方 已提交
88
                 FLAGS_use_polygon_score, FLAGS_use_dilation,
文幕地方's avatar
文幕地方 已提交
89
                 FLAGS_use_tensorrt, FLAGS_precision);
文幕地方's avatar
文幕地方 已提交
90

文幕地方's avatar
文幕地方 已提交
91 92 93 94
  if (!PathExists(FLAGS_output)) {
    mkdir(FLAGS_output.c_str(), 0777);
  }

文幕地方's avatar
文幕地方 已提交
95
  for (int i = 0; i < cv_all_img_names.size(); ++i) {
文幕地方's avatar
文幕地方 已提交
96 97 98
    if (!FLAGS_benchmark) {
      cout << "The predict img: " << cv_all_img_names[i] << endl;
    }
文幕地方's avatar
文幕地方 已提交
99 100 101 102 103 104 105 106 107 108 109

    cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
    if (!srcimg.data) {
      std::cerr << "[ERROR] image read failed! image path: "
                << cv_all_img_names[i] << endl;
      exit(1);
    }
    std::vector<std::vector<std::vector<int>>> boxes;
    std::vector<double> det_times;

    det.Run(srcimg, boxes, &det_times);
文幕地方's avatar
文幕地方 已提交
110
    // visualization
文幕地方's avatar
文幕地方 已提交
111 112 113 114
    if (FLAGS_visualize) {
      std::string file_name = Utility::basename(cv_all_img_names[i]);
      Utility::VisualizeBboxes(srcimg, boxes, FLAGS_output + "/" + file_name);
    }
文幕地方's avatar
文幕地方 已提交
115 116 117 118
    time_info[0] += det_times[0];
    time_info[1] += det_times[1];
    time_info[2] += det_times[2];

M
MissPenguin 已提交
119
    if (FLAGS_benchmark) {
文幕地方's avatar
文幕地方 已提交
120 121 122 123 124 125 126
      cout << cv_all_img_names[i] << '\t';
      for (int n = 0; n < boxes.size(); n++) {
        for (int m = 0; m < boxes[n].size(); m++) {
          cout << boxes[n][m][0] << ' ' << boxes[n][m][1] << ' ';
        }
      }
      cout << endl;
M
MissPenguin 已提交
127
    }
文幕地方's avatar
文幕地方 已提交
128
  }
M
MissPenguin 已提交
129

文幕地方's avatar
文幕地方 已提交
130 131 132 133 134 135 136 137
  if (FLAGS_benchmark) {
    AutoLogger autolog("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
                       FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
                       FLAGS_precision, time_info, cv_all_img_names.size());
    autolog.report();
  }
  return 0;
}
M
MissPenguin 已提交
138

M
MissPenguin 已提交
139
int main_rec(std::vector<cv::String> cv_all_img_names) {
文幕地方's avatar
文幕地方 已提交
140
  std::vector<double> time_info = {0, 0, 0};
M
MissPenguin 已提交
141

文幕地方's avatar
文幕地方 已提交
142
  std::string rec_char_dict_path = FLAGS_rec_char_dict_path;
文幕地方's avatar
文幕地方 已提交
143
  if (FLAGS_benchmark)
文幕地方's avatar
文幕地方 已提交
144 145
    rec_char_dict_path = FLAGS_rec_char_dict_path.substr(6);
  cout << "label file: " << rec_char_dict_path << endl;
M
MissPenguin 已提交
146

文幕地方's avatar
文幕地方 已提交
147 148
  CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
                     FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
文幕地方's avatar
文幕地方 已提交
149
                     rec_char_dict_path, FLAGS_use_tensorrt, FLAGS_precision,
文幕地方's avatar
文幕地方 已提交
150 151 152 153 154 155 156 157 158
                     FLAGS_rec_batch_num);

  std::vector<cv::Mat> img_list;
  for (int i = 0; i < cv_all_img_names.size(); ++i) {
    cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
    if (!srcimg.data) {
      std::cerr << "[ERROR] image read failed! image path: "
                << cv_all_img_names[i] << endl;
      exit(1);
M
MissPenguin 已提交
159
    }
文幕地方's avatar
文幕地方 已提交
160 161
    img_list.push_back(srcimg);
  }
文幕地方's avatar
文幕地方 已提交
162 163
  std::vector<std::string> rec_texts(img_list.size(), "");
  std::vector<float> rec_text_scores(img_list.size(), 0);
文幕地方's avatar
文幕地方 已提交
164
  std::vector<double> rec_times;
文幕地方's avatar
文幕地方 已提交
165 166 167 168 169 170
  rec.Run(img_list, rec_texts, rec_text_scores, &rec_times);
  // output rec results
  for (int i = 0; i < rec_texts.size(); i++) {
    cout << "The predict img: " << cv_all_img_names[i] << "\t" << rec_texts[i]
         << "\t" << rec_text_scores[i] << endl;
  }
文幕地方's avatar
文幕地方 已提交
171 172 173
  time_info[0] += rec_times[0];
  time_info[1] += rec_times[1];
  time_info[2] += rec_times[2];
M
MissPenguin 已提交
174

文幕地方's avatar
文幕地方 已提交
175 176 177 178 179 180 181 182 183
  if (FLAGS_benchmark) {
    AutoLogger autolog("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
                       FLAGS_enable_mkldnn, FLAGS_cpu_threads,
                       FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
                       time_info, cv_all_img_names.size());
    autolog.report();
  }
  return 0;
}
M
MissPenguin 已提交
184

M
MissPenguin 已提交
185
int main_system(std::vector<cv::String> cv_all_img_names) {
文幕地方's avatar
文幕地方 已提交
186 187
  std::vector<double> time_info_det = {0, 0, 0};
  std::vector<double> time_info_rec = {0, 0, 0};
M
MissPenguin 已提交
188

文幕地方's avatar
文幕地方 已提交
189 190 191 192
  if (!PathExists(FLAGS_output)) {
    mkdir(FLAGS_output.c_str(), 0777);
  }

文幕地方's avatar
文幕地方 已提交
193 194 195 196
  DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
                 FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
                 FLAGS_max_side_len, FLAGS_det_db_thresh,
                 FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
文幕地方's avatar
文幕地方 已提交
197
                 FLAGS_use_polygon_score, FLAGS_use_dilation,
文幕地方's avatar
文幕地方 已提交
198
                 FLAGS_use_tensorrt, FLAGS_precision);
文幕地方's avatar
文幕地方 已提交
199 200 201 202 203 204 205 206

  Classifier *cls = nullptr;
  if (FLAGS_use_angle_cls) {
    cls = new Classifier(FLAGS_cls_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
                         FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
                         FLAGS_cls_thresh, FLAGS_use_tensorrt, FLAGS_precision);
  }

文幕地方's avatar
文幕地方 已提交
207
  std::string rec_char_dict_path = FLAGS_rec_char_dict_path;
文幕地方's avatar
文幕地方 已提交
208
  if (FLAGS_benchmark)
文幕地方's avatar
文幕地方 已提交
209 210
    rec_char_dict_path = FLAGS_rec_char_dict_path.substr(6);
  cout << "label file: " << rec_char_dict_path << endl;
文幕地方's avatar
文幕地方 已提交
211 212 213

  CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
                     FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
文幕地方's avatar
文幕地方 已提交
214
                     rec_char_dict_path, FLAGS_use_tensorrt, FLAGS_precision,
文幕地方's avatar
文幕地方 已提交
215 216 217
                     FLAGS_rec_batch_num);

  for (int i = 0; i < cv_all_img_names.size(); ++i) {
文幕地方's avatar
文幕地方 已提交
218
    cout << "The predict img: " << cv_all_img_names[i] << endl;
M
MissPenguin 已提交
219

文幕地方's avatar
文幕地方 已提交
220 221 222 223 224
    cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
    if (!srcimg.data) {
      std::cerr << "[ERROR] image read failed! image path: "
                << cv_all_img_names[i] << endl;
      exit(1);
M
MissPenguin 已提交
225
    }
文幕地方's avatar
文幕地方 已提交
226
    // det
文幕地方's avatar
文幕地方 已提交
227 228 229
    std::vector<std::vector<std::vector<int>>> boxes;
    std::vector<double> det_times;
    std::vector<double> rec_times;
M
MissPenguin 已提交
230

文幕地方's avatar
文幕地方 已提交
231
    det.Run(srcimg, boxes, &det_times);
文幕地方's avatar
文幕地方 已提交
232 233 234 235
    if (FLAGS_visualize) {
      std::string file_name = Utility::basename(cv_all_img_names[i]);
      Utility::VisualizeBboxes(srcimg, boxes, FLAGS_output + "/" + file_name);
    }
文幕地方's avatar
文幕地方 已提交
236 237 238
    time_info_det[0] += det_times[0];
    time_info_det[1] += det_times[1];
    time_info_det[2] += det_times[2];
M
MissPenguin 已提交
239

文幕地方's avatar
文幕地方 已提交
240
    // rec
文幕地方's avatar
文幕地方 已提交
241 242 243 244 245 246 247 248
    std::vector<cv::Mat> img_list;
    for (int j = 0; j < boxes.size(); j++) {
      cv::Mat crop_img;
      crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]);
      if (cls != nullptr) {
        crop_img = cls->Run(crop_img);
      }
      img_list.push_back(crop_img);
M
MissPenguin 已提交
249
    }
文幕地方's avatar
文幕地方 已提交
250 251 252 253 254 255 256 257
    std::vector<std::string> rec_texts(img_list.size(), "");
    std::vector<float> rec_text_scores(img_list.size(), 0);
    rec.Run(img_list, rec_texts, rec_text_scores, &rec_times);
    // output rec results
    for (int i = 0; i < rec_texts.size(); i++) {
      std::cout << i << "\t" << rec_texts[i] << "\t" << rec_text_scores[i]
                << std::endl;
    }
文幕地方's avatar
文幕地方 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    time_info_rec[0] += rec_times[0];
    time_info_rec[1] += rec_times[1];
    time_info_rec[2] += rec_times[2];
  }

  if (FLAGS_benchmark) {
    AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
                           FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
                           FLAGS_precision, time_info_det,
                           cv_all_img_names.size());
    AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
                           FLAGS_enable_mkldnn, FLAGS_cpu_threads,
                           FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
                           time_info_rec, cv_all_img_names.size());
    autolog_det.report();
    std::cout << endl;
    autolog_rec.report();
  }
  return 0;
}

void check_params(char *mode) {
  if (strcmp(mode, "det") == 0) {
    if (FLAGS_det_model_dir.empty() || FLAGS_image_dir.empty()) {
      std::cout << "Usage[det]: ./ppocr "
                   "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
                << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
      exit(1);
M
MissPenguin 已提交
286
    }
文幕地方's avatar
文幕地方 已提交
287 288 289 290 291 292 293
  }
  if (strcmp(mode, "rec") == 0) {
    if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
      std::cout << "Usage[rec]: ./ppocr "
                   "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
                << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
      exit(1);
M
MissPenguin 已提交
294
    }
文幕地方's avatar
文幕地方 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
  }
  if (strcmp(mode, "system") == 0) {
    if ((FLAGS_det_model_dir.empty() || FLAGS_rec_model_dir.empty() ||
         FLAGS_image_dir.empty()) ||
        (FLAGS_use_angle_cls && FLAGS_cls_model_dir.empty())) {
      std::cout << "Usage[system without angle cls]: ./ppocr "
                   "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
                << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
                << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
      std::cout << "Usage[system with angle cls]: ./ppocr "
                   "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
                << "--use_angle_cls=true "
                << "--cls_model_dir=/PATH/TO/CLS_INFERENCE_MODEL/ "
                << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
                << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
      exit(1);
M
MissPenguin 已提交
311
    }
文幕地方's avatar
文幕地方 已提交
312 313 314 315 316 317
  }
  if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" &&
      FLAGS_precision != "int8") {
    cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl;
    exit(1);
  }
M
MissPenguin 已提交
318 319
}

M
MissPenguin 已提交
320
int main(int argc, char **argv) {
文幕地方's avatar
文幕地方 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
  if (argc <= 1 ||
      (strcmp(argv[1], "det") != 0 && strcmp(argv[1], "rec") != 0 &&
       strcmp(argv[1], "system") != 0)) {
    std::cout << "Please choose one mode of [det, rec, system] !" << std::endl;
    return -1;
  }
  std::cout << "mode: " << argv[1] << endl;

  // Parsing command-line
  google::ParseCommandLineFlags(&argc, &argv, true);
  check_params(argv[1]);

  if (!PathExists(FLAGS_image_dir)) {
    std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
              << endl;
    exit(1);
  }

  std::vector<cv::String> cv_all_img_names;
  cv::glob(FLAGS_image_dir, cv_all_img_names);
  std::cout << "total images num: " << cv_all_img_names.size() << endl;
M
MissPenguin 已提交
342

文幕地方's avatar
文幕地方 已提交
343 344 345 346 347 348 349 350 351
  if (strcmp(argv[1], "det") == 0) {
    return main_det(cv_all_img_names);
  }
  if (strcmp(argv[1], "rec") == 0) {
    return main_rec(cv_all_img_names);
  }
  if (strcmp(argv[1], "system") == 0) {
    return main_system(cv_all_img_names);
  }
M
MissPenguin 已提交
352
}