interpreter_impl.h 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/**
 * \file imperative/src/impl/interpreter/interpreter_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include <deque>
#include <future>
#include <list>
17
#include <stack>
18 19 20
#include <thread>
#include <unordered_set>
#include <variant>
21
#include "megbrain/comp_node.h"
22 23
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/profiler.h"
M
Megvii Engine Team 已提交
24
#include "megbrain/utils/mempool.h"
25 26 27

#include "./commands.h"
#include "./option_manager.h"
28
#include "./stack_manager.h"
M
Megvii Engine Team 已提交
29
#include "./tensor_info.h"
30 31

#include "../profiler/events.h"
32 33 34 35 36 37 38 39 40 41 42 43 44 45

namespace mgb::imperative::interpreter::intl {

using Handle = Interpreter::Handle;

struct InterpreterImpl : Interpreter {
    std::unique_ptr<Channel> create_channel() override;
};

struct ChannelImpl : Interpreter::Channel {
    ChannelImpl();
    ~ChannelImpl() override;

    Handle put(const HostTensorND& value, bool no_cache) override;
46
    Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;
47 48 49 50 51

    void del(Handle) override;
    void drop(Handle) override;

    SmallVector<Handle> apply_op(
M
Megvii Engine Team 已提交
52
            std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) override;
53 54 55 56 57 58 59 60

    HostTensorND get_value(Handle) override;
    TensorShape get_shape(Handle) override;
    DType get_dtype(Handle) override;
    CompNode get_device(Handle) override;

    DeviceTensorND get_dev_tensor(Handle) override;

61
    bool check_available() override;
62 63 64
    void sync() override;
    void close() override;

65 66
    size_t get_option(std::string name) override;
    void set_option(std::string name, size_t value) override;
67
    void clear_candidates() override;
68

69 70
    void start_profile() override;
    void stop_profile() override;
71 72 73

    void push_scope(std::string) override;
    void pop_scope(std::string) override;
M
Megvii Engine Team 已提交
74

75
private:
76 77 78
    struct WorkQueue;
    struct State;

79
    TensorInfo* alloc();
80
    void init(TensorInfo*, LogicalTensorDesc desc);
81
    void free(TensorInfo*);
82 83 84
    void real_free(TensorInfo*);
    void recursive_free(TensorInfo*);
    void do_drop(TensorInfo*, bool);
85 86
    void detach_users(TensorInfo*);

87
    TensorInfo* put_impl(const HostTensorND& value, bool no_cache);
88
    TensorInfo* put_impl(const DeviceTensorND& value, const HostTensorND& hvalue);
89 90
    void del_impl(Handle);
    void sync_impl();
91
    SmallVector<Handle> apply_op_impl(
M
Megvii Engine Team 已提交
92
            std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs);
93 94 95
    TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop);
    void notify_tensor_unsafe(TensorInfo* info);

96
    void process_one_task(Command&);
97 98 99

    void check_worker_exc_unsafe();

100
    void produce_tensor(TensorInfo* dest, TensorPtr ptr);
101 102 103 104

    void release_tensor(TensorInfo* dest);

    void regenerate(TensorInfo* dest);
105
    void flush_apply_stack();
106
    void do_apply_op(const ApplyOp& cmd, std::string reason);
M
Megvii Engine Team 已提交
107 108 109 110 111

    std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>>
    init_output_and_workspace(
            const OpDef& def, SmallVector<TensorPtr> inputs,
            SmallVector<MemoryDesc> inputs_mem_desc);
112 113

    void dispatch_default_cpu(
M
Megvii Engine Team 已提交
114 115 116
            std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
            const SmallVector<LogicalTensorDesc>& input_descs,
            SmallVector<Handle>* outputs);
117
    void dispatch_kernel(
M
Megvii Engine Team 已提交
118 119 120
            std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
            const SmallVector<LogicalTensorDesc>& input_descs,
            SmallVector<Handle>* outputs);
121

122 123 124
    void push_scope(std::string, State&);
    void pop_scope(std::string, State&);

125 126 127 128
    void assert_in_channel();
    void assert_in_worker();
    std::thread::id get_worker_tid();

129 130 131 132 133
    void sample_on_device(CompNode device, bool force);

    // valid => status != Deleted
    std::unordered_set<TensorInfo*> collect_valid_tensors();

134
    std::mutex m_mutex;
135
    Spinlock m_spin;
136 137 138 139
    std::condition_variable m_cv;
    MemPool<TensorInfo> m_pool;
    std::unordered_set<Handle> m_valid_handle;
    TensorInfo* m_waitee = nullptr;
140
    uint64_t m_waitee_id = 0;
141
    std::exception_ptr m_worker_exc;
142
    std::function<void(std::string, std::string)> m_profile_dump_callback;
143
    size_t m_storage_id = 0;
144 145
    // TODO: use explicit struct
    std::stack<std::tuple<ApplyOp, size_t, TensorInfo*, std::string>> m_apply_stack;
146
    bool m_applying = false;
147 148
    bool m_closed = false;

149
    struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
150 151
        // set max_spin=0 to prevent Queue fetch task in busy wait manner.
        // this won't affect throughput when python interpreter is sending enough task,
M
Megvii Engine Team 已提交
152 153
        // but will significantly save CPU time when waiting for task, e.g. wait for
        // data input limit pending tasks to 10000
154
        WorkQueue(ChannelImpl* owner)
155
                : AsyncQueueSC<Command, WorkQueue>(0, 10000), m_owner(owner) {
156
            sys::set_thread_name("interpreter");
157 158
            if (const char* env_val = MGB_GETENV("MEGENGINE_ASYNC_QUEUE_SIZE")) {
                int len = strlen(env_val);
M
Megvii Engine Team 已提交
159 160 161 162
                for (int i = 0; i < len; i++) {
                    mgb_assert(
                            env_val[i] >= '0' && env_val[i] <= '9',
                            "async queue size should be an integer");
163 164 165 166 167
                }
                size_t val;
                sscanf(env_val, "%zu", &val);
                update_max_items(val);
            }
168
        }
M
Megvii Engine Team 已提交
169
        void process_one_task(Command& icmd) { m_owner->process_one_task(icmd); }
170
        void on_async_queue_worker_thread_start() override;
M
Megvii Engine Team 已提交
171

172 173 174 175 176 177 178 179 180 181
    private:
        ChannelImpl* m_owner;
    } m_worker;

    //! config whether raise error exactly when invoking op.
    //! level 2: both device and user side errors are async;
    //! level 1: user side errors are sync;
    //! level 0: both sync.
    int m_async_level = 2;

182
    struct State {
183
        std::thread::id tid;
184
        OptionManager options;
185 186
    };

M
Megvii Engine Team 已提交
187
    struct ChannelState : State {
188
        StackManager stack_manager;
189 190
    };

M
Megvii Engine Team 已提交
191
    struct WorkerState : State {};
192

193 194
    ChannelState m_channel_state;
    WorkerState m_worker_state;
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210

    /*!
     * \brief A framework of dynamic sublienar memory optimization
     *
     * Note: The main idea is that during the training process, if the memory
     * usage exceeds the threshold, select some tensors to evict until the
     * memory usage is below the threshold.
     */
    struct DynamicSublinear {
        /*!
         * \brief find an available tensor with the largest evaluation function
         *
         * Note: An available tensor must satisfy: (1) has computing path,
         * (2) is in memory, (3) is not pinned. Evaluation function refers to:
         * @see: TensorInfo::eval_func.
         *
M
Megvii Engine Team 已提交
211
         * \return the pointer of the best tensor; nullptr is returned if no
212 213
         * available tensor is found
         */
214
        TensorInfo* find_best_tensor(bool);
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232

        /*!
         * \brief estimate the cost of recomputing tensor ptr
         *
         * Note: We define the cost as the sum of the costs of each evicted
         * components where all the neighbors of ptr are located.
         */
        double estimate_neighbor_cost(TensorInfo* ptr);

        /*!
         * \brief update the last used time of the tensor ptr
         */
        void update_used_time(TensorInfo* ptr);

        /*!
         * \brief merge the two specified sets (the set in which the element x
         * is located, and the set in which the element y is located)
         */
M
Megvii Engine Team 已提交
233
        void merge(std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y);
234 235 236 237 238

        /*!
         * \brief return the representative of the set that contains the
         * element x
         */
M
Megvii Engine Team 已提交
239
        std::shared_ptr<DsuNode> find_father(std::shared_ptr<DsuNode>& x);
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266

        /*!
         * \brief update DSU after recomputing tensor ptr
         *
         * Delete ptr from the set where ptr is located. Since DSU does not
         * support this operation, instead, we reset the DSU father of ptr, and
         * subtract the recomputation cost of ptr from the cost of the original
         * set.
         */
        void update_dsu_after_recompute(TensorInfo* ptr);

        /*!
         * \brief update DSU after evicting tensor ptr
         *
         * Check the neighbors of x, that is, the input and output tensors, and
         * if they are evicted, merge their respective sets.
         */
        void update_dsu_after_evict(TensorInfo* ptr);

        /*!
         * \brief pin the tensors in vec
         */
        void pin(const SmallVector<TensorInfo*>& vec);

        /*!
         * \brief unpin the tensors in vec
         */
267
        void unpin(const SmallVector<TensorInfo*>& vec, WorkerState& state);
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291

        /*!
         * \brief add the tensor to the candidate set
         *
         * If the size of the tensor does not exceed the minimum threshold,
         * it will do nothing.
         */
        void insert_candidate(TensorInfo* ptr);

        /*!
         * \brief erase the tensor from the candidate set
         *
         * If the size of the tensor does not exceed the minimum threshold,
         * it will do nothing.
         */
        void erase_candidate(TensorInfo* ptr);

        //! estimate the current time, in order to reduce the overhead of timer
        double estimate_timestamp = 0;

        //! the comp node where dynamic sublinear memory optimization works
        CompNode comp_node;

        //! store all tensors that may be evicted
292
        SmallVector<TensorInfo*> candidates;
293

294
        bool is_bad_op(std::string op_name) {
M
Megvii Engine Team 已提交
295 296
            return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) !=
                   op_blacklist.end();
297 298
        }

M
Megvii Engine Team 已提交
299 300 301 302
        std::vector<std::string> op_blacklist = {
                "CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat",
                "GaussianRNG",    "UniformRNG", "GammaRNG",       "PermutationRNG",
                "PoissonRNG",     "BetaRNG"};
303 304 305
    } m_dtr;

    //! automatically evict an optimal tensor
306 307
    bool auto_evict(size_t);

308
    void alloc_tensor_with_evict(Blob*);
309 310 311 312

    // assert thread id when call get_xxx_state to avoid misuse
    ChannelState& get_channel_state();
    WorkerState& get_worker_state();
313 314
};

M
Megvii Engine Team 已提交
315
}  // namespace mgb::imperative::interpreter::intl