cls.cpp 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <include/cls.h>

namespace PaddleClas {

19 20
void Classifier::LoadModel(const std::string &model_path,
                           const std::string &params_path) {
L
littletomatodonkey 已提交
21
  paddle_infer::Config config;
22
  config.SetModel(model_path, params_path);
23 24 25

  if (this->use_gpu_) {
    config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
L
littletomatodonkey 已提交
26 27 28 29 30 31 32
    if (this->use_tensorrt_) {
      config.EnableTensorRtEngine(
          1 << 20, 1, 3,
          this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
                          : paddle_infer::Config::Precision::kFloat32,
          false, false);
    }
33 34 35 36 37 38 39 40 41 42
  } else {
    config.DisableGpu();
    if (this->use_mkldnn_) {
      config.EnableMKLDNN();
      // cache 10 different shapes for mkldnn to avoid memory leak
      config.SetMkldnnCacheCapacity(10);
    }
    config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
  }

L
littletomatodonkey 已提交
43
  config.SwitchUseFeedFetchOps(false);
44 45 46 47 48 49 50 51
  // true for multiple input
  config.SwitchSpecifyInputNames(true);

  config.SwitchIrOptim(true);

  config.EnableMemoryOptim();
  config.DisableGlogInfo();

L
littletomatodonkey 已提交
52
  this->predictor_ = CreatePredictor(config);
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
}

void Classifier::Run(cv::Mat &img) {
  cv::Mat srcimg;
  cv::Mat resize_img;
  img.copyTo(srcimg);

  this->resize_op_.Run(img, resize_img, this->resize_short_size_);

  this->crop_op_.Run(resize_img, this->crop_size_);

  this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
                          this->is_scale_);
  std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
  this->permute_op_.Run(&resize_img, input.data());

L
littletomatodonkey 已提交
69 70 71 72 73
  auto input_names = this->predictor_->GetInputNames();
  auto input_t = this->predictor_->GetInputHandle(input_names[0]);
  input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
  input_t->CopyFromCpu(input.data());
  this->predictor_->Run();
74 75 76

  std::vector<float> out_data;
  auto output_names = this->predictor_->GetOutputNames();
L
littletomatodonkey 已提交
77
  auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
78 79 80 81 82
  std::vector<int> output_shape = output_t->shape();
  int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
                                std::multiplies<int>());

  out_data.resize(out_num);
L
littletomatodonkey 已提交
83
  output_t->CopyToCpu(out_data.data());
84 85 86 87 88 89 90 91 92 93

  int maxPosition =
      max_element(out_data.begin(), out_data.end()) - out_data.begin();
  std::cout << "result: " << std::endl;
  std::cout << "\tclass id: " << maxPosition << std::endl;
  std::cout << std::fixed << std::setprecision(10)
            << "\tscore: " << double(out_data[maxPosition]) << std::endl;
}

} // namespace PaddleClas