network_impl.h 11.0 KB
Newer Older
1 2 3
#pragma once

#include "lite_build_config.h"
4
#include "megbrain/graph.h"
5 6 7 8 9 10

#if LITE_BUILD_WITH_MGE
#include "lite/network.h"
#include "network_impl_base.h"
#include "tensor_impl.h"

11 12 13
#include <memory>
#include <unordered_map>
#include "megbrain/gopt/inference.h"
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include "megbrain/graph/bases.h"
#include "megbrain/plugin/opr_io_dump.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/serialization/extern_c_opr.h"
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/utils/thin/hash_table.h"

namespace lite {

/*!
 * \brief implement the Network, contain the mgb related member
 */
class NetworkImplDft final : public Network::NetworkImplBase {
    LITE_DYN_TYPE_OBJ_FINAL_DECL;

public:
32 33 34 35 36
    NetworkImplDft() {
        m_load_config.comp_graph = mgb::ComputingGraph::make();
        m_user_config = std::make_unique<Config>();
        m_network_io = std::make_unique<NetworkIOInner>();
    }
37
    using S = megdnn::param::ExecutionPolicy::Strategy;
38
    using Var = mgb::cg::SymbolVar;
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    //! set the config of the network, include:
    //! the inference device
    //! the other inference options, such as record_level, weight_preprocess...
    void set_config(const Config& config) override;

    //! set the special io infomation, if not set, default io tensor will used,
    //! this is special for input/output is not host tensor, default the
    //! input/output tensors are host tensor
    void set_io(const NetworkIO& network_io) override;

    //! only compute the output tensor in user configured
    void compute_only_configured_output() override {
        m_compute_configured_output_only = true;
    }

    //! get the network input and ouput tensor, the layout of which is
    //! sync from mge tensor
    std::shared_ptr<Tensor> get_io_tensor(
            std::string io_name,
            LiteTensorPhase phase = LiteTensorPhase::LITE_IO) override;

60 61
    //! get the network input tensors which input consists of discrete multiple tensors,
    //! layout (1, c, h, w)
62
    std::vector<std::shared_ptr<Tensor>> get_discrete_tensors(
63 64 65
            std::string io_name,
            LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) override;

66 67 68
    //! get the input tensor by index in the load_result tensormap
    std::shared_ptr<Tensor> get_input_tensor(size_t index) override;

69 70 71 72
    //! get the network input tensors which input consists of discrete multiple tensors
    //! by index
    std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) override;

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    //! get the output tensor by index in the load_result output_var_list
    std::shared_ptr<Tensor> get_output_tensor(size_t index) override;

    //! get all the input tensor name in the order in load return
    std::vector<const char*> get_all_input_name() const override;

    //! get all the output tensor name in the order in load return
    std::vector<const char*> get_all_output_name() const override;

    //! get the input tensor name in the order in load return
    const char* get_input_name(size_t index) const override;

    //! get the output tensor name in the order in load return
    const char* get_output_name(size_t index) const override;

    //! set the callback in async model
    void set_async_callback(const AsyncCallback& callback) override;

    //! set the start callback which will execute before network forward
    void set_start_callback(const StartCallback& callback) override {
        m_start_callback = std::move(callback);
    }

    //! set the finish callback which will execute after network forward
    void set_finish_callback(const FinishCallback& callback) override {
        m_finish_callback = std::move(callback);
    }

    //! load the model and get the m_load_result
M
Megvii Engine Team 已提交
102 103 104
    void load_model(
            std::shared_ptr<void> model_mem, size_t size,
            std::unordered_map<std::string, LiteAny> separate_config_map = {}) override;
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

    //! forward the network with filled input data and fill the output data
    //! to the output tensor
    void forward() override;

    //! in sync model, wait utile the inference finish
    void wait() override;

    virtual LiteDeviceType get_device_type() const override {
        return m_user_config->device_type;
    }

    //! Set cpu default mode when device is CPU, in some low computation
    //! device or single core device, this mode will get good performace
    void set_cpu_inplace_mode();
    bool is_cpu_inplace_mode() const { return m_is_cpu_inplace_mode; }

    //! When device is CPU, this interface will set the to be loaded model
    //! run in multi thread mode with the given thread number.
    void set_cpu_threads_number(size_t nr_threads);
    size_t get_cpu_threads_number() const { return m_nr_threads; }

    //! set device id, default device id = 0
    void set_device_id(int device_id) override;
    int get_device_id() const override { return m_compnode_locator.device; };

M
Megvii Engine Team 已提交
131
    LiteBackend get_backend_type() const override { return LiteBackend::LITE_DEFAULT; }
132 133 134 135 136 137 138 139
    //! set stream id, default stream id = 0
    void set_stream_id(int stream_id) override;
    int get_stream_id() const override { return m_compnode_locator.stream; };

    //! enable tensorrt
    void use_tensorrt();

    //! enable profile the network, a JSON format file will be generated
M
Megvii Engine Team 已提交
140
    void enable_profile_performance(std::string profile_json_file_path) override;
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

    /********************** mge special function ************************/
    //! load a new network which will share weights with src network
    void shared_weight_with(const NetworkImplBase* src_network);

    //! share the runtime memory with other network, the weights is not shared
    void share_runtime_memory_with(NetworkImplBase* network);
    //! set threads affinity callback;
    void set_runtime_thread_affinity(
            const ThreadAffinityCallback& thread_affinity_callback);

    //! set the network memroy allocator, the allocator is defined by user
    void set_memory_allocator(std::shared_ptr<Allocator> user_allocator);

    //! set opr algorithm selection strategy in the network
M
Megvii Engine Team 已提交
156 157 158
    void set_network_algo_policy(
            LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
            bool binary_equal_between_batch);
159 160 161 162 163 164 165 166 167 168 169 170 171

    //! set workspace_limit for oprs with multiple algorithms, set
    //! workspace limitation can save memory but may influence the performance
    void set_network_algo_workspace_limit(size_t workspace_limit);

    //! Dump input/output values of all internal variables to output file,
    //! in text format
    void enable_io_txt_dump(std::string io_txt_out_file);

    //! Dump input/output values of all internal variables to output
    //! directory, in binary format
    void enable_io_bin_dump(std::string io_bin_out_dir);

172 173 174 175
    //! get static peak memory info showed by Graph visualization
    void get_static_memory_alloc_info(
            const std::string& log_dir = "logs/test") const override;

176 177 178 179 180 181
    //! set global layout transform optimization for network
    void enable_global_layout_transform();

    //! dump network after global layout transform optimization
    void dump_layout_transform_model(std::string optimized_model_path);

182 183 184 185
    mgb::serialization::GraphLoader::LoadResult get_load_result() {
        return m_load_result;
    }

186 187 188 189 190
private:
    //! construct the outputspec according to the m_network_io, and set the
    //! call_back to the outputspec
    void make_output_spec();

191 192 193 194
    //! do layout transform for the given platform target, maybe the global
    //! layout optimization or heuristically choose the best layout according to
    //! the device information
    void layout_transform_optimization();
195

196 197 198 199 200 201 202
    //! modify the execution policy
    void modify_exection_policy();

    //! if the input is dev tensor, the pass will replace the H2D Opr to
    //! VolatileSharedDeviceTensor Opr
    void replace_dev_input_pass();

203 204 205 206 207
    //! if the input to the network is a list of tensors, this pass will replace
    //! the opr that supports the input of a list of tensors with the corresponding
    //! version, current support WarpPerspective
    void replace_src_discrete_input_opr_pass();

208 209 210 211 212 213 214 215 216
    //! check whether the model is cross compnode
    void cross_compnode_model_detect();

    //! when the model have loaded, update the IO, if not set networkio, update
    //! the networkio with the IO of loaded model
    void update_io();

    void update_input();
    void update_output();
217 218
    //! initialization lite_tensors when input is composed of discrete multiple tensors
    void update_input_lite_tensors();
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235

    //! when the model info have loaded, update the config according the model
    //! info, finaly use it in compute graph
    void application_config();

    //! after finish forwarding the netwark, output the result of plugin to file
    void output_plugin_result() const;

    //! when finish forwarding the network, the function will be called
    void finish() const;

    //! before forwarding the network, the function will be called
    void start() const;

    //! compile the graph to get the execute function
    void compile_graph();

236
    //! try to infer output tensor layout
237 238 239 240
    void try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var);

    //! optimized output tensor copy
    void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor);
241

242 243 244
    //! configure and optimize network after loaded
    void configure_after_loaded();

245 246 247 248 249 250
private:
    bool m_async = false;
    bool m_is_cpu_inplace_mode = false;
    int m_nr_device_type = 0;
    size_t m_nr_threads = 1;
    bool m_compute_configured_output_only = false;
251
    bool m_set_layout_transform = false;
252 253 254 255 256 257 258 259 260 261
    mgb::CompNode::Locator m_compnode_locator;

    AsyncCallback m_async_callback = nullptr;
    std::unique_ptr<NetworkIOInner> m_network_io;
    std::unique_ptr<Config> m_user_config;
    std::unique_ptr<mgb::cg::AsyncExecutable> m_execute_func;

    //! The model load related data
    S m_execution_policy = static_cast<S>(0);
    std::unique_ptr<mgb::serialization::InputFile> m_input_file;
262 263 264
    mgb::Maybe<mgb::serialization::GraphDumpFormat> m_format;
    mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;

265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
    mgb::serialization::GraphLoadConfig m_load_config;
    mgb::serialization::GraphLoader::LoadResult m_load_result;
    mgb::ComputingGraph::OutputSpec m_output_spec;
    std::shared_ptr<mgb::serialization::GraphLoader> m_loader;

    //! start and finish callback
    StartCallback m_start_callback = nullptr;
    FinishCallback m_finish_callback = nullptr;

    //! profile and io dump related data
#if MGB_ENABLE_JSON
    std::unique_ptr<mgb::GraphProfiler> m_profiler;
    std::string m_profiler_output_file;
#endif
    std::unique_ptr<mgb::OprIODumpBase> m_iodump;
};
281 282 283 284 285 286 287
//! get the model information before model loaded by Network
NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config);

//! get the model information before model loaded by Network by model memory and
//! size
NetworkIO get_model_io_info_dft(
        const void* model_mem, size_t size, const Config& config);
288 289 290 291 292 293

}  // namespace lite

#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}