interpreter_impl.h 4.6 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11
/**
 * \file imperative/src/impl/interpreter_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
#include <variant>
#include <future>

#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"


namespace mgb::imperative::interpreter::intl {

using Handle = Interpreter::Handle;

struct InterpreterImpl : Interpreter {
    std::unique_ptr<Channel> create_channel() override;
};

27 28 29 30 31 32 33 34 35
enum EvictType {
    NONE = 0,
    SWAP = 1,
    DROP = 2,
};

struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;

36 37 38 39
struct TensorInfo {
    TensorPtr ptr;
    LogicalTensorDesc desc;
    bool value_fetched = false;
40
    bool invalid = false;
41 42 43 44 45 46 47 48 49 50 51 52 53 54
    bool allow_delete = false;

    EvictType evict_type = NONE;

    HostTensorND h_value;
    size_t locked = 0;
    size_t recompute_times = 0;
    
    struct ComputePath {
        std::shared_ptr<OpDef> op;
        SmallVector<TensorInfoPtr> inputs;
        SmallVector<std::weak_ptr<TensorInfo>> outputs;
        SmallVector<std::weak_ptr<TensorInfo>> dep_outputs;
    } path;
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
};

struct Put {
    TensorInfo* dest;
    HostTensorND value;
};
struct ApplyOp {
    std::shared_ptr<OpDef> op;
    SmallVector<TensorInfo*> inputs;
    SmallVector<TensorInfo*> outputs;
};
struct Del {
    TensorInfo* dest;
};
struct GetValue {
    TensorInfo* dest;
};
72 73 74 75 76 77 78 79 80 81 82

struct SwapIn {
    TensorInfo* dest;
};
struct SwapOut {
    TensorInfo* dest;
};
struct Drop {
    TensorInfo* dest;
};

83 84 85
using Command = std::variant<Put,
                             ApplyOp,
                             Del,
86 87 88 89
                             GetValue,
                             SwapIn,
                             SwapOut,
                             Drop>;
90 91 92 93 94 95

struct ChannelImpl : Interpreter::Channel {
    ChannelImpl() : m_worker(this) {}
    ~ChannelImpl() override;

    Handle put(const HostTensorND& value) override;
M
Megvii Engine Team 已提交
96
    Handle put(const DeviceTensorND& value) override;
97 98

    void del(Handle) override;
99 100 101
    void swap_in(Handle) override;
    void swap_out(Handle) override;
    void drop(Handle) override;
102 103 104 105 106 107 108 109 110 111 112 113 114 115

    SmallVector<Handle> apply_op(
            std::shared_ptr<OpDef> op,
            const SmallVector<Handle>& inputs) override;

    HostTensorND get_value(Handle) override;
    TensorShape get_shape(Handle) override;
    DType get_dtype(Handle) override;
    CompNode get_device(Handle) override;

    DeviceTensorND get_dev_tensor(Handle) override;

    void sync() override;
    void close() override;
116 117
    void set_swap_flag(bool) override;
    void set_drop_flag(bool) override;
118 119

    void config_async_level(int level) override;
120
    int get_async_level() override;
121 122 123 124

private:
    TensorInfo* alloc();
    void free(TensorInfo*);
125
    void remove_dep(TensorInfo*);
126 127 128 129 130

    void process_one_task(Command&);

    void check_worker_exc_unsafe();

131 132 133 134 135
    void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice);
    void do_swap_out(TensorInfo* dest);
    void do_swap_in(TensorInfo* dest);
    void do_drop(TensorInfo* dest);
    void regenerate(TensorInfo* dest, bool must_drop);
136 137 138 139 140 141 142

    std::mutex m_mutex;
    std::condition_variable m_cv;
    MemPool<TensorInfo> m_pool;
    std::unordered_set<Handle> m_valid_handle;
    TensorInfo* m_waitee = nullptr;
    std::exception_ptr m_worker_exc;
143
    size_t m_enable_evict = 0;
144 145 146 147 148 149 150 151 152 153

    struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
        WorkQueue(ChannelImpl* owner) : m_owner(owner) {}
        void process_one_task(Command& cmd) {
            m_owner->process_one_task(cmd);
        }
    private:
        ChannelImpl* m_owner;
    } m_worker;

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    struct SharedTensorInfoMap {
        void insert(TensorInfo* info) {
            MGB_LOCK_GUARD(mtx);
            tmap.emplace(info, TensorInfoPtr{info, [](TensorInfo* ptr){ ptr->allow_delete = true;}});
        }
        void erase(TensorInfo* info) {
            MGB_LOCK_GUARD(mtx);
            tmap.erase(info);
        }
        TensorInfoPtr at(TensorInfo* info) {
            MGB_LOCK_GUARD(mtx);
            return tmap.at(info);
        }
    private:
        std::mutex mtx;
        std::unordered_map<TensorInfo*, TensorInfoPtr> tmap;
    }m_st;
    
172 173 174 175
    //! config whether raise error exactly when invoking op.
    //! level 2: both device and user side errors are async;
    //! level 1: user side errors are sync;
    //! level 0: both sync.
176
    int m_async_level = 2;
177
    int m_max_recompute_time = 1;
178 179 180
};

} // namespace mgb::imperative::interpreter::intl