main.cc 14.9 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <glog/logging.h>

17 18 19 20
#include <math.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <algorithm>
Q
qingqing01 已提交
21
#include <iostream>
22
#include <numeric>
Q
qingqing01 已提交
23 24 25 26 27 28 29 30 31 32 33
#include <string>
#include <vector>

#ifdef _WIN32
#include <direct.h>
#include <io.h>
#elif LINUX
#include <stdarg.h>
#include <sys/stat.h>
#endif

34
#include <gflags/gflags.h>
35
#include "include/object_detector.h"
Q
qingqing01 已提交
36

37
DEFINE_string(model_dir, "", "Path of inference model");
G
Guanghua Yu 已提交
38
DEFINE_string(image_file, "", "Path of input image");
39 40 41
DEFINE_string(image_dir,
              "",
              "Dir of input image, `image_file` has a higher priority.");
42
DEFINE_int32(batch_size, 1, "batch_size");
43 44 45 46
DEFINE_string(
    video_file,
    "",
    "Path of input video, `video_file` or `camera_id` has a highest priority.");
G
Guanghua Yu 已提交
47
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
48 49 50 51 52 53 54 55
DEFINE_bool(
    use_gpu,
    false,
    "Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device,
              "CPU",
              "Choose the device you want to run, it can be: CPU/GPU/XPU, "
              "default is CPU.");
56
DEFINE_double(threshold, 0.5, "Threshold of score.");
G
Guanghua Yu 已提交
57
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
58 59 60
DEFINE_string(run_mode,
              "paddle",
              "Mode of running(paddle/trt_fp32/trt_fp16/trt_int8)");
Q
qingqing01 已提交
61
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
62 63 64
DEFINE_bool(run_benchmark,
            false,
            "Whether to predict a image_file repeatedly for benchmark");
65
DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU");
G
Guanghua Yu 已提交
66
DEFINE_int32(cpu_threads, 1, "Num of threads with CPU");
67 68 69
DEFINE_int32(trt_min_shape, 1, "Min shape of TRT DynamicShapeI");
DEFINE_int32(trt_max_shape, 1280, "Max shape of TRT DynamicShapeI");
DEFINE_int32(trt_opt_shape, 640, "Opt shape of TRT DynamicShapeI");
70 71 72 73
DEFINE_bool(trt_calib_mode,
            false,
            "If the model is produced by TRT offline quantitative calibration, "
            "trt_calib_mode need to set True");
G
Guanghua Yu 已提交
74

75
void PrintBenchmarkLog(std::vector<double> det_time, int img_num) {
G
Guanghua Yu 已提交
76
  LOG(INFO) << "----------------------- Config info -----------------------";
G
Guanghua Yu 已提交
77
  LOG(INFO) << "runtime_device: " << FLAGS_device;
78 79 80 81
  LOG(INFO) << "ir_optim: "
            << "True";
  LOG(INFO) << "enable_memory_optim: "
            << "True";
G
Guanghua Yu 已提交
82 83
  int has_trt = FLAGS_run_mode.find("trt");
  if (has_trt >= 0) {
84 85
    LOG(INFO) << "enable_tensorrt: "
              << "True";
G
Guanghua Yu 已提交
86 87 88
    std::string precision = FLAGS_run_mode.substr(4, 8);
    LOG(INFO) << "precision: " << precision;
  } else {
89 90 91 92
    LOG(INFO) << "enable_tensorrt: "
              << "False";
    LOG(INFO) << "precision: "
              << "fp32";
G
Guanghua Yu 已提交
93
  }
94
  LOG(INFO) << "enable_mkldnn: " << (FLAGS_use_mkldnn ? "True" : "False");
G
Guanghua Yu 已提交
95 96
  LOG(INFO) << "cpu_math_library_num_threads: " << FLAGS_cpu_threads;
  LOG(INFO) << "----------------------- Data info -----------------------";
97
  LOG(INFO) << "batch_size: " << FLAGS_batch_size;
98 99
  LOG(INFO) << "input_shape: "
            << "dynamic shape";
G
Guanghua Yu 已提交
100
  LOG(INFO) << "----------------------- Model info -----------------------";
101
  FLAGS_model_dir.erase(FLAGS_model_dir.find_last_not_of("/") + 1);
102 103
  LOG(INFO) << "model_name: "
            << FLAGS_model_dir.substr(FLAGS_model_dir.find_last_of('/') + 1);
G
Guanghua Yu 已提交
104 105
  LOG(INFO) << "----------------------- Perf info ------------------------";
  LOG(INFO) << "Total number of predicted data: " << img_num
G
Guanghua Yu 已提交
106
            << " and total time spent(ms): "
G
Guanghua Yu 已提交
107 108 109
            << std::accumulate(det_time.begin(), det_time.end(), 0);
  LOG(INFO) << "preproce_time(ms): " << det_time[0] / img_num
            << ", inference_time(ms): " << det_time[1] / img_num
110
            << ", postprocess_time(ms): " << det_time[2] / img_num;
G
Guanghua Yu 已提交
111
}
Q
qingqing01 已提交
112

113
static std::string DirName(const std::string& filepath) {
Q
qingqing01 已提交
114 115 116 117 118 119 120
  auto pos = filepath.rfind(OS_PATH_SEP);
  if (pos == std::string::npos) {
    return "";
  }
  return filepath.substr(0, pos);
}

121
static bool PathExists(const std::string& path) {
Q
qingqing01 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
#ifdef _WIN32
  struct _stat buffer;
  return (_stat(path.c_str(), &buffer) == 0);
#else
  struct stat buffer;
  return (stat(path.c_str(), &buffer) == 0);
#endif  // !_WIN32
}

static void MkDir(const std::string& path) {
  if (PathExists(path)) return;
  int ret = 0;
#ifdef _WIN32
  ret = _mkdir(path.c_str());
#else
  ret = mkdir(path.c_str(), 0755);
#endif  // !_WIN32
  if (ret != 0) {
    std::string path_error(path);
    path_error += " mkdir failed!";
    throw std::runtime_error(path_error);
  }
}

static void MkDirs(const std::string& path) {
  if (path.empty()) return;
  if (PathExists(path)) return;

  MkDirs(DirName(path));
  MkDir(path);
}

void PredictVideo(const std::string& video_path,
155 156
                  PaddleDetection::ObjectDetector* det,
                  const std::string& output_dir = "output") {
Q
qingqing01 已提交
157 158
  // Open video
  cv::VideoCapture capture;
159
  std::string video_out_name = "output.mp4";
160
  if (FLAGS_camera_id != -1) {
Q
qingqing01 已提交
161
    capture.open(FLAGS_camera_id);
162
  } else {
Q
qingqing01 已提交
163
    capture.open(video_path.c_str());
164 165
    video_out_name =
        video_path.substr(video_path.find_last_of(OS_PATH_SEP) + 1);
Q
qingqing01 已提交
166 167 168 169 170 171
  }
  if (!capture.isOpened()) {
    printf("can not open video : %s\n", video_path.c_str());
    return;
  }

172
  // Get Video info : resolution, fps, frame count
Q
qingqing01 已提交
173 174 175
  int video_width = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_WIDTH));
  int video_height = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_HEIGHT));
  int video_fps = static_cast<int>(capture.get(CV_CAP_PROP_FPS));
176 177
  int video_frame_count =
      static_cast<int>(capture.get(CV_CAP_PROP_FRAME_COUNT));
178
  printf("fps: %d, frame_count: %d\n", video_fps, video_frame_count);
Q
qingqing01 已提交
179 180 181

  // Create VideoWriter for output
  cv::VideoWriter video_out;
182 183 184 185 186
  std::string video_out_path(output_dir);
  if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
    video_out_path += OS_PATH_SEP;
  }
  video_out_path += video_out_name;
Q
qingqing01 已提交
187 188 189 190 191 192 193 194 195 196 197
  video_out.open(video_out_path.c_str(),
                 0x00000021,
                 video_fps,
                 cv::Size(video_width, video_height),
                 true);
  if (!video_out.isOpened()) {
    printf("create video writer failed!\n");
    return;
  }

  std::vector<PaddleDetection::ObjectResult> result;
C
cnn 已提交
198
  std::vector<int> bbox_num;
G
Guanghua Yu 已提交
199
  std::vector<double> det_times;
Q
qingqing01 已提交
200 201 202 203
  auto labels = det->GetLabelList();
  auto colormap = PaddleDetection::GenerateColorMap(labels.size());
  // Capture all frames and do inference
  cv::Mat frame;
204
  int frame_id = 1;
C
cnn 已提交
205
  bool is_rbox = false;
Q
qingqing01 已提交
206 207 208 209
  while (capture.read(frame)) {
    if (frame.empty()) {
      break;
    }
C
cnn 已提交
210 211
    std::vector<cv::Mat> imgs;
    imgs.push_back(frame);
212 213 214
    printf("detect frame: %d\n", frame_id);
    det->Predict(imgs, FLAGS_threshold, 0, 1, &result, &bbox_num, &det_times);
    std::vector<PaddleDetection::ObjectResult> out_result;
Q
qingqing01 已提交
215
    for (const auto& item : result) {
216
      if (item.confidence < FLAGS_threshold || item.class_id == -1) {
217
        continue;
218 219
      }
      out_result.push_back(item);
220 221 222 223 224 225 226 227 228 229 230 231 232 233
      if (item.rect.size() > 6) {
        is_rbox = true;
        printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
               item.class_id,
               item.confidence,
               item.rect[0],
               item.rect[1],
               item.rect[2],
               item.rect[3],
               item.rect[4],
               item.rect[5],
               item.rect[6],
               item.rect[7]);
      } else {
C
cnn 已提交
234
        printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
235 236 237 238 239 240
               item.class_id,
               item.confidence,
               item.rect[0],
               item.rect[1],
               item.rect[2],
               item.rect[3]);
C
cnn 已提交
241
      }
242
    }
C
cnn 已提交
243

244
    cv::Mat out_im = PaddleDetection::VisualizeResult(
245
        frame, out_result, labels, colormap, is_rbox);
C
cnn 已提交
246

247
    video_out.write(out_im);
Q
qingqing01 已提交
248 249 250 251 252 253
    frame_id += 1;
  }
  capture.release();
  video_out.release();
}

C
cnn 已提交
254
void PredictImage(const std::vector<std::string> all_img_paths,
255 256
                  const int batch_size,
                  const double threshold,
Q
qingqing01 已提交
257 258 259
                  const bool run_benchmark,
                  PaddleDetection::ObjectDetector* det,
                  const std::string& output_dir = "output") {
G
Guanghua Yu 已提交
260
  std::vector<double> det_t = {0, 0, 0};
261 262
  int steps = ceil(float(all_img_paths.size()) / batch_size);
  printf("total images = %d, batch_size = %d, total steps = %d\n",
263 264 265
         all_img_paths.size(),
         batch_size,
         steps);
C
cnn 已提交
266 267
  for (int idx = 0; idx < steps; idx++) {
    std::vector<cv::Mat> batch_imgs;
268 269 270
    int left_image_cnt = all_img_paths.size() - idx * batch_size;
    if (left_image_cnt > batch_size) {
      left_image_cnt = batch_size;
C
cnn 已提交
271 272
    }
    for (int bs = 0; bs < left_image_cnt; bs++) {
273
      std::string image_file_path = all_img_paths.at(idx * batch_size + bs);
C
cnn 已提交
274 275 276
      cv::Mat im = cv::imread(image_file_path, 1);
      batch_imgs.insert(batch_imgs.end(), im);
    }
277

G
Guanghua Yu 已提交
278 279
    // Store all detected result
    std::vector<PaddleDetection::ObjectResult> result;
C
cnn 已提交
280
    std::vector<int> bbox_num;
G
Guanghua Yu 已提交
281
    std::vector<double> det_times;
C
cnn 已提交
282
    bool is_rbox = false;
G
Guanghua Yu 已提交
283
    if (run_benchmark) {
284 285
      det->Predict(
          batch_imgs, threshold, 10, 10, &result, &bbox_num, &det_times);
G
Guanghua Yu 已提交
286
    } else {
287
      det->Predict(batch_imgs, threshold, 0, 1, &result, &bbox_num, &det_times);
288 289 290
      // get labels and colormap
      auto labels = det->GetLabelList();
      auto colormap = PaddleDetection::GenerateColorMap(labels.size());
291

292 293 294 295 296
      int item_start_idx = 0;
      for (int i = 0; i < left_image_cnt; i++) {
        cv::Mat im = batch_imgs[i];
        std::vector<PaddleDetection::ObjectResult> im_result;
        int detect_num = 0;
297

298 299 300 301 302 303 304
        for (int j = 0; j < bbox_num[i]; j++) {
          PaddleDetection::ObjectResult item = result[item_start_idx + j];
          if (item.confidence < threshold || item.class_id == -1) {
            continue;
          }
          detect_num += 1;
          im_result.push_back(item);
305
          if (item.rect.size() > 6) {
306 307
            is_rbox = true;
            printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
308 309 310 311 312 313 314 315 316 317 318
                   item.class_id,
                   item.confidence,
                   item.rect[0],
                   item.rect[1],
                   item.rect[2],
                   item.rect[3],
                   item.rect[4],
                   item.rect[5],
                   item.rect[6],
                   item.rect[7]);
          } else {
319
            printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
320 321 322 323 324 325
                   item.class_id,
                   item.confidence,
                   item.rect[0],
                   item.rect[1],
                   item.rect[2],
                   item.rect[3]);
C
cnn 已提交
326
          }
C
cnn 已提交
327
        }
328 329
        std::cout << all_img_paths.at(idx * batch_size + i)
                  << " The number of detected box: " << detect_num << std::endl;
330
        item_start_idx = item_start_idx + bbox_num[i];
W
wangguanzhong 已提交
331
        // Visualization result
C
cnn 已提交
332 333
        cv::Mat vis_img = PaddleDetection::VisualizeResult(
            im, im_result, labels, colormap, is_rbox);
334 335 336 337 338 339 340 341
        std::vector<int> compression_params;
        compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
        compression_params.push_back(95);
        std::string output_path(output_dir);
        if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
          output_path += OS_PATH_SEP;
        }
        std::string image_file_path = all_img_paths.at(idx * batch_size + i);
342 343
        output_path +=
            image_file_path.substr(image_file_path.find_last_of('/') + 1);
344
        cv::imwrite(output_path, vis_img, compression_params);
345
        printf("Visualized output saved as %s\n", output_path.c_str());
G
Guanghua Yu 已提交
346
      }
Q
qingqing01 已提交
347
    }
G
Guanghua Yu 已提交
348 349 350
    det_t[0] += det_times[0];
    det_t[1] += det_times[1];
    det_t[2] += det_times[2];
351
    det_times.clear();
Q
qingqing01 已提交
352
  }
C
cnn 已提交
353
  PrintBenchmarkLog(det_t, all_img_paths.size());
Q
qingqing01 已提交
354 355 356 357 358
}

int main(int argc, char** argv) {
  // Parsing command-line
  google::ParseCommandLineFlags(&argc, &argv, true);
359 360 361
  if (FLAGS_model_dir.empty() ||
      (FLAGS_image_file.empty() && FLAGS_image_dir.empty() &&
       FLAGS_video_file.empty())) {
362
    std::cout << "Usage: ./main --model_dir=/PATH/TO/INFERENCE_MODEL/ "
363
              << "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl;
Q
qingqing01 已提交
364 365
    return -1;
  }
366 367 368 369
  if (!(FLAGS_run_mode == "paddle" || FLAGS_run_mode == "trt_fp32" ||
        FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
    std::cout
        << "run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
Q
qingqing01 已提交
370 371
    return -1;
  }
372 373 374 375 376 377
  transform(FLAGS_device.begin(),
            FLAGS_device.end(),
            FLAGS_device.begin(),
            ::toupper);
  if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" ||
        FLAGS_device == "XPU")) {
G
Guanghua Yu 已提交
378 379 380 381
    std::cout << "device should be 'CPU', 'GPU' or 'XPU'.";
    return -1;
  }
  if (FLAGS_use_gpu) {
382 383
    std::cout << "Deprecated, please use `--device` to set the device you want "
                 "to run.";
G
Guanghua Yu 已提交
384 385
    return -1;
  }
Q
qingqing01 已提交
386
  // Load model and create a object detector
387 388
  PaddleDetection::ObjectDetector det(FLAGS_model_dir,
                                      FLAGS_device,
389
                                      FLAGS_use_mkldnn,
390 391 392 393 394 395 396 397
                                      FLAGS_cpu_threads,
                                      FLAGS_run_mode,
                                      FLAGS_batch_size,
                                      FLAGS_gpu_id,
                                      FLAGS_trt_min_shape,
                                      FLAGS_trt_max_shape,
                                      FLAGS_trt_opt_shape,
                                      FLAGS_trt_calib_mode);
Q
qingqing01 已提交
398
  // Do inference on input video or image
399
  if (!PathExists(FLAGS_output_dir)) {
400
    MkDirs(FLAGS_output_dir);
401
  }
G
Guanghua Yu 已提交
402
  if (!FLAGS_video_file.empty() || FLAGS_camera_id != -1) {
403
    PredictVideo(FLAGS_video_file, &det, FLAGS_output_dir);
G
Guanghua Yu 已提交
404
  } else if (!FLAGS_image_file.empty() || !FLAGS_image_dir.empty()) {
C
cnn 已提交
405 406
    std::vector<std::string> all_img_paths;
    std::vector<cv::String> cv_all_img_paths;
G
Guanghua Yu 已提交
407
    if (!FLAGS_image_file.empty()) {
C
cnn 已提交
408
      all_img_paths.push_back(FLAGS_image_file);
409
      if (FLAGS_batch_size > 1) {
410 411 412
        std::cout << "batch_size should be 1, when set `image_file`."
                  << std::endl;
        return -1;
C
cnn 已提交
413
      }
G
Guanghua Yu 已提交
414
    } else {
415 416 417 418
      cv::glob(FLAGS_image_dir, cv_all_img_paths);
      for (const auto& img_path : cv_all_img_paths) {
        all_img_paths.push_back(img_path);
      }
G
Guanghua Yu 已提交
419
    }
420 421 422 423 424 425
    PredictImage(all_img_paths,
                 FLAGS_batch_size,
                 FLAGS_threshold,
                 FLAGS_run_benchmark,
                 &det,
                 FLAGS_output_dir);
Q
qingqing01 已提交
426 427 428
  }
  return 0;
}