ppredictor.h 1.3 KB
Newer Older
W
WenmuZhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
#pragma once

#include "paddle_api.h"
#include "predictor_input.h"
#include "predictor_output.h"

namespace ppredictor {

/**
 * PaddleLite Preditor Common Interface
 */
class PPredictor_Interface {
public:
  virtual ~PPredictor_Interface() {}

  virtual NET_TYPE get_net_flag() const = 0;
};

/**
 * Common Predictor
 */
class PPredictor : public PPredictor_Interface {
public:
  PPredictor(
      int thread_num, int net_flag = 0,
      paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH);

  virtual ~PPredictor() {}

  /**
   * init paddlitelite opt model,nb format ,or use ini_paddle
   * @param model_content
   * @return 0
   */
  virtual int init_nb(const std::string &model_content);

  virtual int init_from_file(const std::string &model_content);

  std::vector<PredictorOutput> infer();

  std::shared_ptr<paddle::lite_api::PaddlePredictor> get_predictor() {
    return _predictor;
  }

  virtual std::vector<PredictorInput> get_inputs(int num);

  virtual PredictorInput get_input(int index);

  virtual PredictorInput get_first_input();

  virtual NET_TYPE get_net_flag() const;

protected:
  template <typename ConfigT> int _init(ConfigT &config);

private:
  int _thread_num;
  paddle::lite_api::PowerMode _mode;
  std::shared_ptr<paddle::lite_api::PaddlePredictor> _predictor;
  bool _is_input_get = false;
  int _net_flag;
};
}