network_impl_base.h 6.6 KB
Newer Older
1 2 3 4 5 6 7
#pragma once

#include "lite/network.h"
#include "misc.h"
#include "tensor_impl_base.h"
#include "type_info.h"

8
#include <atomic>
9 10 11 12
#include <unordered_map>

namespace lite {

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
/*!
 * \brief network reference count
 */
class NetworkRefCount : public Singleton<NetworkRefCount> {
public:
    NetworkRefCount() : count(0) {}

    NetworkRefCount& operator++(int) {
        ++count;
        return *this;
    }
    NetworkRefCount& operator--(int) {
        --count;
        return *this;
    }
    int refcount() { return count; }

private:
    std::atomic<int> count;
};

34 35 36 37 38 39 40 41 42 43 44
/*!
 * \brief the Inner IO data struct, add some inner data from IO
 */
class IOInner : public IO {
public:
    //! use to flag the corresponding lite_tensor is filled, when the
    //! value of lite_tensor is filled, the have_sync is true, other wise false,
    //! this is used in async mode
    bool have_sync = false;
    //! Real input and output data location
    std::shared_ptr<Tensor> lite_tensor = nullptr;
45 46 47
    //! If the input is consists of discrete multiple tensors, lite_tensors is real
    //! input data location
    std::vector<std::shared_ptr<Tensor>> lite_tensors;
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

    IOInner() = default;
    IOInner(const IO& io) {
        name = io.name;
        is_host = io.is_host;
        io_type = io.io_type;
        config_layout = io.config_layout;
    }
};

/*!
 * \brief the realy network IO info when network run
 */
struct NetworkIOInner {
    std::vector<IOInner> inputs;
    std::vector<IOInner> outputs;
};

/*!
 * \brief implement the Network, contain the mgb related member
 */
class Network::NetworkImplBase : public DynTypeObj {
public:
71 72
    virtual ~NetworkImplBase() { NetworkRefCount::Instance()--; };
    NetworkImplBase() { NetworkRefCount::Instance()++; };
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

    //! set the config of the network, include:
    //! the inference device
    //! the other inference options, such as record_level, weight_preprocess...
    virtual void set_config(const Config& config) = 0;

    //! set the special io infomation, if not set, default io tensor will used,
    //! this is special for input/output is not host tensor, default the
    //! input/output tensors are host tensor
    virtual void set_io(const NetworkIO& network_io) = 0;

    //! only compute the output tensor in user configured
    virtual void compute_only_configured_output() = 0;

    //! get the network input and ouput tensor, the layout of which is
    //! sync from mge tensor
    virtual std::shared_ptr<Tensor> get_io_tensor(
M
Megvii Engine Team 已提交
90
            std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_IO) = 0;
91

92 93 94 95 96 97 98
    //! get the network input tensors which input consists of discrete multiple tensors,
    //! layout (1, c, h, w)
    virtual std::vector<std::shared_ptr<Tensor>> get_io_tensors(
            std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) {
        return {};
    }

99 100 101
    //! get the input tensor by index in the load_result tensormap
    virtual std::shared_ptr<Tensor> get_input_tensor(size_t index) = 0;

102 103 104 105 106 107
    //! get the network input tensors which input consists of discrete multiple tensors
    //! by index
    virtual std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) {
        return {};
    }

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
    //! get the output tensor by index in the load_result output_var_list
    virtual std::shared_ptr<Tensor> get_output_tensor(size_t index) = 0;

    //! get all the input tensor name in the order in load return
    virtual std::vector<const char*> get_all_input_name() const = 0;

    //! get all the output tensor name in the order in load return
    virtual std::vector<const char*> get_all_output_name() const = 0;

    //! get the input tensor name in the order in load return
    virtual const char* get_input_name(size_t index) const = 0;

    //! get the output tensor name in the order in load return
    virtual const char* get_output_name(size_t index) const = 0;

    //! set the callback in async model
    virtual void set_async_callback(const AsyncCallback& callback) = 0;

    //! set the start callback which will execute before network forward
    virtual void set_start_callback(const StartCallback& callback) = 0;

    //! set the finish callback which will execute after network forward
    virtual void set_finish_callback(const FinishCallback& callback) = 0;

    //! load the model and get the m_load_result
M
Megvii Engine Team 已提交
133 134 135
    virtual void load_model(
            std::shared_ptr<void> model_mem, size_t size,
            std::unordered_map<std::string, LiteAny> separate_config_map = {}) = 0;
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

    //! forward the network with filled input data and fill the output data
    //! to the output tensor
    virtual void forward() = 0;

    //! in sync model, wait utile the inference finish
    virtual void wait() = 0;

    //! set device id, default device id = 0
    virtual void set_device_id(int device_id) = 0;
    virtual int get_device_id() const = 0;
    virtual LiteBackend get_backend_type() const = 0;
    //! set stream id, default stream id = 0
    virtual void set_stream_id(int stream_id) = 0;
    virtual int get_stream_id() const = 0;

    virtual LiteDeviceType get_device_type() const = 0;

    //! enable profile the network, a file will be generated
    virtual void enable_profile_performance(std::string profile_file_path) = 0;
156 157 158 159 160 161 162 163

    //! get static peak memory info showed by Graph visualization
    virtual void get_static_memory_alloc_info(const std::string& log_dir) const {
        LITE_MARK_USED_VAR(log_dir);
        LITE_THROW(
                "This nerworkimpl doesn't support get_static_memory_alloc_info() "
                "function.");
    }
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
};

/******************************** friend class *****************************/
/*!
 * \brief friend class of Network, for convenient accessing the Network members
 */
class NetworkHelper {
public:
    static bool loaded(const std::shared_ptr<Network> network) {
        LITE_ASSERT(network);
        return network->m_loaded;
    }
    static void loaded(const std::shared_ptr<Network> network, bool loaded) {
        LITE_ASSERT(network);
        network->m_loaded = loaded;
    }
    static Network::NetworkImplBase* implement(const Network* network) {
        LITE_ASSERT(network);
        return network->m_impl.get();
    }
M
Megvii Engine Team 已提交
184
    static Network::NetworkImplBase* implement(const std::shared_ptr<Network> network) {
185 186 187
        LITE_ASSERT(network);
        return network->m_impl.get();
    }
M
Megvii Engine Team 已提交
188 189 190
    static void implement(
            const std::shared_ptr<Network> network,
            std::unique_ptr<Network::NetworkImplBase> impl) {
191 192 193 194 195 196 197 198
        LITE_ASSERT(network);
        network->m_impl = std::move(impl);
    }
};

}  // namespace lite

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}