ppredictor.h 1.5 KB
Newer Older
A
authorfu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
#pragma once

#include "paddle_api.h"
#include "predictor_input.h"
#include "predictor_output.h"

namespace ppredictor {

/**
 * PaddleLite Preditor 通用接口
 */
class PPredictor_Interface {
public:
    virtual ~PPredictor_Interface() {

    }


    virtual NET_TYPE get_net_flag() const = 0;

};

/**
 * 通用推理
 */
class PPredictor : public PPredictor_Interface {
public:
    PPredictor(int thread_num, int net_flag = 0,
               paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH);

    virtual ~PPredictor() {

    }

    /**
     * 初始化paddlitelite的opt模型,nb格式,与init_paddle二选一
     * @param model_content
     * @return 0 目前是固定值0, 之后其他值表示失败
     */
    virtual int init_nb(const std::string &model_content);

    virtual int init_from_file(const std::string &model_content);

    std::vector<PredictorOutput> infer();

    std::shared_ptr<paddle::lite_api::PaddlePredictor> get_predictor() {
        return _predictor;
    }

    virtual std::vector<PredictorInput> get_inputs(int num);

    virtual PredictorInput get_input(int index);

    virtual PredictorInput get_first_input();

    virtual NET_TYPE get_net_flag() const;

protected:
    template<typename ConfigT>
    int _init(ConfigT &config);

private:
    int _thread_num;
    paddle::lite_api::PowerMode _mode;
    std::shared_ptr<paddle::lite_api::PaddlePredictor> _predictor;
    bool _is_input_get = false;
    int _net_flag;

};


}