interpreter_impl.h 9.1 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11
/**
 * \file imperative/src/impl/interpreter_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include <deque>
13
#include <future>
14 15 16
#include <list>
#include <unordered_set>
#include <variant>
17 18 19 20 21 22 23 24 25 26 27 28

#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"

namespace mgb::imperative::interpreter::intl {

using Handle = Interpreter::Handle;

struct InterpreterImpl : Interpreter {
    std::unique_ptr<Channel> create_channel() override;
};

29 30 31 32 33 34 35 36 37
enum EvictType {
    NONE = 0,
    SWAP = 1,
    DROP = 2,
};

struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;

38 39 40 41
struct TensorInfo {
    TensorPtr ptr;
    LogicalTensorDesc desc;
    bool value_fetched = false;
42
    bool invalid = false;
43 44 45 46 47 48 49
    bool allow_delete = false;

    EvictType evict_type = NONE;

    HostTensorND h_value;
    size_t locked = 0;
    size_t recompute_times = 0;
50

51 52 53 54 55 56
    struct ComputePath {
        std::shared_ptr<OpDef> op;
        SmallVector<TensorInfoPtr> inputs;
        SmallVector<std::weak_ptr<TensorInfo>> outputs;
        SmallVector<std::weak_ptr<TensorInfo>> dep_outputs;
    } path;
57 58 59 60 61
};

struct Put {
    TensorInfo* dest;
    HostTensorND value;
62
    bool no_cache = false;
63 64

    std::string to_string() const { return ssprintf("Command: Put %p", dest); }
65 66 67 68 69
};
struct ApplyOp {
    std::shared_ptr<OpDef> op;
    SmallVector<TensorInfo*> inputs;
    SmallVector<TensorInfo*> outputs;
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    SmallVector<TensorInfo*> dels;

    std::string to_string() const {
        std::string builder{"Command: ApplyOp {"};
        builder += "inputs [";
        for (auto* input : inputs) {
            builder += ssprintf("%p, ", input);
        }
        builder += "], outputs [";
        for (auto* output : outputs) {
            builder += ssprintf("%p, ", output);
        }
        builder += "], dels [";
        for (auto* del : dels) {
            builder += ssprintf("%p, ", del);
        }
        builder += "]";
        return builder;
    }
89 90 91
};
struct Del {
    TensorInfo* dest;
92 93

    std::string to_string() const { return ssprintf("Command: Del %p", dest); }
94 95 96
};
struct GetValue {
    TensorInfo* dest;
97

98 99 100 101
    std::string to_string() const {
        return ssprintf("Command: GetValue %p", dest);
    }
};
102 103
struct SwapIn {
    TensorInfo* dest;
104 105 106 107

    std::string to_string() const {
        return ssprintf("Command: SwapIn %p", dest);
    }
108 109 110
};
struct SwapOut {
    TensorInfo* dest;
111 112 113 114

    std::string to_string() const {
        return ssprintf("Command: SwapOut %p", dest);
    }
115 116 117
};
struct Drop {
    TensorInfo* dest;
118 119 120 121 122 123 124 125 126 127 128 129 130 131

    std::string to_string() const {
        return ssprintf("Command: Drop %p", dest);
    }
};
struct Move {
    TensorInfo* src;
    TensorInfo* dest;

    std::string to_string() const {
        return ssprintf("Command: Move %s to %s",
                        src->desc.layout.to_string().c_str(),
                        dest->desc.layout.to_string().c_str());
    }
132
};
133 134
struct Flush {
    TensorInfo* dest = nullptr;
135

136 137 138 139 140 141 142
    std::string to_string() const {
        return ssprintf("Command: Flush %p", dest);
    }
};
struct Nop {
    std::string to_string() const { return "Command: Nop"; }
};
143 144 145
using Command = std::variant<Put,
                             ApplyOp,
                             Del,
146 147 148
                             GetValue,
                             SwapIn,
                             SwapOut,
149 150 151 152
                             Drop,
                             Move,
                             Flush,
                             Nop>;
153 154

struct ChannelImpl : Interpreter::Channel {
155
    ChannelImpl() : m_worker(this), m_buffer(this) {}
156 157
    ~ChannelImpl() override;

158
    Handle put(const HostTensorND& value, bool no_cache) override;
M
Megvii Engine Team 已提交
159
    Handle put(const DeviceTensorND& value) override;
160 161

    void del(Handle) override;
162 163 164
    void swap_in(Handle) override;
    void swap_out(Handle) override;
    void drop(Handle) override;
165 166 167 168 169 170 171 172 173 174 175 176 177 178

    SmallVector<Handle> apply_op(
            std::shared_ptr<OpDef> op,
            const SmallVector<Handle>& inputs) override;

    HostTensorND get_value(Handle) override;
    TensorShape get_shape(Handle) override;
    DType get_dtype(Handle) override;
    CompNode get_device(Handle) override;

    DeviceTensorND get_dev_tensor(Handle) override;

    void sync() override;
    void close() override;
179 180
    void set_swap_flag(bool) override;
    void set_drop_flag(bool) override;
181
    void set_buffer_length(int) override;
182 183

    void config_async_level(int level) override;
184
    int get_async_level() override;
185 186 187 188

private:
    TensorInfo* alloc();
    void free(TensorInfo*);
189
    void remove_dep(TensorInfo*);
190 191 192 193 194

    void process_one_task(Command&);

    void check_worker_exc_unsafe();

195 196 197 198 199
    void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice);
    void do_swap_out(TensorInfo* dest);
    void do_swap_in(TensorInfo* dest);
    void do_drop(TensorInfo* dest);
    void regenerate(TensorInfo* dest, bool must_drop);
200 201 202 203 204 205 206

    std::mutex m_mutex;
    std::condition_variable m_cv;
    MemPool<TensorInfo> m_pool;
    std::unordered_set<Handle> m_valid_handle;
    TensorInfo* m_waitee = nullptr;
    std::exception_ptr m_worker_exc;
207
    size_t m_enable_evict = 0;
208 209

    struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
210 211 212 213 214
        // set max_spin=0 to prevent Queue fetch task in busy wait manner.
        // this won't affect throughput when python interpreter is sending enough task,
        // but will significantly save CPU time when waiting for task, e.g. wait for data input
        WorkQueue(ChannelImpl* owner)
                : AsyncQueueSC<Command, WorkQueue>(0), m_owner(owner) {
215 216
            sys::set_thread_name("interpreter");
        }
217 218 219
        void process_one_task(Command& cmd) {
            m_owner->process_one_task(cmd);
        }
220 221 222
        void on_async_queue_worker_thread_start() override {
               sys::set_thread_name("worker");
        }
223 224 225 226
    private:
        ChannelImpl* m_owner;
    } m_worker;

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
    struct SharedTensorInfoMap {
        void insert(TensorInfo* info) {
            MGB_LOCK_GUARD(mtx);
            tmap.emplace(info, TensorInfoPtr{info, [](TensorInfo* ptr){ ptr->allow_delete = true;}});
        }
        void erase(TensorInfo* info) {
            MGB_LOCK_GUARD(mtx);
            tmap.erase(info);
        }
        TensorInfoPtr at(TensorInfo* info) {
            MGB_LOCK_GUARD(mtx);
            return tmap.at(info);
        }
    private:
        std::mutex mtx;
        std::unordered_map<TensorInfo*, TensorInfoPtr> tmap;
    }m_st;
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

    /**
     * Buf a command window for following fuse
     * example:
     *     ---------------------------------------------------------------------
     *     | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1}  |
     *     ---------------------------------------------------------------------
     *     | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
     *     ---------------------------------------------------------------------
     *     | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ...       |
     *     ---------------------------------------------------------------------
     *     Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task
     */
    struct CommandBuffer {
        CommandBuffer(ChannelImpl* owner) : m_owner(owner) {
            int capacity = 3;
            if(const char* capacity_str = MGB_GETENV("MEGENGINE_COMMAND_BUFFER_LENGTH")) {
                capacity = atoi(capacity_str);
            }
            set_capacity(capacity);
        }
        void enqueue(Command cmd);
        bool empty() const {
            return m_commands.empty();
        }
        void set_capacity(int capacity) {
            mgb_assert(capacity >= 0 && capacity < 100, "invalid command buffer length");
            m_capacity = capacity;
        }
    private:
        ChannelImpl* m_owner;
        size_t m_capacity;
        std::deque<Command> m_commands;

        using Handle = decltype(m_commands)::iterator;
        // [begin, end)
        using Range = std::array<Handle, 2>;

        // Launch commands in range [m_commands.begin(), pos)
        void flush(Handle pos);
        // Select flush position for incoming cmd
        Handle flush_pos_for(const Command& cmd);
        // Fuse del command into suitable ApplyOp
        bool fuse_del(const Del& cmd);
        // Returns the last handle that dest is used within range. If dest is not used, returns range[1]
        Handle find_last_usage(TensorInfo* dest, Range range);
        // Returns the produce position of dest. If not found, returns range[1]
        Handle find_produce(TensorInfo* dest, Range range);
    } m_buffer;

294 295 296 297
    //! config whether raise error exactly when invoking op.
    //! level 2: both device and user side errors are async;
    //! level 1: user side errors are sync;
    //! level 0: both sync.
298
    int m_async_level = 2;
299
    int m_max_recompute_time = 1;
300 301 302
};

} // namespace mgb::imperative::interpreter::intl