model.h 1.3 KB
Newer Older
1 2 3 4
#pragma once
#include <gflags/gflags.h>
#include <string>
#include "helpers/common.h"
5
#include "megbrain/utils/json.h"
6
DECLARE_bool(lite);
7
DECLARE_bool(mdl);
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27

namespace lar {
/*!
 * \brief: base class of model
 */
class ModelBase {
public:
    //! get model type by the magic number in dump file
    static ModelType get_model_type(std::string model_path);

    //! create model by different model type
    static std::shared_ptr<ModelBase> create_model(std::string model_path);

    //! type of the model
    virtual ModelType type() = 0;

    //! set model load state

    virtual void set_shared_mem(bool state) = 0;

28 29
    virtual void create_network(){};

30 31 32 33 34 35 36 37 38 39
    //! load model interface for load and run strategy
    virtual void load_model() = 0;

    //! run model interface for load and run strategy
    virtual void run_model() = 0;

    //! wait asynchronous function interface for load and run strategy
    virtual void wait() = 0;

    virtual ~ModelBase() = default;
M
Megvii Engine Team 已提交
40 41

    virtual const std::string& get_model_path() const = 0;
42 43 44 45 46 47

    virtual std::vector<uint8_t> get_model_data() = 0;
#if MGB_ENABLE_JSON
    //! get model io information
    virtual std::shared_ptr<mgb::json::Object> get_io_info() = 0;
#endif
48 49 50 51
};
}  // namespace lar

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}