object_detector.cc 15.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
// for setprecision
Z
zlsh80826 已提交
16
#include <chrono>
17
#include <iomanip>
Q
qingqing01 已提交
18

19
#include "include/object_detector.h"
Q
qingqing01 已提交
20 21 22 23

namespace PaddleDetection {

// Load Model and create model predictor
24
void ObjectDetector::LoadModel(const std::string &model_dir,
Q
qingqing01 已提交
25
                               const int batch_size,
26
                               const std::string &run_mode) {
Q
qingqing01 已提交
27 28 29 30
  paddle_infer::Config config;
  std::string prog_file = model_dir + OS_PATH_SEP + "model.pdmodel";
  std::string params_file = model_dir + OS_PATH_SEP + "model.pdiparams";
  config.SetModel(prog_file, params_file);
G
Guanghua Yu 已提交
31
  if (this->device_ == "GPU") {
G
Guanghua Yu 已提交
32
    config.EnableUseGpu(200, this->gpu_id_);
Q
qingqing01 已提交
33
    config.SwitchIrOptim(true);
34
    // use tensorrt
35
    if (run_mode != "paddle") {
Q
qingqing01 已提交
36
      auto precision = paddle_infer::Config::Precision::kFloat32;
37 38
      if (run_mode == "trt_fp32") {
        precision = paddle_infer::Config::Precision::kFloat32;
39
      } else if (run_mode == "trt_fp16") {
Q
qingqing01 已提交
40
        precision = paddle_infer::Config::Precision::kHalf;
41
      } else if (run_mode == "trt_int8") {
42
        precision = paddle_infer::Config::Precision::kInt8;
Q
qingqing01 已提交
43
      } else {
44 45
        printf("run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or "
               "'trt_int8'");
Q
qingqing01 已提交
46
      }
47
      // set tensorrt
48 49
      config.EnableTensorRtEngine(1 << 30, batch_size, this->min_subgraph_size_,
                                  precision, false, this->trt_calib_mode_);
50 51

      // set use dynamic shape
G
Guanghua Yu 已提交
52
      if (this->use_dynamic_shape_) {
53
        // set DynamicShape for image tensor
54
        const std::vector<int> min_input_shape = {
55
            batch_size, 3, this->trt_min_shape_, this->trt_min_shape_};
56
        const std::vector<int> max_input_shape = {
57
            batch_size, 3, this->trt_max_shape_, this->trt_max_shape_};
58
        const std::vector<int> opt_input_shape = {
59
            batch_size, 3, this->trt_opt_shape_, this->trt_opt_shape_};
60 61 62 63 64 65
        const std::map<std::string, std::vector<int>> map_min_input_shape = {
            {"image", min_input_shape}};
        const std::map<std::string, std::vector<int>> map_max_input_shape = {
            {"image", max_input_shape}};
        const std::map<std::string, std::vector<int>> map_opt_input_shape = {
            {"image", opt_input_shape}};
66

67 68
        config.SetTRTDynamicShapeInfo(map_min_input_shape, map_max_input_shape,
                                      map_opt_input_shape);
69 70 71 72
        std::cout << "TensorRT dynamic shape enabled" << std::endl;
      }
    }

73 74
  } else if (this->device_ == "XPU") {
    config.EnableXpu(10 * 1024 * 1024);
Q
qingqing01 已提交
75 76
  } else {
    config.DisableGpu();
G
Guanghua Yu 已提交
77 78 79 80 81 82
    if (this->use_mkldnn_) {
      config.EnableMKLDNN();
      // cache 10 different shapes for mkldnn to avoid memory leak
      config.SetMkldnnCacheCapacity(10);
    }
    config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
Q
qingqing01 已提交
83 84
  }
  config.SwitchUseFeedFetchOps(false);
G
Guanghua Yu 已提交
85
  config.SwitchIrOptim(true);
Q
qingqing01 已提交
86 87 88 89 90 91 92
  config.DisableGlogInfo();
  // Memory optimization
  config.EnableMemoryOptim();
  predictor_ = std::move(CreatePredictor(config));
}

// Visualiztion MaskDetector results
93 94 95 96 97
cv::Mat
VisualizeResult(const cv::Mat &img,
                const std::vector<PaddleDetection::ObjectResult> &results,
                const std::vector<std::string> &lables,
                const std::vector<int> &colormap, const bool is_rbox = false) {
Q
qingqing01 已提交
98 99 100 101 102
  cv::Mat vis_img = img.clone();
  for (int i = 0; i < results.size(); ++i) {
    // Configure color and text size
    std::ostringstream oss;
    oss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
C
cnn 已提交
103
    oss << lables[results[i].class_id] << " ";
Q
qingqing01 已提交
104 105 106 107 108 109 110 111 112
    oss << results[i].confidence;
    std::string text = oss.str();
    int c1 = colormap[3 * results[i].class_id + 0];
    int c2 = colormap[3 * results[i].class_id + 1];
    int c3 = colormap[3 * results[i].class_id + 2];
    cv::Scalar roi_color = cv::Scalar(c1, c2, c3);
    int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL;
    double font_scale = 0.5f;
    float thickness = 0.5;
113 114
    cv::Size text_size =
        cv::getTextSize(text, font_face, font_scale, thickness, nullptr);
Q
qingqing01 已提交
115
    cv::Point origin;
C
cnn 已提交
116

117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    if (is_rbox) {
      // Draw object, text, and background
      for (int k = 0; k < 4; k++) {
        cv::Point pt1 = cv::Point(results[i].rect[(k * 2) % 8],
                                  results[i].rect[(k * 2 + 1) % 8]);
        cv::Point pt2 = cv::Point(results[i].rect[(k * 2 + 2) % 8],
                                  results[i].rect[(k * 2 + 3) % 8]);
        cv::line(vis_img, pt1, pt2, roi_color, 2);
      }
    } else {
      int w = results[i].rect[2] - results[i].rect[0];
      int h = results[i].rect[3] - results[i].rect[1];
      cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h);
      // Draw roi object, text, and background
      cv::rectangle(vis_img, roi, roi_color, 2);
C
cnn 已提交
132 133 134 135
    }

    origin.x = results[i].rect[0];
    origin.y = results[i].rect[1];
Q
qingqing01 已提交
136 137

    // Configure text background
138 139 140
    cv::Rect text_back =
        cv::Rect(results[i].rect[0], results[i].rect[1] - text_size.height,
                 text_size.width, text_size.height);
C
cnn 已提交
141
    // Draw text, and background
Q
qingqing01 已提交
142
    cv::rectangle(vis_img, text_back, roi_color, -1);
143 144
    cv::putText(vis_img, text, origin, font_face, font_scale,
                cv::Scalar(255, 255, 255), thickness);
Q
qingqing01 已提交
145 146 147 148
  }
  return vis_img;
}

149
void ObjectDetector::Preprocess(const cv::Mat &ori_im) {
Q
qingqing01 已提交
150 151 152 153 154 155 156
  // Clone the image : keep the original mat for postprocess
  cv::Mat im = ori_im.clone();
  cv::cvtColor(im, im, cv::COLOR_BGR2RGB);
  preprocessor_.Run(&im, &inputs_);
}

void ObjectDetector::Postprocess(
C
cnn 已提交
157
    const std::vector<cv::Mat> mats,
158 159
    std::vector<PaddleDetection::ObjectResult> *result,
    std::vector<int> bbox_num, std::vector<float> output_data_,
160
    bool is_rbox = false) {
Q
qingqing01 已提交
161
  result->clear();
C
cnn 已提交
162
  int start_idx = 0;
163
  for (int im_id = 0; im_id < mats.size(); im_id++) {
C
cnn 已提交
164
    cv::Mat raw_mat = mats[im_id];
165 166 167 168 169 170
    int rh = 1;
    int rw = 1;
    if (config_.arch_ == "Face") {
      rh = raw_mat.rows;
      rw = raw_mat.cols;
    }
171
    for (int j = start_idx; j < start_idx + bbox_num[im_id]; j++) {
C
cnn 已提交
172
      if (is_rbox) {
173 174 175 176 177 178 179 180 181 182 183 184
        // Class id
        int class_id = static_cast<int>(round(output_data_[0 + j * 10]));
        // Confidence score
        float score = output_data_[1 + j * 10];
        int x1 = (output_data_[2 + j * 10] * rw);
        int y1 = (output_data_[3 + j * 10] * rh);
        int x2 = (output_data_[4 + j * 10] * rw);
        int y2 = (output_data_[5 + j * 10] * rh);
        int x3 = (output_data_[6 + j * 10] * rw);
        int y3 = (output_data_[7 + j * 10] * rh);
        int x4 = (output_data_[8 + j * 10] * rw);
        int y4 = (output_data_[9 + j * 10] * rh);
185

186
        PaddleDetection::ObjectResult result_item;
187 188 189 190
        result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4};
        result_item.class_id = class_id;
        result_item.confidence = score;
        result->push_back(result_item);
191
      } else {
192 193 194 195 196 197 198 199 200 201
        // Class id
        int class_id = static_cast<int>(round(output_data_[0 + j * 6]));
        // Confidence score
        float score = output_data_[1 + j * 6];
        int xmin = (output_data_[2 + j * 6] * rw);
        int ymin = (output_data_[3 + j * 6] * rh);
        int xmax = (output_data_[4 + j * 6] * rw);
        int ymax = (output_data_[5 + j * 6] * rh);
        int wd = xmax - xmin;
        int hd = ymax - ymin;
202

203
        PaddleDetection::ObjectResult result_item;
204 205 206 207
        result_item.rect = {xmin, ymin, xmax, ymax};
        result_item.class_id = class_id;
        result_item.confidence = score;
        result->push_back(result_item);
C
cnn 已提交
208
      }
Q
qingqing01 已提交
209
    }
C
cnn 已提交
210
    start_idx += bbox_num[im_id];
Q
qingqing01 已提交
211 212 213
  }
}

C
cnn 已提交
214
void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
215
                             const double threshold, const int warmup,
216
                             const int repeats,
217 218 219
                             std::vector<PaddleDetection::ObjectResult> *result,
                             std::vector<int> *bbox_num,
                             std::vector<double> *times) {
G
Guanghua Yu 已提交
220
  auto preprocess_start = std::chrono::steady_clock::now();
C
cnn 已提交
221 222 223 224 225 226
  int batch_size = imgs.size();

  // in_data_batch
  std::vector<float> in_data_all;
  std::vector<float> im_shape_all(batch_size * 2);
  std::vector<float> scale_factor_all(batch_size * 2);
227
  std::vector<const float *> output_data_list_;
228
  std::vector<int> out_bbox_num_data_;
229 230 231 232

  // in_net img for each batch
  std::vector<cv::Mat> in_net_img_all(batch_size);

Q
qingqing01 已提交
233
  // Preprocess image
C
cnn 已提交
234 235 236 237 238 239 240 241 242
  for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
    cv::Mat im = imgs.at(bs_idx);
    Preprocess(im);
    im_shape_all[bs_idx * 2] = inputs_.im_shape_[0];
    im_shape_all[bs_idx * 2 + 1] = inputs_.im_shape_[1];

    scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0];
    scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1];

243 244
    in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(),
                       inputs_.im_data_.end());
245 246 247

    // collect in_net img
    in_net_img_all[bs_idx] = inputs_.in_net_im_;
C
cnn 已提交
248
  }
249 250 251 252 253 254 255 256 257 258 259 260 261 262

  // Pad Batch if batch size > 1
  if (batch_size > 1 && CheckDynamicInput(in_net_img_all)) {
    in_data_all.clear();
    std::vector<cv::Mat> pad_img_all = PadBatch(in_net_img_all);
    int rh = pad_img_all[0].rows;
    int rw = pad_img_all[0].cols;
    int rc = pad_img_all[0].channels();

    for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
      cv::Mat pad_img = pad_img_all[bs_idx];
      pad_img.convertTo(pad_img, CV_32FC3);
      std::vector<float> pad_data;
      pad_data.resize(rc * rh * rw);
263
      float *base = pad_data.data();
264
      for (int i = 0; i < rc; ++i) {
265 266
        cv::extractChannel(pad_img,
                           cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i);
267 268 269 270 271 272 273
      }
      in_data_all.insert(in_data_all.end(), pad_data.begin(), pad_data.end());
    }
    // update in_net_shape
    inputs_.in_net_shape_ = {static_cast<float>(rh), static_cast<float>(rw)};
  }

274
  auto preprocess_end = std::chrono::steady_clock::now();
275
  // Prepare input tensor
Q
qingqing01 已提交
276
  auto input_names = predictor_->GetInputNames();
277
  for (const auto &tensor_name : input_names) {
Q
qingqing01 已提交
278 279
    auto in_tensor = predictor_->GetInputHandle(tensor_name);
    if (tensor_name == "image") {
280 281
      int rh = inputs_.in_net_shape_[0];
      int rw = inputs_.in_net_shape_[1];
C
cnn 已提交
282 283
      in_tensor->Reshape({batch_size, 3, rh, rw});
      in_tensor->CopyFromCpu(in_data_all.data());
Q
qingqing01 已提交
284
    } else if (tensor_name == "im_shape") {
C
cnn 已提交
285 286
      in_tensor->Reshape({batch_size, 2});
      in_tensor->CopyFromCpu(im_shape_all.data());
Q
qingqing01 已提交
287
    } else if (tensor_name == "scale_factor") {
C
cnn 已提交
288 289
      in_tensor->Reshape({batch_size, 2});
      in_tensor->CopyFromCpu(scale_factor_all.data());
Q
qingqing01 已提交
290 291
    }
  }
292

Q
qingqing01 已提交
293
  // Run predictor
294 295 296 297 298
  std::vector<std::vector<float>> out_tensor_list;
  std::vector<std::vector<int>> output_shape_list;
  bool is_rbox = false;
  int reg_max = 7;
  int num_class = 80;
299
  // warmup
300
  for (int i = 0; i < warmup; i++) {
Q
qingqing01 已提交
301 302 303
    predictor_->Run();
    // Get output tensor
    auto output_names = predictor_->GetOutputNames();
304 305 306
    for (int j = 0; j < output_names.size(); j++) {
      auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
      std::vector<int> output_shape = output_tensor->shape();
307 308
      int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
                                    std::multiplies<int>());
309 310 311 312 313 314 315 316 317
      if (output_tensor->type() == paddle_infer::DataType::INT32) {
        out_bbox_num_data_.resize(out_num);
        output_tensor->CopyToCpu(out_bbox_num_data_.data());
      } else {
        std::vector<float> out_data;
        out_data.resize(out_num);
        output_tensor->CopyToCpu(out_data.data());
        out_tensor_list.push_back(out_data);
      }
318 319
    }
  }
320

G
Guanghua Yu 已提交
321
  auto inference_start = std::chrono::steady_clock::now();
322
  for (int i = 0; i < repeats; i++) {
Q
qingqing01 已提交
323 324
    predictor_->Run();
    // Get output tensor
325 326
    out_tensor_list.clear();
    output_shape_list.clear();
Q
qingqing01 已提交
327
    auto output_names = predictor_->GetOutputNames();
328 329 330
    for (int j = 0; j < output_names.size(); j++) {
      auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
      std::vector<int> output_shape = output_tensor->shape();
331 332
      int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
                                    std::multiplies<int>());
333 334 335 336 337 338 339 340 341 342
      output_shape_list.push_back(output_shape);
      if (output_tensor->type() == paddle_infer::DataType::INT32) {
        out_bbox_num_data_.resize(out_num);
        output_tensor->CopyToCpu(out_bbox_num_data_.data());
      } else {
        std::vector<float> out_data;
        out_data.resize(out_num);
        output_tensor->CopyToCpu(out_data.data());
        out_tensor_list.push_back(out_data);
      }
C
cnn 已提交
343
    }
Q
qingqing01 已提交
344
  }
G
Guanghua Yu 已提交
345 346
  auto inference_end = std::chrono::steady_clock::now();
  auto postprocess_start = std::chrono::steady_clock::now();
Q
qingqing01 已提交
347
  // Postprocessing result
348
  result->clear();
C
cnn 已提交
349
  bbox_num->clear();
350 351 352 353 354 355 356 357
  if (config_.arch_ == "PicoDet") {
    for (int i = 0; i < out_tensor_list.size(); i++) {
      if (i == 0) {
        num_class = output_shape_list[i][2];
      }
      if (i == config_.fpn_stride_.size()) {
        reg_max = output_shape_list[i][2] / 4 - 1;
      }
358 359
      float *buffer = new float[out_tensor_list[i].size()];
      memcpy(buffer, &out_tensor_list[i][0],
360
             out_tensor_list[i].size() * sizeof(float));
361 362 363
      output_data_list_.push_back(buffer);
    }
    PaddleDetection::PicoDetPostProcess(
364 365 366
        result, output_data_list_, config_.fpn_stride_, inputs_.im_shape_,
        inputs_.scale_factor_, config_.nms_info_["score_threshold"].as<float>(),
        config_.nms_info_["nms_threshold"].as<float>(), num_class, reg_max);
367 368
    bbox_num->push_back(result->size());
  } else {
369
    is_rbox = output_shape_list[0][output_shape_list[0].size() - 1] % 10 == 0;
370
    Postprocess(imgs, result, out_bbox_num_data_, out_tensor_list[0], is_rbox);
371
    for (int k = 0; k < out_bbox_num_data_.size(); k++) {
372 373 374
      int tmp = out_bbox_num_data_[k];
      bbox_num->push_back(tmp);
    }
C
cnn 已提交
375
  }
376

G
Guanghua Yu 已提交
377 378
  auto postprocess_end = std::chrono::steady_clock::now();

379 380 381
  std::chrono::duration<float> preprocess_diff =
      preprocess_end - preprocess_start;
  times->push_back(static_cast<double>(preprocess_diff.count() * 1000));
G
Guanghua Yu 已提交
382
  std::chrono::duration<float> inference_diff = inference_end - inference_start;
383 384 385 386 387
  times->push_back(
      static_cast<double>(inference_diff.count() / repeats * 1000));
  std::chrono::duration<float> postprocess_diff =
      postprocess_end - postprocess_start;
  times->push_back(static_cast<double>(postprocess_diff.count() * 1000));
Q
qingqing01 已提交
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
}

std::vector<int> GenerateColorMap(int num_class) {
  auto colormap = std::vector<int>(3 * num_class, 0);
  for (int i = 0; i < num_class; ++i) {
    int j = 0;
    int lab = i;
    while (lab) {
      colormap[i * 3] |= (((lab >> 0) & 1) << (7 - j));
      colormap[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j));
      colormap[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j));
      ++j;
      lab >>= 3;
    }
  }
  return colormap;
}

}  // namespace PaddleDetection