ppredictor.cpp 1.9 KB
Newer Older
W
WenmuZhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
#include "ppredictor.h"
#include "common.h"

namespace ppredictor {
PPredictor::PPredictor(int thread_num, int net_flag,
                       paddle::lite_api::PowerMode mode)
    : _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {}

int PPredictor::init_nb(const std::string &model_content) {
  paddle::lite_api::MobileConfig config;
  config.set_model_from_buffer(model_content);
  return _init(config);
}

int PPredictor::init_from_file(const std::string &model_content) {
  paddle::lite_api::MobileConfig config;
  config.set_model_from_file(model_content);
  return _init(config);
}

template <typename ConfigT> int PPredictor::_init(ConfigT &config) {
  config.set_threads(_thread_num);
  config.set_power_mode(_mode);
  _predictor = paddle::lite_api::CreatePaddlePredictor(config);
  LOGI("paddle instance created");
  return RETURN_OK;
}

PredictorInput PPredictor::get_input(int index) {
  PredictorInput input{_predictor->GetInput(index), index, _net_flag};
  _is_input_get = true;
  return input;
}

std::vector<PredictorInput> PPredictor::get_inputs(int num) {
  std::vector<PredictorInput> results;
  for (int i = 0; i < num; i++) {
    results.emplace_back(get_input(i));
  }
  return results;
}

PredictorInput PPredictor::get_first_input() { return get_input(0); }

std::vector<PredictorOutput> PPredictor::infer() {
  LOGI("infer Run start %d", _net_flag);
  std::vector<PredictorOutput> results;
  if (!_is_input_get) {
    return results;
  }
  _predictor->Run();
  LOGI("infer Run end");

  for (int i = 0; i < _predictor->GetOutputNames().size(); i++) {
    std::unique_ptr<const paddle::lite_api::Tensor> output_tensor =
        _predictor->GetOutput(i);
    LOGI("output tensor[%d] size %ld", i, product(output_tensor->shape()));
    PredictorOutput result{std::move(output_tensor), i, _net_flag};
    results.emplace_back(std::move(result));
  }
  return results;
}

NET_TYPE PPredictor::get_net_flag() const { return (NET_TYPE)_net_flag; }
}