network_impl_base.h 5.6 KB
Newer Older
1 2
/**
 * \file src/network_impl_base.h
3
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
 */

#pragma once

#include "lite/network.h"
#include "misc.h"
#include "tensor_impl_base.h"
#include "type_info.h"

#include <unordered_map>

namespace lite {

/*!
 * \brief the Inner IO data struct, add some inner data from IO
 */
class IOInner : public IO {
public:
    //! use to flag the corresponding lite_tensor is filled, when the
    //! value of lite_tensor is filled, the have_sync is true, other wise false,
    //! this is used in async mode
    bool have_sync = false;
    //! Real input and output data location
    std::shared_ptr<Tensor> lite_tensor = nullptr;

    IOInner() = default;
    IOInner(const IO& io) {
        name = io.name;
        is_host = io.is_host;
        io_type = io.io_type;
        config_layout = io.config_layout;
    }
};

/*!
 * \brief the realy network IO info when network run
 */
struct NetworkIOInner {
    std::vector<IOInner> inputs;
    std::vector<IOInner> outputs;
};

/*!
 * \brief implement the Network, contain the mgb related member
 */
class Network::NetworkImplBase : public DynTypeObj {
public:
    virtual ~NetworkImplBase() = default;

    //! set the config of the network, include:
    //! the inference device
    //! the other inference options, such as record_level, weight_preprocess...
    virtual void set_config(const Config& config) = 0;

    //! set the special io infomation, if not set, default io tensor will used,
    //! this is special for input/output is not host tensor, default the
    //! input/output tensors are host tensor
    virtual void set_io(const NetworkIO& network_io) = 0;

    //! only compute the output tensor in user configured
    virtual void compute_only_configured_output() = 0;

    //! get the network input and ouput tensor, the layout of which is
    //! sync from mge tensor
    virtual std::shared_ptr<Tensor> get_io_tensor(
            std::string io_name,
            LiteTensorPhase phase = LiteTensorPhase::LITE_IO) = 0;

    //! get the input tensor by index in the load_result tensormap
    virtual std::shared_ptr<Tensor> get_input_tensor(size_t index) = 0;

    //! get the output tensor by index in the load_result output_var_list
    virtual std::shared_ptr<Tensor> get_output_tensor(size_t index) = 0;

    //! get all the input tensor name in the order in load return
    virtual std::vector<const char*> get_all_input_name() const = 0;

    //! get all the output tensor name in the order in load return
    virtual std::vector<const char*> get_all_output_name() const = 0;

    //! get the input tensor name in the order in load return
    virtual const char* get_input_name(size_t index) const = 0;

    //! get the output tensor name in the order in load return
    virtual const char* get_output_name(size_t index) const = 0;

    //! set the callback in async model
    virtual void set_async_callback(const AsyncCallback& callback) = 0;

    //! set the start callback which will execute before network forward
    virtual void set_start_callback(const StartCallback& callback) = 0;

    //! set the finish callback which will execute after network forward
    virtual void set_finish_callback(const FinishCallback& callback) = 0;

    //! load the model and get the m_load_result
    virtual void load_model(std::shared_ptr<void> model_mem, size_t size,
                            std::unordered_map<std::string, LiteAny>
                                    separate_config_map = {}) = 0;

    //! forward the network with filled input data and fill the output data
    //! to the output tensor
    virtual void forward() = 0;

    //! in sync model, wait utile the inference finish
    virtual void wait() = 0;

    //! set device id, default device id = 0
    virtual void set_device_id(int device_id) = 0;
    virtual int get_device_id() const = 0;
    virtual LiteBackend get_backend_type() const = 0;
    //! set stream id, default stream id = 0
    virtual void set_stream_id(int stream_id) = 0;
    virtual int get_stream_id() const = 0;

    virtual LiteDeviceType get_device_type() const = 0;

    //! enable profile the network, a file will be generated
    virtual void enable_profile_performance(std::string profile_file_path) = 0;
};

/******************************** friend class *****************************/
/*!
 * \brief friend class of Network, for convenient accessing the Network members
 */
class NetworkHelper {
public:
    static bool loaded(const std::shared_ptr<Network> network) {
        LITE_ASSERT(network);
        return network->m_loaded;
    }
    static void loaded(const std::shared_ptr<Network> network, bool loaded) {
        LITE_ASSERT(network);
        network->m_loaded = loaded;
    }
    static Network::NetworkImplBase* implement(const Network* network) {
        LITE_ASSERT(network);
        return network->m_impl.get();
    }
    static Network::NetworkImplBase* implement(
            const std::shared_ptr<Network> network) {
        LITE_ASSERT(network);
        return network->m_impl.get();
    }
    static void implement(const std::shared_ptr<Network> network,
                          std::unique_ptr<Network::NetworkImplBase> impl) {
        LITE_ASSERT(network);
        network->m_impl = std::move(impl);
    }
};

}  // namespace lite

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}