paddlex.cpp 31.0 KB
Newer Older
C
Channingss 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
S
syyxsxx 已提交
14 15

#include <math.h>
J
jack 已提交
16
#include <omp.h>
J
jack 已提交
17
#include <algorithm>
J
jack 已提交
18
#include <fstream>
J
jack 已提交
19
#include <cstring>
J
jack 已提交
20
#include "include/paddlex/paddlex.h"
S
syyxsxx 已提交
21 22 23 24 25

#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>

C
Channingss 已提交
26 27 28 29
namespace PaddleX {

void Model::create_predictor(const std::string& model_dir,
                             bool use_gpu,
C
Channingss 已提交
30
                             bool use_trt,
S
syyxsxx 已提交
31
                             bool use_mkl,
S
syyxsxx 已提交
32
                             int mkl_thread_num,
C
Channingss 已提交
33
                             int gpu_id,
J
jack 已提交
34 35
                             std::string key,
                             bool use_ir_optim) {
C
Channingss 已提交
36 37 38
  paddle::AnalysisConfig config;
  std::string model_file = model_dir + OS_PATH_SEP + "__model__";
  std::string params_file = model_dir + OS_PATH_SEP + "__params__";
J
jack 已提交
39
  std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
J
jack 已提交
40
  std::string yaml_input = "";
C
Channingss 已提交
41
#ifdef WITH_ENCRYPTION
J
jack 已提交
42
  if (key != "") {
F
FlyingQianMM 已提交
43 44
    model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
    params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
J
jack 已提交
45
    yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted";
J
jack 已提交
46 47
    paddle_security_load_model(
        &config, key.c_str(), model_file.c_str(), params_file.c_str());
J
jack 已提交
48
    yaml_input = decrypt_file(yaml_file.c_str(), key.c_str());
C
Channingss 已提交
49 50
  }
#endif
J
jack 已提交
51
  if (yaml_input == "") {
S
syyxsxx 已提交
52
    // read yaml file
J
jack 已提交
53 54 55 56 57 58 59
    std::ifstream yaml_fin(yaml_file);
    yaml_fin.seekg(0, std::ios::end);
    size_t yaml_file_size = yaml_fin.tellg();
    yaml_input.assign(yaml_file_size, ' ');
    yaml_fin.seekg(0);
    yaml_fin.read(&yaml_input[0], yaml_file_size);
  }
S
syyxsxx 已提交
60
  // load yaml file
J
jack 已提交
61
  if (!load_config(yaml_input)) {
J
jack 已提交
62 63 64 65
    std::cerr << "Parse file 'model.yml' failed!" << std::endl;
    exit(-1);
  }

J
jack 已提交
66
  if (key == "") {
C
Channingss 已提交
67 68
    config.SetModel(model_file, params_file);
  }
F
FlyingQianMM 已提交
69 70 71 72 73 74 75 76
  if (use_mkl) {
    if (name != "HRNet" && name != "DeepLabv3p" && name != "PPYOLO") {
        config.EnableMKLDNN();
        config.SetCpuMathLibraryNumThreads(mkl_thread_num);
    } else {
        std::cerr << "HRNet/DeepLabv3p/PPYOLO are not supported "
                  << "for the use of mkldnn" << std::endl;
    }
S
syyxsxx 已提交
77
  }
C
Channingss 已提交
78 79 80 81 82 83 84
  if (use_gpu) {
    config.EnableUseGpu(100, gpu_id);
  } else {
    config.DisableGpu();
  }
  config.SwitchUseFeedFetchOps(false);
  config.SwitchSpecifyInputNames(true);
S
syyxsxx 已提交
85
  // enable graph Optim
F
FlyingQianMM 已提交
86 87 88
#if defined(__arm__) || defined(__aarch64__)
  config.SwitchIrOptim(false);
#else
J
jack 已提交
89
  config.SwitchIrOptim(use_ir_optim);
F
FlyingQianMM 已提交
90
#endif
S
syyxsxx 已提交
91
  // enable Memory Optim
C
Channingss 已提交
92
  config.EnableMemoryOptim();
F
FlyingQianMM 已提交
93
  if (use_trt && use_gpu) {
C
Channingss 已提交
94 95 96 97 98 99 100
    config.EnableTensorRtEngine(
        1 << 20 /* workspace_size*/,
        32 /* max_batch_size*/,
        20 /* min_subgraph_size*/,
        paddle::AnalysisConfig::Precision::kFloat32 /* precision*/,
        true /* use_static*/,
        false /* use_calib_mode*/);
C
Channingss 已提交
101
  }
C
Channingss 已提交
102 103 104
  predictor_ = std::move(CreatePaddlePredictor(config));
}

J
jack 已提交
105 106
bool Model::load_config(const std::string& yaml_input) {
  YAML::Node config = YAML::Load(yaml_input);
C
Channingss 已提交
107 108
  type = config["_Attributes"]["model_type"].as<std::string>();
  name = config["Model"].as<std::string>();
F
FlyingQianMM 已提交
109 110
  std::string version = config["version"].as<std::string>();
  if (version[0] == '0') {
J
jack 已提交
111 112 113 114 115
    std::cerr << "[Init] Version of the loaded model is lower than 1.0.0, "
              << "deployment cannot be done, please refer to "
              << "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs"
              << "/tutorials/deploy/upgrade_version.md "
              << "to transfer version." << std::endl;
F
FlyingQianMM 已提交
116 117
    return false;
  }
C
Channingss 已提交
118 119 120 121 122 123 124 125 126 127 128
  bool to_rgb = true;
  if (config["TransformsMode"].IsDefined()) {
    std::string mode = config["TransformsMode"].as<std::string>();
    if (mode == "BGR") {
      to_rgb = false;
    } else if (mode != "RGB") {
      std::cerr << "[Init] Only 'RGB' or 'BGR' is supported for TransformsMode"
                << std::endl;
      return false;
    }
  }
S
syyxsxx 已提交
129
  // build data preprocess stream
C
Channingss 已提交
130
  transforms_.Init(config["Transforms"], to_rgb);
S
syyxsxx 已提交
131
  // read label list
C
Channingss 已提交
132 133 134 135 136
  labels.clear();
  for (const auto& item : config["_Attributes"]["labels"]) {
    int index = labels.size();
    labels[index] = item.as<std::string>();
  }
F
FlyingQianMM 已提交
137
  if (config["_init_params"]["input_channel"].IsDefined()) {
138
    input_channel_ = config["_init_params"]["input_channel"].as<int>();
F
FlyingQianMM 已提交
139
  } else {
140
    input_channel_ = 3;
F
FlyingQianMM 已提交
141
  }
C
Channingss 已提交
142 143 144 145 146
  return true;
}

bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
  cv::Mat im = input_im.clone();
147
  if (!transforms_.Run(&im, blob)) {
C
Channingss 已提交
148 149 150 151 152
    return false;
  }
  return true;
}

J
jack 已提交
153
// use openmp
J
jack 已提交
154 155 156
bool Model::preprocess(const std::vector<cv::Mat>& input_im_batch,
                       std::vector<ImageBlob>* blob_batch,
                       int thread_num) {
J
jack 已提交
157
  int batch_size = input_im_batch.size();
J
jack 已提交
158
  bool success = true;
J
jack 已提交
159 160
  thread_num = std::min(thread_num, batch_size);
  #pragma omp parallel for num_threads(thread_num)
J
jack 已提交
161
  for (int i = 0; i < input_im_batch.size(); ++i) {
J
jack 已提交
162
    cv::Mat im = input_im_batch[i].clone();
J
jack 已提交
163
    if (!transforms_.Run(&im, &(*blob_batch)[i])) {
J
jack 已提交
164 165 166 167 168 169
      success = false;
    }
  }
  return success;
}

C
Channingss 已提交
170 171 172 173 174
bool Model::predict(const cv::Mat& im, ClsResult* result) {
  inputs_.clear();
  if (type == "detector") {
    std::cerr << "Loading model is a 'detector', DetResult should be passed to "
                 "function predict()!"
J
jack 已提交
175
                 "to function predict()!" << std::endl;
C
Channingss 已提交
176 177
    return false;
  }
S
syyxsxx 已提交
178
  // im preprocess
C
Channingss 已提交
179 180 181 182
  if (!preprocess(im, &inputs_)) {
    std::cerr << "Preprocess failed!" << std::endl;
    return false;
  }
S
syyxsxx 已提交
183
  // predict
C
Channingss 已提交
184 185 186
  auto in_tensor = predictor_->GetInputTensor("image");
  int h = inputs_.new_im_size_[0];
  int w = inputs_.new_im_size_[1];
187
  in_tensor->Reshape({1, input_channel_, h, w});
C
Channingss 已提交
188 189
  in_tensor->copy_from_cpu(inputs_.im_data_.data());
  predictor_->ZeroCopyRun();
S
syyxsxx 已提交
190
  // get result
C
Channingss 已提交
191 192 193 194 195 196 197 198 199
  auto output_names = predictor_->GetOutputNames();
  auto output_tensor = predictor_->GetOutputTensor(output_names[0]);
  std::vector<int> output_shape = output_tensor->shape();
  int size = 1;
  for (const auto& i : output_shape) {
    size *= i;
  }
  outputs_.resize(size);
  output_tensor->copy_to_cpu(outputs_.data());
S
syyxsxx 已提交
200
  // postprocess
C
Channingss 已提交
201 202 203 204
  auto ptr = std::max_element(std::begin(outputs_), std::end(outputs_));
  result->category_id = std::distance(std::begin(outputs_), ptr);
  result->score = *ptr;
  result->category = labels[result->category_id];
J
jack 已提交
205
  return true;
C
Channingss 已提交
206 207
}

J
jack 已提交
208 209 210 211
bool Model::predict(const std::vector<cv::Mat>& im_batch,
                    std::vector<ClsResult>* results,
                    int thread_num) {
  for (auto& inputs : inputs_batch_) {
J
jack 已提交
212 213 214 215
    inputs.clear();
  }
  if (type == "detector") {
    std::cerr << "Loading model is a 'detector', DetResult should be passed to "
J
jack 已提交
216
                 "function predict()!" << std::endl;
J
jack 已提交
217 218 219
    return false;
  } else if (type == "segmenter") {
    std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
J
jack 已提交
220
                 "to function predict()!" << std::endl;
J
jack 已提交
221 222
    return false;
  }
J
jack 已提交
223
  inputs_batch_.assign(im_batch.size(), ImageBlob());
S
syyxsxx 已提交
224
  // preprocess
J
jack 已提交
225
  if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
J
jack 已提交
226 227 228
    std::cerr << "Preprocess failed!" << std::endl;
    return false;
  }
S
syyxsxx 已提交
229
  // predict
J
jack 已提交
230 231 232 233
  int batch_size = im_batch.size();
  auto in_tensor = predictor_->GetInputTensor("image");
  int h = inputs_batch_[0].new_im_size_[0];
  int w = inputs_batch_[0].new_im_size_[1];
234 235
  in_tensor->Reshape({batch_size, input_channel_, h, w});
  std::vector<float> inputs_data(batch_size * input_channel_ * h * w);
J
jack 已提交
236 237 238
  for (int i = 0; i < batch_size; ++i) {
    std::copy(inputs_batch_[i].im_data_.begin(),
              inputs_batch_[i].im_data_.end(),
239
              inputs_data.begin() + i * input_channel_ * h * w);
J
jack 已提交
240 241
  }
  in_tensor->copy_from_cpu(inputs_data.data());
J
jack 已提交
242
  // in_tensor->copy_from_cpu(inputs_.im_data_.data());
J
jack 已提交
243
  predictor_->ZeroCopyRun();
S
syyxsxx 已提交
244
  // get result
J
jack 已提交
245 246 247 248 249 250 251 252 253
  auto output_names = predictor_->GetOutputNames();
  auto output_tensor = predictor_->GetOutputTensor(output_names[0]);
  std::vector<int> output_shape = output_tensor->shape();
  int size = 1;
  for (const auto& i : output_shape) {
    size *= i;
  }
  outputs_.resize(size);
  output_tensor->copy_to_cpu(outputs_.data());
S
syyxsxx 已提交
254
  // postprocess
255 256
  (*results).clear();
  (*results).resize(batch_size);
J
jack 已提交
257
  int single_batch_size = size / batch_size;
J
jack 已提交
258
  for (int i = 0; i < batch_size; ++i) {
J
jack 已提交
259 260 261 262 263
    auto start_ptr = std::begin(outputs_);
    auto end_ptr = std::begin(outputs_);
    std::advance(start_ptr, i * single_batch_size);
    std::advance(end_ptr, (i + 1) * single_batch_size);
    auto ptr = std::max_element(start_ptr, end_ptr);
J
jack 已提交
264 265 266
    (*results)[i].category_id = std::distance(start_ptr, ptr);
    (*results)[i].score = *ptr;
    (*results)[i].category = labels[(*results)[i].category_id];
J
jack 已提交
267 268 269 270
  }
  return true;
}

C
Channingss 已提交
271
bool Model::predict(const cv::Mat& im, DetResult* result) {
J
jack 已提交
272
  inputs_.clear();
C
Channingss 已提交
273 274 275
  result->clear();
  if (type == "classifier") {
    std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
J
jack 已提交
276
                 "to function predict()!" << std::endl;
C
Channingss 已提交
277 278 279
    return false;
  } else if (type == "segmenter") {
    std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
J
jack 已提交
280
                 "to function predict()!" << std::endl;
C
Channingss 已提交
281 282 283
    return false;
  }

S
syyxsxx 已提交
284
  // preprocess
C
Channingss 已提交
285 286 287 288 289 290 291 292
  if (!preprocess(im, &inputs_)) {
    std::cerr << "Preprocess failed!" << std::endl;
    return false;
  }

  int h = inputs_.new_im_size_[0];
  int w = inputs_.new_im_size_[1];
  auto im_tensor = predictor_->GetInputTensor("image");
293
  im_tensor->Reshape({1, input_channel_, h, w});
C
Channingss 已提交
294
  im_tensor->copy_from_cpu(inputs_.im_data_.data());
J
jack 已提交
295

F
FlyingQianMM 已提交
296
  if (name == "YOLOv3" || name == "PPYOLO") {
C
Channingss 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
    auto im_size_tensor = predictor_->GetInputTensor("im_size");
    im_size_tensor->Reshape({1, 2});
    im_size_tensor->copy_from_cpu(inputs_.ori_im_size_.data());
  } else if (name == "FasterRCNN" || name == "MaskRCNN") {
    auto im_info_tensor = predictor_->GetInputTensor("im_info");
    auto im_shape_tensor = predictor_->GetInputTensor("im_shape");
    im_info_tensor->Reshape({1, 3});
    im_shape_tensor->Reshape({1, 3});
    float ori_h = static_cast<float>(inputs_.ori_im_size_[0]);
    float ori_w = static_cast<float>(inputs_.ori_im_size_[1]);
    float new_h = static_cast<float>(inputs_.new_im_size_[0]);
    float new_w = static_cast<float>(inputs_.new_im_size_[1]);
    float im_info[] = {new_h, new_w, inputs_.scale};
    float im_shape[] = {ori_h, ori_w, 1.0};
    im_info_tensor->copy_from_cpu(im_info);
    im_shape_tensor->copy_from_cpu(im_shape);
  }
S
syyxsxx 已提交
314
  // predict
C
Channingss 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
  predictor_->ZeroCopyRun();

  std::vector<float> output_box;
  auto output_names = predictor_->GetOutputNames();
  auto output_box_tensor = predictor_->GetOutputTensor(output_names[0]);
  std::vector<int> output_box_shape = output_box_tensor->shape();
  int size = 1;
  for (const auto& i : output_box_shape) {
    size *= i;
  }
  output_box.resize(size);
  output_box_tensor->copy_to_cpu(output_box.data());
  if (size < 6) {
    std::cerr << "[WARNING] There's no object detected." << std::endl;
    return true;
  }
  int num_boxes = size / 6;
S
syyxsxx 已提交
332
  // box postprocess
C
Channingss 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346
  for (int i = 0; i < num_boxes; ++i) {
    Box box;
    box.category_id = static_cast<int>(round(output_box[i * 6]));
    box.category = labels[box.category_id];
    box.score = output_box[i * 6 + 1];
    float xmin = output_box[i * 6 + 2];
    float ymin = output_box[i * 6 + 3];
    float xmax = output_box[i * 6 + 4];
    float ymax = output_box[i * 6 + 5];
    float w = xmax - xmin + 1;
    float h = ymax - ymin + 1;
    box.coordinate = {xmin, ymin, w, h};
    result->boxes.push_back(std::move(box));
  }
S
syyxsxx 已提交
347
  // mask postprocess
C
Channingss 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
  if (name == "MaskRCNN") {
    std::vector<float> output_mask;
    auto output_mask_tensor = predictor_->GetOutputTensor(output_names[1]);
    std::vector<int> output_mask_shape = output_mask_tensor->shape();
    int masks_size = 1;
    for (const auto& i : output_mask_shape) {
      masks_size *= i;
    }
    int mask_pixels = output_mask_shape[2] * output_mask_shape[3];
    int classes = output_mask_shape[1];
    output_mask.resize(masks_size);
    output_mask_tensor->copy_to_cpu(output_mask.data());
    result->mask_resolution = output_mask_shape[2];
    for (int i = 0; i < result->boxes.size(); ++i) {
      Box* box = &result->boxes[i];
      box->mask.shape = {static_cast<int>(box->coordinate[2]),
                         static_cast<int>(box->coordinate[3])};
S
syyxsxx 已提交
365
      auto begin_mask =
S
fix  
syyxsxx 已提交
366
          output_mask.data() + (i * classes + box->category_id) * mask_pixels;
S
syyxsxx 已提交
367 368 369 370 371 372 373 374
      cv::Mat bin_mask(result->mask_resolution,
                     result->mask_resolution,
                     CV_32FC1,
                     begin_mask);
      cv::resize(bin_mask,
               bin_mask,
               cv::Size(box->mask.shape[0], box->mask.shape[1]));
      cv::threshold(bin_mask, bin_mask, 0.5, 1, cv::THRESH_BINARY);
S
syyxsxx 已提交
375
      auto mask_int_begin = reinterpret_cast<float*>(bin_mask.data);
S
syyxsxx 已提交
376 377 378
      auto mask_int_end =
        mask_int_begin + box->mask.shape[0] * box->mask.shape[1];
      box->mask.data.assign(mask_int_begin, mask_int_end);
C
Channingss 已提交
379 380
    }
  }
J
jack 已提交
381
  return true;
C
Channingss 已提交
382 383
}

J
jack 已提交
384
bool Model::predict(const std::vector<cv::Mat>& im_batch,
385
                    std::vector<DetResult>* results,
J
jack 已提交
386 387
                    int thread_num) {
  for (auto& inputs : inputs_batch_) {
J
jack 已提交
388 389
    inputs.clear();
  }
J
jack 已提交
390 391
  if (type == "classifier") {
    std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
J
jack 已提交
392
                 "to function predict()!" << std::endl;
J
jack 已提交
393 394 395
    return false;
  } else if (type == "segmenter") {
    std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
J
jack 已提交
396
                 "to function predict()!" << std::endl;
J
jack 已提交
397 398 399
    return false;
  }

J
jack 已提交
400
  inputs_batch_.assign(im_batch.size(), ImageBlob());
J
jack 已提交
401
  int batch_size = im_batch.size();
S
syyxsxx 已提交
402
  // preprocess
J
jack 已提交
403
  if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
J
jack 已提交
404 405 406
    std::cerr << "Preprocess failed!" << std::endl;
    return false;
  }
S
syyxsxx 已提交
407
  // RCNN model padding
J
jack 已提交
408 409 410 411
  if (batch_size > 1) {
    if (name == "FasterRCNN" || name == "MaskRCNN") {
      int max_h = -1;
      int max_w = -1;
J
jack 已提交
412
      for (int i = 0; i < batch_size; ++i) {
J
jack 已提交
413 414
        max_h = std::max(max_h, inputs_batch_[i].new_im_size_[0]);
        max_w = std::max(max_w, inputs_batch_[i].new_im_size_[1]);
J
jack 已提交
415 416
        // std::cout << "(" << inputs_batch_[i].new_im_size_[0]
        //          << ", " << inputs_batch_[i].new_im_size_[1]
J
jack 已提交
417
        //          <<  ")" << std::endl;
J
jack 已提交
418
      }
J
jack 已提交
419 420
      thread_num = std::min(thread_num, batch_size);
      #pragma omp parallel for num_threads(thread_num)
J
jack 已提交
421
      for (int i = 0; i < batch_size; ++i) {
J
jack 已提交
422 423 424
        int h = inputs_batch_[i].new_im_size_[0];
        int w = inputs_batch_[i].new_im_size_[1];
        int c = im_batch[i].channels();
J
jack 已提交
425
        if (max_h != h || max_w != w) {
J
jack 已提交
426
          std::vector<float> temp_buffer(c * max_h * max_w);
J
jack 已提交
427 428 429
          float* temp_ptr = temp_buffer.data();
          float* ptr = inputs_batch_[i].im_data_.data();
          for (int cur_channel = c - 1; cur_channel >= 0; --cur_channel) {
J
jack 已提交
430 431
            int ori_pos = cur_channel * h * w + (h - 1) * w;
            int des_pos = cur_channel * max_h * max_w + (h - 1) * max_w;
J
jack 已提交
432 433 434
            int last_pos = cur_channel * h * w;
            for (; ori_pos >= last_pos; ori_pos -= w, des_pos -= max_w) {
              memcpy(temp_ptr + des_pos, ptr + ori_pos, w * sizeof(float));
J
jack 已提交
435 436 437 438
            }
          }
          inputs_batch_[i].im_data_.swap(temp_buffer);
          inputs_batch_[i].new_im_size_[0] = max_h;
J
jack 已提交
439
          inputs_batch_[i].new_im_size_[1] = max_w;
J
jack 已提交
440 441 442 443
        }
      }
    }
  }
J
jack 已提交
444 445 446
  int h = inputs_batch_[0].new_im_size_[0];
  int w = inputs_batch_[0].new_im_size_[1];
  auto im_tensor = predictor_->GetInputTensor("image");
447 448
  im_tensor->Reshape({batch_size, input_channel_, h, w});
  std::vector<float> inputs_data(batch_size * input_channel_ * h * w);
J
jack 已提交
449 450 451
  for (int i = 0; i < batch_size; ++i) {
    std::copy(inputs_batch_[i].im_data_.begin(),
              inputs_batch_[i].im_data_.end(),
452
              inputs_data.begin() + i * input_channel_ * h * w);
J
jack 已提交
453 454
  }
  im_tensor->copy_from_cpu(inputs_data.data());
F
FlyingQianMM 已提交
455
  if (name == "YOLOv3" || name == "PPYOLO") {
J
jack 已提交
456 457
    auto im_size_tensor = predictor_->GetInputTensor("im_size");
    im_size_tensor->Reshape({batch_size, 2});
J
jack 已提交
458 459 460 461 462
    std::vector<int> inputs_data_size(batch_size * 2);
    for (int i = 0; i < batch_size; ++i) {
      std::copy(inputs_batch_[i].ori_im_size_.begin(),
                inputs_batch_[i].ori_im_size_.end(),
                inputs_data_size.begin() + 2 * i);
J
jack 已提交
463 464 465 466 467 468 469
    }
    im_size_tensor->copy_from_cpu(inputs_data_size.data());
  } else if (name == "FasterRCNN" || name == "MaskRCNN") {
    auto im_info_tensor = predictor_->GetInputTensor("im_info");
    auto im_shape_tensor = predictor_->GetInputTensor("im_shape");
    im_info_tensor->Reshape({batch_size, 3});
    im_shape_tensor->Reshape({batch_size, 3});
J
jack 已提交
470

J
jack 已提交
471 472
    std::vector<float> im_info(3 * batch_size);
    std::vector<float> im_shape(3 * batch_size);
J
jack 已提交
473
    for (int i = 0; i < batch_size; ++i) {
J
jack 已提交
474 475 476 477 478 479 480 481 482 483 484 485 486 487
      float ori_h = static_cast<float>(inputs_batch_[i].ori_im_size_[0]);
      float ori_w = static_cast<float>(inputs_batch_[i].ori_im_size_[1]);
      float new_h = static_cast<float>(inputs_batch_[i].new_im_size_[0]);
      float new_w = static_cast<float>(inputs_batch_[i].new_im_size_[1]);
      im_info[i * 3] = new_h;
      im_info[i * 3 + 1] = new_w;
      im_info[i * 3 + 2] = inputs_batch_[i].scale;
      im_shape[i * 3] = ori_h;
      im_shape[i * 3 + 1] = ori_w;
      im_shape[i * 3 + 2] = 1.0;
    }
    im_info_tensor->copy_from_cpu(im_info.data());
    im_shape_tensor->copy_from_cpu(im_shape.data());
  }
S
syyxsxx 已提交
488
  // predict
J
jack 已提交
489 490
  predictor_->ZeroCopyRun();

S
syyxsxx 已提交
491
  // get all box
J
jack 已提交
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
  std::vector<float> output_box;
  auto output_names = predictor_->GetOutputNames();
  auto output_box_tensor = predictor_->GetOutputTensor(output_names[0]);
  std::vector<int> output_box_shape = output_box_tensor->shape();
  int size = 1;
  for (const auto& i : output_box_shape) {
    size *= i;
  }
  output_box.resize(size);
  output_box_tensor->copy_to_cpu(output_box.data());
  if (size < 6) {
    std::cerr << "[WARNING] There's no object detected." << std::endl;
    return true;
  }
  auto lod_vector = output_box_tensor->lod();
  int num_boxes = size / 6;
S
syyxsxx 已提交
508
  // box postprocess
509 510
  (*results).clear();
  (*results).resize(batch_size);
J
jack 已提交
511
  for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
J
jack 已提交
512
    for (int j = lod_vector[0][i]; j < lod_vector[0][i + 1]; ++j) {
J
jack 已提交
513
      Box box;
J
jack 已提交
514
      box.category_id = static_cast<int>(round(output_box[j * 6]));
J
jack 已提交
515 516 517 518 519 520 521 522 523
      box.category = labels[box.category_id];
      box.score = output_box[j * 6 + 1];
      float xmin = output_box[j * 6 + 2];
      float ymin = output_box[j * 6 + 3];
      float xmax = output_box[j * 6 + 4];
      float ymax = output_box[j * 6 + 5];
      float w = xmax - xmin + 1;
      float h = ymax - ymin + 1;
      box.coordinate = {xmin, ymin, w, h};
524
      (*results)[i].boxes.push_back(std::move(box));
J
jack 已提交
525 526 527
    }
  }

S
syyxsxx 已提交
528
  // mask postprocess
J
jack 已提交
529 530 531 532 533 534 535 536 537 538 539 540 541
  if (name == "MaskRCNN") {
    std::vector<float> output_mask;
    auto output_mask_tensor = predictor_->GetOutputTensor(output_names[1]);
    std::vector<int> output_mask_shape = output_mask_tensor->shape();
    int masks_size = 1;
    for (const auto& i : output_mask_shape) {
      masks_size *= i;
    }
    int mask_pixels = output_mask_shape[2] * output_mask_shape[3];
    int classes = output_mask_shape[1];
    output_mask.resize(masks_size);
    output_mask_tensor->copy_to_cpu(output_mask.data());
    int mask_idx = 0;
J
jack 已提交
542
    for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
543 544
      (*results)[i].mask_resolution = output_mask_shape[2];
      for (int j = 0; j < (*results)[i].boxes.size(); ++j) {
S
fix  
syyxsxx 已提交
545
        Box* box = &(*results)[i].boxes[i];
J
jack 已提交
546
        int category_id = box->category_id;
S
syyxsxx 已提交
547 548
        box->mask.shape = {static_cast<int>(box->coordinate[2]),
                          static_cast<int>(box->coordinate[3])};
S
syyxsxx 已提交
549
        auto begin_mask =
S
fix  
syyxsxx 已提交
550
          output_mask.data() + (i * classes + box->category_id) * mask_pixels;
S
fix  
syyxsxx 已提交
551 552
        cv::Mat bin_mask(output_mask_shape[2],
                      output_mask_shape[2],
S
syyxsxx 已提交
553 554 555 556 557 558
                      CV_32FC1,
                      begin_mask);
        cv::resize(bin_mask,
                bin_mask,
                cv::Size(box->mask.shape[0], box->mask.shape[1]));
        cv::threshold(bin_mask, bin_mask, 0.5, 1, cv::THRESH_BINARY);
S
syyxsxx 已提交
559
        auto mask_int_begin = reinterpret_cast<float*>(bin_mask.data);
S
syyxsxx 已提交
560 561 562
        auto mask_int_end =
          mask_int_begin + box->mask.shape[0] * box->mask.shape[1];
        box->mask.data.assign(mask_int_begin, mask_int_end);
J
jack 已提交
563 564 565 566
        mask_idx++;
      }
    }
  }
J
jack 已提交
567
  return true;
J
jack 已提交
568 569
}

C
Channingss 已提交
570 571 572 573 574
bool Model::predict(const cv::Mat& im, SegResult* result) {
  result->clear();
  inputs_.clear();
  if (type == "classifier") {
    std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
J
jack 已提交
575
                 "to function predict()!" << std::endl;
C
Channingss 已提交
576 577 578
    return false;
  } else if (type == "detector") {
    std::cerr << "Loading model is a 'detector', DetResult should be passed to "
J
jack 已提交
579
                 "function predict()!" << std::endl;
C
Channingss 已提交
580 581 582
    return false;
  }

S
syyxsxx 已提交
583
  // preprocess
C
Channingss 已提交
584 585 586 587 588 589 590 591
  if (!preprocess(im, &inputs_)) {
    std::cerr << "Preprocess failed!" << std::endl;
    return false;
  }

  int h = inputs_.new_im_size_[0];
  int w = inputs_.new_im_size_[1];
  auto im_tensor = predictor_->GetInputTensor("image");
592
  im_tensor->Reshape({1, input_channel_, h, w});
C
Channingss 已提交
593 594
  im_tensor->copy_from_cpu(inputs_.im_data_.data());

S
syyxsxx 已提交
595
  // predict
C
Channingss 已提交
596 597
  predictor_->ZeroCopyRun();

S
syyxsxx 已提交
598
  // get labelmap
C
Channingss 已提交
599 600 601 602 603 604 605 606
  auto output_names = predictor_->GetOutputNames();
  auto output_label_tensor = predictor_->GetOutputTensor(output_names[0]);
  std::vector<int> output_label_shape = output_label_tensor->shape();
  int size = 1;
  for (const auto& i : output_label_shape) {
    size *= i;
    result->label_map.shape.push_back(i);
  }
J
jack 已提交
607

C
Channingss 已提交
608 609 610
  result->label_map.data.resize(size);
  output_label_tensor->copy_to_cpu(result->label_map.data.data());

S
syyxsxx 已提交
611
  // get scoremap
C
Channingss 已提交
612 613 614 615 616 617 618
  auto output_score_tensor = predictor_->GetOutputTensor(output_names[1]);
  std::vector<int> output_score_shape = output_score_tensor->shape();
  size = 1;
  for (const auto& i : output_score_shape) {
    size *= i;
    result->score_map.shape.push_back(i);
  }
J
jack 已提交
619

C
Channingss 已提交
620 621 622
  result->score_map.data.resize(size);
  output_score_tensor->copy_to_cpu(result->score_map.data.data());

S
syyxsxx 已提交
623
  // get origin image result
C
Channingss 已提交
624 625 626 627 628 629 630 631 632 633 634
  std::vector<uint8_t> label_map(result->label_map.data.begin(),
                                 result->label_map.data.end());
  cv::Mat mask_label(result->label_map.shape[1],
                     result->label_map.shape[2],
                     CV_8UC1,
                     label_map.data());

  cv::Mat mask_score(result->score_map.shape[2],
                     result->score_map.shape[3],
                     CV_32FC1,
                     result->score_map.data.data());
C
Channingss 已提交
635
  int idx = 1;
C
Channingss 已提交
636
  int len_postprocess = inputs_.im_size_before_resize_.size();
C
Channingss 已提交
637 638
  for (std::vector<std::string>::reverse_iterator iter =
           inputs_.reshape_order_.rbegin();
C
Channingss 已提交
639 640
       iter != inputs_.reshape_order_.rend();
       ++iter) {
C
Channingss 已提交
641
    if (*iter == "padding") {
C
Channingss 已提交
642
      auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
C
Channingss 已提交
643 644 645
      inputs_.im_size_before_resize_.pop_back();
      auto padding_w = before_shape[0];
      auto padding_h = before_shape[1];
J
jack 已提交
646 647
      mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
      mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
C
Channingss 已提交
648
    } else if (*iter == "resize") {
C
Channingss 已提交
649
      auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
C
Channingss 已提交
650 651 652
      inputs_.im_size_before_resize_.pop_back();
      auto resize_w = before_shape[0];
      auto resize_h = before_shape[1];
C
Channingss 已提交
653 654 655 656 657 658 659 660 661 662 663
      cv::resize(mask_label,
                 mask_label,
                 cv::Size(resize_h, resize_w),
                 0,
                 0,
                 cv::INTER_NEAREST);
      cv::resize(mask_score,
                 mask_score,
                 cv::Size(resize_h, resize_w),
                 0,
                 0,
J
jack 已提交
664
                 cv::INTER_LINEAR);
C
Channingss 已提交
665
    }
C
Channingss 已提交
666
    ++idx;
C
Channingss 已提交
667 668 669 670 671 672 673
  }
  result->label_map.data.assign(mask_label.begin<uint8_t>(),
                                mask_label.end<uint8_t>());
  result->label_map.shape = {mask_label.rows, mask_label.cols};
  result->score_map.data.assign(mask_score.begin<float>(),
                                mask_score.end<float>());
  result->score_map.shape = {mask_score.rows, mask_score.cols};
J
jack 已提交
674 675 676
  return true;
}

J
jack 已提交
677
bool Model::predict(const std::vector<cv::Mat>& im_batch,
678
                    std::vector<SegResult>* results,
J
jack 已提交
679 680
                    int thread_num) {
  for (auto& inputs : inputs_batch_) {
J
jack 已提交
681 682 683 684
    inputs.clear();
  }
  if (type == "classifier") {
    std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
J
jack 已提交
685
                 "to function predict()!" << std::endl;
J
jack 已提交
686 687 688
    return false;
  } else if (type == "detector") {
    std::cerr << "Loading model is a 'detector', DetResult should be passed to "
J
jack 已提交
689
                 "function predict()!" << std::endl;
J
jack 已提交
690 691 692
    return false;
  }

S
syyxsxx 已提交
693
  // preprocess
J
jack 已提交
694
  inputs_batch_.assign(im_batch.size(), ImageBlob());
J
jack 已提交
695
  if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
J
jack 已提交
696 697 698 699 700
    std::cerr << "Preprocess failed!" << std::endl;
    return false;
  }

  int batch_size = im_batch.size();
701 702
  (*results).clear();
  (*results).resize(batch_size);
J
jack 已提交
703 704 705
  int h = inputs_batch_[0].new_im_size_[0];
  int w = inputs_batch_[0].new_im_size_[1];
  auto im_tensor = predictor_->GetInputTensor("image");
706 707
  im_tensor->Reshape({batch_size, input_channel_, h, w});
  std::vector<float> inputs_data(batch_size * input_channel_ * h * w);
J
jack 已提交
708 709 710
  for (int i = 0; i < batch_size; ++i) {
    std::copy(inputs_batch_[i].im_data_.begin(),
              inputs_batch_[i].im_data_.end(),
711
              inputs_data.begin() + i * input_channel_ * h * w);
J
jack 已提交
712 713
  }
  im_tensor->copy_from_cpu(inputs_data.data());
J
jack 已提交
714
  // im_tensor->copy_from_cpu(inputs_.im_data_.data());
J
jack 已提交
715

S
syyxsxx 已提交
716
  // predict
J
jack 已提交
717 718
  predictor_->ZeroCopyRun();

S
syyxsxx 已提交
719
  // get labelmap
J
jack 已提交
720 721 722 723 724 725 726 727 728 729 730 731 732
  auto output_names = predictor_->GetOutputNames();
  auto output_label_tensor = predictor_->GetOutputTensor(output_names[0]);
  std::vector<int> output_label_shape = output_label_tensor->shape();
  int size = 1;
  for (const auto& i : output_label_shape) {
    size *= i;
  }

  std::vector<int64_t> output_labels(size, 0);
  output_label_tensor->copy_to_cpu(output_labels.data());
  auto output_labels_iter = output_labels.begin();

  int single_batch_size = size / batch_size;
J
jack 已提交
733
  for (int i = 0; i < batch_size; ++i) {
734 735
    (*results)[i].label_map.data.resize(single_batch_size);
    (*results)[i].label_map.shape.push_back(1);
J
jack 已提交
736
    for (int j = 1; j < output_label_shape.size(); ++j) {
737
      (*results)[i].label_map.shape.push_back(output_label_shape[j]);
J
jack 已提交
738
    }
J
jack 已提交
739 740
    std::copy(output_labels_iter + i * single_batch_size,
              output_labels_iter + (i + 1) * single_batch_size,
741
              (*results)[i].label_map.data.data());
J
jack 已提交
742 743
  }

S
syyxsxx 已提交
744
  // get scoremap
J
jack 已提交
745 746 747 748 749 750 751 752 753 754 755 756
  auto output_score_tensor = predictor_->GetOutputTensor(output_names[1]);
  std::vector<int> output_score_shape = output_score_tensor->shape();
  size = 1;
  for (const auto& i : output_score_shape) {
    size *= i;
  }

  std::vector<float> output_scores(size, 0);
  output_score_tensor->copy_to_cpu(output_scores.data());
  auto output_scores_iter = output_scores.begin();

  int single_batch_score_size = size / batch_size;
J
jack 已提交
757
  for (int i = 0; i < batch_size; ++i) {
758 759
    (*results)[i].score_map.data.resize(single_batch_score_size);
    (*results)[i].score_map.shape.push_back(1);
J
jack 已提交
760
    for (int j = 1; j < output_score_shape.size(); ++j) {
761
      (*results)[i].score_map.shape.push_back(output_score_shape[j]);
J
jack 已提交
762
    }
J
jack 已提交
763 764
    std::copy(output_scores_iter + i * single_batch_score_size,
              output_scores_iter + (i + 1) * single_batch_score_size,
765
              (*results)[i].score_map.data.data());
J
jack 已提交
766 767
  }

S
syyxsxx 已提交
768
  // get origin image result
J
jack 已提交
769
  for (int i = 0; i < batch_size; ++i) {
770 771 772 773
    std::vector<uint8_t> label_map((*results)[i].label_map.data.begin(),
                                   (*results)[i].label_map.data.end());
    cv::Mat mask_label((*results)[i].label_map.shape[1],
                       (*results)[i].label_map.shape[2],
J
jack 已提交
774 775
                       CV_8UC1,
                       label_map.data());
J
jack 已提交
776

777 778
    cv::Mat mask_score((*results)[i].score_map.shape[2],
                       (*results)[i].score_map.shape[3],
J
jack 已提交
779
                       CV_32FC1,
780
                       (*results)[i].score_map.data.data());
J
jack 已提交
781 782 783 784 785 786 787
    int idx = 1;
    int len_postprocess = inputs_batch_[i].im_size_before_resize_.size();
    for (std::vector<std::string>::reverse_iterator iter =
             inputs_batch_[i].reshape_order_.rbegin();
         iter != inputs_batch_[i].reshape_order_.rend();
         ++iter) {
      if (*iter == "padding") {
J
jack 已提交
788 789
        auto before_shape =
            inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
J
jack 已提交
790 791 792 793 794 795
        inputs_batch_[i].im_size_before_resize_.pop_back();
        auto padding_w = before_shape[0];
        auto padding_h = before_shape[1];
        mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
        mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
      } else if (*iter == "resize") {
J
jack 已提交
796 797
        auto before_shape =
            inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
J
jack 已提交
798 799 800 801 802 803 804 805 806 807 808 809 810 811
        inputs_batch_[i].im_size_before_resize_.pop_back();
        auto resize_w = before_shape[0];
        auto resize_h = before_shape[1];
        cv::resize(mask_label,
                   mask_label,
                   cv::Size(resize_h, resize_w),
                   0,
                   0,
                   cv::INTER_NEAREST);
        cv::resize(mask_score,
                   mask_score,
                   cv::Size(resize_h, resize_w),
                   0,
                   0,
J
jack 已提交
812
                   cv::INTER_LINEAR);
J
jack 已提交
813 814 815
      }
      ++idx;
    }
816
    (*results)[i].label_map.data.assign(mask_label.begin<uint8_t>(),
J
jack 已提交
817
                                       mask_label.end<uint8_t>());
818 819
    (*results)[i].label_map.shape = {mask_label.rows, mask_label.cols};
    (*results)[i].score_map.data.assign(mask_score.begin<float>(),
J
jack 已提交
820
                                       mask_score.end<float>());
821
    (*results)[i].score_map.shape = {mask_score.rows, mask_score.cols};
J
jack 已提交
822 823
  }
  return true;
C
Channingss 已提交
824 825
}

J
jack 已提交
826
}  // namespace PaddleX