ppredictor.cpp 3.3 KB
Newer Older
W
WenmuZhou 已提交
1 2 3 4
#include "ppredictor.h"
#include "common.h"

namespace ppredictor {
5
PPredictor::PPredictor(int use_opencl, int thread_num, int net_flag,
W
WenmuZhou 已提交
6
                       paddle::lite_api::PowerMode mode)
7
    : _use_opencl(use_opencl), _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {}
W
WenmuZhou 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21

int PPredictor::init_nb(const std::string &model_content) {
  paddle::lite_api::MobileConfig config;
  config.set_model_from_buffer(model_content);
  return _init(config);
}

int PPredictor::init_from_file(const std::string &model_content) {
  paddle::lite_api::MobileConfig config;
  config.set_model_from_file(model_content);
  return _init(config);
}

template <typename ConfigT> int PPredictor::_init(ConfigT &config) {
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
  bool is_opencl_backend_valid = paddle::lite_api::IsOpenCLBackendValid(/*check_fp16_valid = false*/);
  if (is_opencl_backend_valid) {
    if (_use_opencl != 0) {
      // Make sure you have write permission of the binary path.
      // We strongly recommend each model has a unique binary name.
      const std::string bin_path = "/data/local/tmp/";
      const std::string bin_name = "lite_opencl_kernel.bin";
      config.set_opencl_binary_path_name(bin_path, bin_name);

      // opencl tune option
      // CL_TUNE_NONE: 0
      // CL_TUNE_RAPID: 1
      // CL_TUNE_NORMAL: 2
      // CL_TUNE_EXHAUSTIVE: 3
      const std::string tuned_path = "/data/local/tmp/";
      const std::string tuned_name = "lite_opencl_tuned.bin";
      config.set_opencl_tune(paddle::lite_api::CL_TUNE_NORMAL, tuned_path, tuned_name);

      // opencl precision option
      // CL_PRECISION_AUTO: 0, first fp16 if valid, default
      // CL_PRECISION_FP32: 1, force fp32
      // CL_PRECISION_FP16: 2, force fp16
      config.set_opencl_precision(paddle::lite_api::CL_PRECISION_FP32);
W
WenmuZhou 已提交
45
      LOGI("ocr cpp device: running on gpu.");
46 47
    }
  } else {
W
WenmuZhou 已提交
48
    LOGI("ocr cpp device: running on cpu.");
49 50 51
    // you can give backup cpu nb model instead
    // config.set_model_from_file(cpu_nb_model_dir);
  }
W
WenmuZhou 已提交
52 53 54
  config.set_threads(_thread_num);
  config.set_power_mode(_mode);
  _predictor = paddle::lite_api::CreatePaddlePredictor(config);
W
WenmuZhou 已提交
55
  LOGI("ocr cpp paddle instance created");
W
WenmuZhou 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
  return RETURN_OK;
}

PredictorInput PPredictor::get_input(int index) {
  PredictorInput input{_predictor->GetInput(index), index, _net_flag};
  _is_input_get = true;
  return input;
}

std::vector<PredictorInput> PPredictor::get_inputs(int num) {
  std::vector<PredictorInput> results;
  for (int i = 0; i < num; i++) {
    results.emplace_back(get_input(i));
  }
  return results;
}

PredictorInput PPredictor::get_first_input() { return get_input(0); }

std::vector<PredictorOutput> PPredictor::infer() {
W
WenmuZhou 已提交
76
  LOGI("ocr cpp infer Run start %d", _net_flag);
W
WenmuZhou 已提交
77 78 79 80 81
  std::vector<PredictorOutput> results;
  if (!_is_input_get) {
    return results;
  }
  _predictor->Run();
W
WenmuZhou 已提交
82
  LOGI("ocr cpp infer Run end");
W
WenmuZhou 已提交
83 84 85 86

  for (int i = 0; i < _predictor->GetOutputNames().size(); i++) {
    std::unique_ptr<const paddle::lite_api::Tensor> output_tensor =
        _predictor->GetOutput(i);
W
WenmuZhou 已提交
87
    LOGI("ocr cpp output tensor[%d] size %ld", i, product(output_tensor->shape()));
W
WenmuZhou 已提交
88 89 90 91 92 93 94 95
    PredictorOutput result{std::move(output_tensor), i, _net_flag};
    results.emplace_back(std::move(result));
  }
  return results;
}

NET_TYPE PPredictor::get_net_flag() const { return (NET_TYPE)_net_flag; }
}