main.cc 7.5 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <cmath>
#include <iostream>
#include <math.h>
#include <numeric>
#include <stdarg.h>
#include <string>
#include <sys/stat.h>
#include <sys/types.h>
#include <vector>

#include "include/config_parser.h"
#include "include/object_detector.h"
#include "include/preprocess_op.h"
#include "include/recognition.h"
#include "json/json.h"

Json::Value RT_Config;

static std::string DirName(const std::string &filepath) {
  auto pos = filepath.rfind(OS_PATH_SEP);
  if (pos == std::string::npos) {
    return "";
  }
  return filepath.substr(0, pos);
}

static bool PathExists(const std::string &path) {
  struct stat buffer;
  return (stat(path.c_str(), &buffer) == 0);
}

static void MkDir(const std::string &path) {
  if (PathExists(path))
    return;
  int ret = 0;
  ret = mkdir(path.c_str(), 0755);
  if (ret != 0) {
    std::string path_error(path);
    path_error += " mkdir failed!";
    throw std::runtime_error(path_error);
  }
}

static void MkDirs(const std::string &path) {
  if (path.empty())
    return;
  if (PathExists(path))
    return;

  MkDirs(DirName(path));
  MkDir(path);
}

void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
                     std::vector<PPShiTu::ObjectResult> &im_result,
                     const int batch_size_det, const int max_det_num,
                     const bool run_benchmark, PPShiTu::ObjectDetector *det) {
  std::vector<double> det_t = {0, 0, 0};
  int steps = ceil(float(batch_imgs.size()) / batch_size_det);
  for (int idx = 0; idx < steps; idx++) {
    int left_image_cnt = batch_imgs.size() - idx * batch_size_det;
    if (left_image_cnt > batch_size_det) {
      left_image_cnt = batch_size_det;
    }
    // Store all detected result
    std::vector<PPShiTu::ObjectResult> result;
    std::vector<int> bbox_num;
    std::vector<double> det_times;

    bool is_rbox = false;
    if (run_benchmark) {
      det->Predict(batch_imgs, 50, 50, &result, &bbox_num, &det_times);
    } else {
      det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
    }

    int item_start_idx = 0;
    for (int i = 0; i < left_image_cnt; i++) {
      cv::Mat im = batch_imgs[i];
      // std::vector<PPShiTu::ObjectResult> im_result;
      int detect_num = 0;
      for (int j = 0; j < min(bbox_num[i], max_det_num); j++) {
        PPShiTu::ObjectResult item = result[item_start_idx + j];
        if (item.class_id == -1) {
          continue;
        }
        detect_num += 1;
        im_result.push_back(item);
      }
      item_start_idx = item_start_idx + bbox_num[i];
    }

    det_t[0] += det_times[0];
    det_t[1] += det_times[1];
    det_t[2] += det_times[2];
  }
}

void PrintResult(const std::string &image_path,
115
                 std::vector<PPShiTu::ObjectResult> &det_result) {
D
dongshuilong 已提交
116
  printf("%s:\n", image_path.c_str());
D
dongshuilong 已提交
117 118 119
  for (int i = 0; i < det_result.size(); ++i) {
    printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
           det_result[i].rect[0], det_result[i].rect[1], det_result[i].rect[2],
120 121
           det_result[i].rect[3], det_result[i].rec_result[0].score,
           det_result[i].rec_result[0].class_name.c_str());
D
dongshuilong 已提交
122 123 124 125 126 127 128 129 130 131 132
  }
}

int main(int argc, char **argv) {
  std::cout << "Usage: " << argv[0]
            << " [config_path](option) [image_dir](option)\n";
  if (argc < 2) {
    std::cout << "Usage: ./main det_runtime_config.json" << std::endl;
    return -1;
  }
  std::string config_path = argv[1];
133
  std::string img_dir = "";
D
dongshuilong 已提交
134 135

  if (argc >= 3) {
136
    img_dir = argv[2];
D
dongshuilong 已提交
137 138 139
  }
  // Parsing command-line
  PPShiTu::load_jsonf(config_path, RT_Config);
140 141
  if (RT_Config["Global"]["det_model_path"].as<std::string>().empty()) {
    std::cout << "Please set [det_model_path] in " << config_path << std::endl;
D
dongshuilong 已提交
142 143 144
    return -1;
  }
  if (RT_Config["Global"]["infer_imgs"].as<std::string>().empty() &&
145
      img_dir.empty()) {
D
dongshuilong 已提交
146 147 148 149 150
    std::cout << "Please set [infer_imgs] in " << config_path
              << " Or use command: <" << argv[0] << " [shitu_config]"
              << " [image_dir]>" << std::endl;
    return -1;
  }
151
  if (!img_dir.empty()) {
D
dongshuilong 已提交
152 153
    std::cout << "Use image_dir in command line overide the path in config file"
              << std::endl;
154
    RT_Config["Global"]["infer_imgs_dir"] = img_dir;
D
dongshuilong 已提交
155 156 157 158
    RT_Config["Global"]["infer_imgs"] = "";
  }
  // Load model and create a object detector
  PPShiTu::ObjectDetector det(
159
      RT_Config, RT_Config["Global"]["det_model_path"].as<std::string>(),
D
dongshuilong 已提交
160 161 162 163 164 165 166 167 168 169 170 171
      RT_Config["Global"]["cpu_num_threads"].as<int>(),
      RT_Config["Global"]["batch_size"].as<int>());
  // create rec model
  PPShiTu::Recognition rec(RT_Config);
  // Do inference on input image

  std::vector<PPShiTu::ObjectResult> det_result;
  std::vector<cv::Mat> batch_imgs;
  double rec_time;
  if (!RT_Config["Global"]["infer_imgs"].as<std::string>().empty() ||
      !RT_Config["Global"]["infer_imgs_dir"].as<std::string>().empty()) {
    std::vector<std::string> all_img_paths;
D
dongshuilong 已提交
172
    std::vector<cv::String> cv_all_img_paths;
D
dongshuilong 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
    if (!RT_Config["Global"]["infer_imgs"].as<std::string>().empty()) {
      all_img_paths.push_back(
          RT_Config["Global"]["infer_imgs"].as<std::string>());
      if (RT_Config["Global"]["batch_size"].as<int>() > 1) {
        std::cout << "batch_size_det should be 1, when set `image_file`."
                  << std::endl;
        return -1;
      }
    } else {
      cv::glob(RT_Config["Global"]["infer_imgs_dir"].as<std::string>(),
               cv_all_img_paths);
      for (const auto &img_path : cv_all_img_paths) {
        all_img_paths.push_back(img_path);
      }
    }
    for (int i = 0; i < all_img_paths.size(); ++i) {
D
dongshuilong 已提交
189
      std::string img_path = all_img_paths[i];
D
dongshuilong 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203
      cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
      if (!srcimg.data) {
        std::cerr << "[ERROR] image read failed! image path: " << img_path
                  << "\n";
        exit(-1);
      }
      cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
      batch_imgs.push_back(srcimg);
      DetPredictImage(
          batch_imgs, det_result, RT_Config["Global"]["batch_size"].as<int>(),
          RT_Config["Global"]["max_det_results"].as<int>(), false, &det);

      // add the whole image for recognition to improve recall
      PPShiTu::ObjectResult result_whole_img = {
204
          {0, 0, srcimg.cols, srcimg.rows}, 0, 1.0};
D
dongshuilong 已提交
205 206 207 208 209 210 211 212 213 214
      det_result.push_back(result_whole_img);

      // get rec result
      for (int j = 0; j < det_result.size(); ++j) {
        int w = det_result[j].rect[2] - det_result[j].rect[0];
        int h = det_result[j].rect[3] - det_result[j].rect[1];
        cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
        cv::Mat crop_img = srcimg(rect);
        std::vector<PPShiTu::RESULT> result =
            rec.RunRecModel(crop_img, rec_time);
215
        det_result[j].rec_result.assign(result.begin(), result.end());
D
dongshuilong 已提交
216
      }
217 218 219 220
      // rec nms
      PPShiTu::nms(det_result,
                   RT_Config["Global"]["rec_nms_thresold"].as<float>(), true);
      PrintResult(img_path, det_result);
D
dongshuilong 已提交
221 222 223 224 225 226
      batch_imgs.clear();
      det_result.clear();
    }
  }
  return 0;
}