main.cpp 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <opencv2/core/utils/filesystem.hpp>
#include <ostream>
#include <vector>

#include <cstring>
#include <fstream>
#include <numeric>

#include <auto_log/autolog.h>
D
dongshuilong 已提交
30
#include <gflags/gflags.h>
D
dongshuilong 已提交
31
#include <include/feature_extracter.h>
D
dongshuilong 已提交
32
#include <include/nms.h>
33
#include <include/object_detector.h>
D
dongshuilong 已提交
34
#include <include/vector_search.h>
D
dongshuilong 已提交
35
#include <include/yaml_config.h>
36 37 38

using namespace std;
using namespace cv;
39

D
dongshuilong 已提交
40 41 42
DEFINE_string(config, "", "Path of yaml file");
DEFINE_string(c, "", "Path of yaml file");

43 44
void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
                     const std::vector<std::string> &all_img_paths,
D
dongshuilong 已提交
45 46
                     const int batch_size, Detection::ObjectDetector *det,
                     std::vector<Detection::ObjectResult> &im_result,
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
                     std::vector<int> &im_bbox_num, std::vector<double> &det_t,
                     const bool visual_det = false,
                     const bool run_benchmark = false,
                     const std::string &output_dir = "output") {
  int steps = ceil(float(all_img_paths.size()) / batch_size);
  //   printf("total images = %d, batch_size = %d, total steps = %d\n",
  //                 all_img_paths.size(), batch_size, steps);
  for (int idx = 0; idx < steps; idx++) {
    int left_image_cnt = all_img_paths.size() - idx * batch_size;
    if (left_image_cnt > batch_size) {
      left_image_cnt = batch_size;
    }
    // for (int bs = 0; bs < left_image_cnt; bs++) {
    // std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
    // cv::Mat im = cv::imread(image_file_path, 1);
    // batch_imgs.insert(batch_imgs.end(), im);
    // }

    // Store all detected result
D
dongshuilong 已提交
66
    std::vector<Detection::ObjectResult> result;
67 68 69 70 71 72 73 74 75
    std::vector<int> bbox_num;
    std::vector<double> det_times;
    bool is_rbox = false;
    if (run_benchmark) {
      det->Predict(batch_imgs, 10, 10, &result, &bbox_num, &det_times);
    } else {
      det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
      // get labels and colormap
      auto labels = det->GetLabelList();
D
dongshuilong 已提交
76
      auto colormap = Detection::GenerateColorMap(labels.size());
77 78 79 80 81 82 83

      int item_start_idx = 0;
      for (int i = 0; i < left_image_cnt; i++) {
        cv::Mat im = batch_imgs[i];
        int detect_num = 0;

        for (int j = 0; j < bbox_num[i]; j++) {
D
dongshuilong 已提交
84
          Detection::ObjectResult item = result[item_start_idx + j];
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
          if (item.confidence < det->GetThreshold() || item.class_id == -1) {
            continue;
          }
          detect_num += 1;
          im_result.push_back(item);
          if (visual_det) {
            if (item.rect.size() > 6) {
              is_rbox = true;
              printf(
                  "class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
                  item.class_id, item.confidence, item.rect[0], item.rect[1],
                  item.rect[2], item.rect[3], item.rect[4], item.rect[5],
                  item.rect[6], item.rect[7]);
            } else {
              printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
                     item.class_id, item.confidence, item.rect[0], item.rect[1],
                     item.rect[2], item.rect[3]);
            }
          }
        }
        im_bbox_num.push_back(detect_num);
        item_start_idx = item_start_idx + bbox_num[i];

        // Visualization result
        if (visual_det) {
          std::cout << all_img_paths.at(idx * batch_size + i)
                    << " The number of detected box: " << detect_num
                    << std::endl;
D
dongshuilong 已提交
113 114
          cv::Mat vis_img = Detection::VisualizeResult(im, im_result, labels,
                                                       colormap, is_rbox);
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
          std::vector<int> compression_params;
          compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
          compression_params.push_back(95);
          std::string output_path(output_dir);
          if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
            output_path += OS_PATH_SEP;
          }
          std::string image_file_path = all_img_paths.at(idx * batch_size + i);
          output_path +=
              image_file_path.substr(image_file_path.find_last_of('/') + 1);
          cv::imwrite(output_path, vis_img, compression_params);
          printf("Visualized output saved as %s\n", output_path.c_str());
        }
      }
    }
    det_t[0] += det_times[0];
    det_t[1] += det_times[1];
    det_t[2] += det_times[2];
  }
}
135

D
dongshuilong 已提交
136
void PrintResult(std::string &img_path,
D
dongshuilong 已提交
137
                 std::vector<Detection::ObjectResult> &det_result,
D
dongshuilong 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150
                 std::vector<int> &indeices, VectorSearch &vector_search,
                 SearchResult &search_result) {
  printf("%s:\n", img_path.c_str());
  for (int i = 0; i < indeices.size(); ++i) {
    int t = indeices[i];
    printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
           det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
           det_result[t].rect[3], det_result[t].confidence,
           vector_search.GetLabel(search_result.I[search_result.return_k * t])
               .c_str());
  }
}

151
int main(int argc, char **argv) {
D
dongshuilong 已提交
152 153 154 155 156 157 158
  google::ParseCommandLineFlags(&argc, &argv, true);
  std::string yaml_path = "";
  if (FLAGS_config == "" && FLAGS_c == "") {
    std::cerr << "[ERROR] usage: " << std::endl
              << argv[0] << " -c $yaml_path" << std::endl
              << "or:" << std::endl
              << argv[0] << " -config $yaml_path" << std::endl;
159
    exit(1);
D
dongshuilong 已提交
160 161 162 163
  } else if (FLAGS_config != "") {
    yaml_path = FLAGS_config;
  } else {
    yaml_path = FLAGS_c;
164 165
  }

D
dongshuilong 已提交
166
  YamlConfig config(yaml_path);
167 168
  config.PrintConfigInfo();

D
dongshuilong 已提交
169
  // initialize detector, rec_Model, vector_search
D
dongshuilong 已提交
170 171
  Feature::FeatureExtracter feature_extracter(config.config_file);
  Detection::ObjectDetector detector(config.config_file);
D
dongshuilong 已提交
172 173
  VectorSearch searcher(config.config_file);

174 175 176 177 178 179 180 181 182 183 184 185 186 187
  // config
  const int batch_size = config.config_file["Global"]["batch_size"].as<int>();
  bool visual_det = false;
  if (config.config_file["Global"]["visual_det"].IsDefined()) {
    visual_det = config.config_file["Global"]["visual_det"].as<bool>();
  }
  bool run_benchmark = false;
  if (config.config_file["Global"]["benchmark"].IsDefined()) {
    run_benchmark = config.config_file["Global"]["benchmark"].as<bool>();
  }
  int max_det_results = 5;
  if (config.config_file["Global"]["max_det_results"].IsDefined()) {
    max_det_results = config.config_file["Global"]["max_det_results"].as<int>();
  }
D
dongshuilong 已提交
188 189 190 191 192
  float rec_nms_thresold = 0.05;
  if (config.config_file["Global"]["rec_nms_thresold"].IsDefined()) {
    rec_nms_thresold =
        config.config_file["Global"]["rec_nms_thresold"].as<float>();
  }
193

D
dongshuilong 已提交
194
  // load image_file_path
195 196 197 198 199 200 201 202 203 204 205 206 207
  std::string path =
      config.config_file["Global"]["infer_imgs"].as<std::string>();
  std::vector<std::string> img_files_list;
  if (cv::utils::fs::isDirectory(path)) {
    std::vector<cv::String> filenames;
    cv::glob(path, filenames);
    for (auto f : filenames) {
      img_files_list.push_back(f);
    }
  } else {
    img_files_list.push_back(path);
  }
  std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
D
dongshuilong 已提交
208
  // for time log
209 210
  std::vector<double> cls_times = {0, 0, 0};
  std::vector<double> det_times = {0, 0, 0};
D
dongshuilong 已提交
211
  // for read images
212 213
  std::vector<cv::Mat> batch_imgs;
  std::vector<std::string> img_paths;
D
dongshuilong 已提交
214
  // for detection
D
dongshuilong 已提交
215
  std::vector<Detection::ObjectResult> det_result;
216
  std::vector<int> det_bbox_num;
D
dongshuilong 已提交
217
  // for vector search
D
dongshuilong 已提交
218 219
  std::vector<float> features;
  std::vector<float> feature;
D
dongshuilong 已提交
220 221
  // for nms
  std::vector<int> indeices;
222

223 224 225 226 227 228 229 230 231 232 233
  int warmup_iter = img_files_list.size() > 5 ? 5 : 0;
  for (int idx = 0; idx < img_files_list.size(); ++idx) {
    std::string img_path = img_files_list[idx];
    cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
    if (!srcimg.data) {
      std::cerr << "[ERROR] image read failed! image path: " << img_path
                << "\n";
      exit(-1);
    }
    cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);

234 235 236 237 238 239 240 241
    batch_imgs.push_back(srcimg);
    img_paths.push_back(img_path);

    // step1: get all detection results
    DetPredictImage(batch_imgs, img_paths, batch_size, &detector, det_result,
                    det_bbox_num, det_times, visual_det, run_benchmark);

    // select max_det_results bbox
D
dongshuilong 已提交
242 243
    if (det_result.size() > max_det_results) {
      det_result.resize(max_det_results);
244 245
    }
    // step2: add the whole image for recognition to improve recall
D
dongshuilong 已提交
246
    Detection::ObjectResult result_whole_img = {
247 248 249 250
        {0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
    det_result.push_back(result_whole_img);
    det_bbox_num[0] = det_result.size() + 1;

D
dongshuilong 已提交
251 252
    // step3: extract feature for all boxes in an inmage
    SearchResult search_result;
253 254 255 256 257
    for (int j = 0; j < det_result.size(); ++j) {
      int w = det_result[j].rect[2] - det_result[j].rect[0];
      int h = det_result[j].rect[3] - det_result[j].rect[1];
      cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
      cv::Mat crop_img = srcimg(rect);
D
dongshuilong 已提交
258
      feature_extracter.Run(crop_img, feature, cls_times);
D
dongshuilong 已提交
259
      features.insert(features.end(), feature.begin(), feature.end());
260
    }
D
dongshuilong 已提交
261 262 263 264 265

    // step4: get search result
    search_result = searcher.Search(features.data(), det_result.size());

    // nms for search result
D
dongshuilong 已提交
266 267 268
    for (int i = 0; i < det_result.size(); ++i) {
      det_result[i].confidence = search_result.D[search_result.return_k * i];
    }
D
dongshuilong 已提交
269
    NMSBoxes(det_result, searcher.GetThreshold(), rec_nms_thresold, indeices);
D
dongshuilong 已提交
270 271 272

    // print result
    PrintResult(img_path, det_result, indeices, searcher, search_result);
D
dongshuilong 已提交
273 274

    // for postprocess
275 276 277 278
    batch_imgs.clear();
    img_paths.clear();
    det_bbox_num.clear();
    det_result.clear();
D
dongshuilong 已提交
279 280
    feature.clear();
    features.clear();
D
dongshuilong 已提交
281
    indeices.clear();
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
  }

  std::string presion = "fp32";

  // if (config.use_fp16)
  //   presion = "fp16";
  // if (config.benchmark) {
  //   AutoLogger autolog("Classification", config.use_gpu, config.use_tensorrt,
  //                      config.use_mkldnn, config.cpu_threads, 1,
  //                      "1, 3, 224, 224", presion, cls_times,
  //                      img_files_list.size());
  //   autolog.report();
  // }
  return 0;
}