interpreter_impl.h 10.6 KB
Newer Older
1 2 3 4 5
#pragma once

#include <deque>
#include <future>
#include <list>
6
#include <stack>
7 8 9
#include <thread>
#include <unordered_set>
#include <variant>
10
#include "megbrain/comp_node.h"
11 12
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/profiler.h"
M
Megvii Engine Team 已提交
13
#include "megbrain/utils/mempool.h"
14 15 16

#include "./commands.h"
#include "./option_manager.h"
17
#include "./stack_manager.h"
M
Megvii Engine Team 已提交
18
#include "./tensor_info.h"
19 20

#include "../profiler/events.h"
21
#include "megbrain/imperative/backtrace.h"
22 23 24 25 26 27 28 29 30

namespace mgb::imperative::interpreter::intl {

using Handle = Interpreter::Handle;

struct InterpreterImpl : Interpreter {
    std::unique_ptr<Channel> create_channel() override;
};

31 32 33 34 35
/*!
 * \brief implement Channel to execute the commands asynchronously,
 * almost commands are executed by the worker threads, commands are sent
 * by the interface
 */
36
struct ChannelImpl : Interpreter::Channel, NonCopyableObj, NonMoveableObj {
37 38 39 40
    ChannelImpl();
    ~ChannelImpl() override;

    Handle put(const HostTensorND& value, bool no_cache) override;
41
    Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;
42 43 44 45 46

    void del(Handle) override;
    void drop(Handle) override;

    SmallVector<Handle> apply_op(
M
Megvii Engine Team 已提交
47
            std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) override;
48 49 50 51 52 53 54 55

    HostTensorND get_value(Handle) override;
    TensorShape get_shape(Handle) override;
    DType get_dtype(Handle) override;
    CompNode get_device(Handle) override;

    DeviceTensorND get_dev_tensor(Handle) override;

56
    bool check_available() override;
57 58 59
    void sync() override;
    void close() override;

60 61
    size_t get_option(std::string name) override;
    void set_option(std::string name, size_t value) override;
62
    void clear_candidates() override;
63

64 65
    void start_profile() override;
    void stop_profile() override;
66
    void stop_step() override;
67 68 69

    void push_scope(std::string) override;
    void pop_scope(std::string) override;
M
Megvii Engine Team 已提交
70

71 72 73 74
    BackTraceInfoPtr& get_backtrace() override;
    void set_backtrace(BackTraceInfoPtr bt) override;
    void clear_backtrace() override;

75 76 77 78 79 80 81
    bool worker_started() const;
    void update_status_to_forked(void);
    void assert_available() const;

    static std::unordered_set<ChannelImpl*> m_all_active_channels;
    static MGB_MUTEX m_all_active_channels_mutex;

82
private:
83 84 85
    struct WorkQueue;
    struct State;

86
    TensorInfo* alloc();
87
    void init(TensorInfo*, LogicalTensorDesc&& desc);
88
    void free(TensorInfo*);
89 90 91
    void real_free(TensorInfo*);
    void recursive_free(TensorInfo*);
    void do_drop(TensorInfo*, bool);
92 93
    void detach_users(TensorInfo*);

94
    TensorInfo* put_impl(const HostTensorND& value, bool no_cache);
95
    TensorInfo* put_impl(const DeviceTensorND& value, const HostTensorND& hvalue);
96 97
    void del_impl(Handle);
    void sync_impl();
98
    SmallVector<Handle> apply_op_impl(
M
Megvii Engine Team 已提交
99
            std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs);
100 101 102
    TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop);
    void notify_tensor_unsafe(TensorInfo* info);

103
    void process_one_task(Command&);
104 105 106

    void check_worker_exc_unsafe();

107
    void produce_tensor(TensorInfo* dest, TensorPtr ptr);
108 109 110 111

    void release_tensor(TensorInfo* dest);

    void regenerate(TensorInfo* dest);
112
    void flush_apply_stack();
113
    void do_apply_op(const ApplyOp& cmd, std::string reason);
M
Megvii Engine Team 已提交
114

115
    void dispatch_default_cpu(
M
Megvii Engine Team 已提交
116 117 118
            std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
            const SmallVector<LogicalTensorDesc>& input_descs,
            SmallVector<Handle>* outputs);
119
    void dispatch_kernel(
M
Megvii Engine Team 已提交
120 121 122
            std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
            const SmallVector<LogicalTensorDesc>& input_descs,
            SmallVector<Handle>* outputs);
123

124 125 126
    void push_scope(std::string, State&);
    void pop_scope(std::string, State&);

127 128 129 130
    void assert_in_channel();
    void assert_in_worker();
    std::thread::id get_worker_tid();

131 132 133 134 135
    void sample_on_device(CompNode device, bool force);

    // valid => status != Deleted
    std::unordered_set<TensorInfo*> collect_valid_tensors();

136
    std::mutex m_mutex;
137
    Spinlock m_spin;
138 139 140 141
    std::condition_variable m_cv;
    MemPool<TensorInfo> m_pool;
    std::unordered_set<Handle> m_valid_handle;
    TensorInfo* m_waitee = nullptr;
142
    BackTraceInfoPtr m_bt = nullptr;
143 144
    Spinlock m_pool_spin;
    Spinlock m_info_spin;
145
    uint64_t m_waitee_id = 0;
146
    std::exception_ptr m_worker_exc;
147
    std::function<void(std::string, std::string)> m_profile_dump_callback;
148
    size_t m_storage_id = 0;
149 150
    // TODO: use explicit struct
    std::stack<std::tuple<ApplyOp, size_t, TensorInfo*, std::string>> m_apply_stack;
151
    bool m_applying = false;
152 153 154

    enum class ChannelRunningStatus { RUNING, CLOSED, FORKED };
    ChannelRunningStatus m_status = ChannelRunningStatus::RUNING;
155

156
    struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
157 158
        // set max_spin=0 to prevent Queue fetch task in busy wait manner.
        // this won't affect throughput when python interpreter is sending enough task,
M
Megvii Engine Team 已提交
159 160
        // but will significantly save CPU time when waiting for task, e.g. wait for
        // data input limit pending tasks to 10000
161
        WorkQueue(ChannelImpl* owner)
162
                : AsyncQueueSC<Command, WorkQueue>(0, 10000), m_owner(owner) {
163
            sys::set_thread_name("interpreter");
164 165
            if (const char* env_val = MGB_GETENV("MEGENGINE_ASYNC_QUEUE_SIZE")) {
                int len = strlen(env_val);
M
Megvii Engine Team 已提交
166 167 168 169
                for (int i = 0; i < len; i++) {
                    mgb_assert(
                            env_val[i] >= '0' && env_val[i] <= '9',
                            "async queue size should be an integer");
170 171 172 173 174
                }
                size_t val;
                sscanf(env_val, "%zu", &val);
                update_max_items(val);
            }
175
        }
M
Megvii Engine Team 已提交
176
        void process_one_task(Command& icmd) { m_owner->process_one_task(icmd); }
177
        void on_async_queue_worker_thread_start() override;
M
Megvii Engine Team 已提交
178

179 180 181 182
    private:
        ChannelImpl* m_owner;
    } m_worker;

183
    struct State {
184
        std::thread::id tid;
185
        OptionManager options;
186 187
    };

M
Megvii Engine Team 已提交
188
    struct ChannelState : State {
189
        StackManager stack_manager;
190 191
    };

M
Megvii Engine Team 已提交
192
    struct WorkerState : State {};
193

194 195
    ChannelState m_channel_state;
    WorkerState m_worker_state;
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211

    /*!
     * \brief A framework of dynamic sublienar memory optimization
     *
     * Note: The main idea is that during the training process, if the memory
     * usage exceeds the threshold, select some tensors to evict until the
     * memory usage is below the threshold.
     */
    struct DynamicSublinear {
        /*!
         * \brief find an available tensor with the largest evaluation function
         *
         * Note: An available tensor must satisfy: (1) has computing path,
         * (2) is in memory, (3) is not pinned. Evaluation function refers to:
         * @see: TensorInfo::eval_func.
         *
M
Megvii Engine Team 已提交
212
         * \return the pointer of the best tensor; nullptr is returned if no
213 214
         * available tensor is found
         */
215
        TensorInfo* find_best_tensor(bool);
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233

        /*!
         * \brief estimate the cost of recomputing tensor ptr
         *
         * Note: We define the cost as the sum of the costs of each evicted
         * components where all the neighbors of ptr are located.
         */
        double estimate_neighbor_cost(TensorInfo* ptr);

        /*!
         * \brief update the last used time of the tensor ptr
         */
        void update_used_time(TensorInfo* ptr);

        /*!
         * \brief merge the two specified sets (the set in which the element x
         * is located, and the set in which the element y is located)
         */
M
Megvii Engine Team 已提交
234
        void merge(std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y);
235 236 237 238 239

        /*!
         * \brief return the representative of the set that contains the
         * element x
         */
M
Megvii Engine Team 已提交
240
        std::shared_ptr<DsuNode> find_father(std::shared_ptr<DsuNode>& x);
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267

        /*!
         * \brief update DSU after recomputing tensor ptr
         *
         * Delete ptr from the set where ptr is located. Since DSU does not
         * support this operation, instead, we reset the DSU father of ptr, and
         * subtract the recomputation cost of ptr from the cost of the original
         * set.
         */
        void update_dsu_after_recompute(TensorInfo* ptr);

        /*!
         * \brief update DSU after evicting tensor ptr
         *
         * Check the neighbors of x, that is, the input and output tensors, and
         * if they are evicted, merge their respective sets.
         */
        void update_dsu_after_evict(TensorInfo* ptr);

        /*!
         * \brief pin the tensors in vec
         */
        void pin(const SmallVector<TensorInfo*>& vec);

        /*!
         * \brief unpin the tensors in vec
         */
268 269
        void unpin(
                const SmallVector<TensorInfo*>& vec, size_t& dtr_evictee_minimum_size);
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

        /*!
         * \brief add the tensor to the candidate set
         *
         * If the size of the tensor does not exceed the minimum threshold,
         * it will do nothing.
         */
        void insert_candidate(TensorInfo* ptr);

        /*!
         * \brief erase the tensor from the candidate set
         *
         * If the size of the tensor does not exceed the minimum threshold,
         * it will do nothing.
         */
        void erase_candidate(TensorInfo* ptr);

        //! estimate the current time, in order to reduce the overhead of timer
        double estimate_timestamp = 0;

        //! the comp node where dynamic sublinear memory optimization works
        CompNode comp_node;

        //! store all tensors that may be evicted
294
        SmallVector<TensorInfo*> candidates;
295

296
        bool is_bad_op(std::string op_name) {
M
Megvii Engine Team 已提交
297 298
            return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) !=
                   op_blacklist.end();
299 300
        }

301 302
        // operators that cannot be re-computed, including :
        // distributed operators, inplace operator, random generator operators
M
Megvii Engine Team 已提交
303 304 305 306
        std::vector<std::string> op_blacklist = {
                "CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat",
                "GaussianRNG",    "UniformRNG", "GammaRNG",       "PermutationRNG",
                "PoissonRNG",     "BetaRNG"};
307 308 309
    } m_dtr;

    //! automatically evict an optimal tensor
310 311
    bool auto_evict(size_t);

312
    void alloc_tensor_with_evict(OwnedBlob*);
313 314 315 316

    // assert thread id when call get_xxx_state to avoid misuse
    ChannelState& get_channel_state();
    WorkerState& get_worker_state();
317 318
};

M
Megvii Engine Team 已提交
319
}  // namespace mgb::imperative::interpreter::intl
320 321

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}