cg_impl.h 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
/**
 * \file src/core/impl/graph/cg_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "./eager_eval.h"
#include "./grad_manager.h"
#include "./graph_opt.h"
#include "./seq_comp_node_opt_impl.h"
#include "./seq_sublinear_memory.h"
#include "./static_infer_impl.h"
#include "./swap/memory_swap.h"
#include "./topo_sort.h"
#include "./var_node_mem_mgr.h"

#include "megbrain/utils/mempool.h"

namespace mgb {
namespace cg {

class ComputingGraphImpl final : public ComputingGraph {
    class CallbackCaller;
    class RecordedComputingSequence;
    class MegDNNDtorCheck;
    class MultiPartCompiler;
    friend class GradManager;

    //! temporary state in compiling
    struct CompileState {
        //! extra info that must be set in the ComputingSequence
        CompSeqExtraInfo extra_info;
        const OprNodeArray* opr_seq = nullptr;
    };

    /*!
     * Components for implementing algorithms on a computing graph.
     *
     * They are put in a separate struct because they need to be destructed in
     * on_comp_node_finalize(), before ~ComputingGraphImpl().
     */
    struct Components : public NonCopyableObj {
        TopoSorter topo_sorter;
        VarNodeMemManager var_node_mem_manager;
        SeqCompNodeOptimizerImpl seq_comp_node_opt;
        static_infer::StaticInferManagerImpl static_infer_manager;
        static_infer::CompSeqManager static_infer_comp_seq_manager;
        GradManager grad_manager;
        GraphOptimizer graph_optimizer;
#if MGB_ENABLE_SUBLINEAR
        SeqModifierForSublinearMemory seq_modifier_for_sublinear_memory;
#endif
#if MGB_ENABLE_MEMORY_SWAP
        swap::MemorySwap memory_swap_support;
#endif
        EagerEvalManager eager_eval_manager;

        explicit Components(ComputingGraphImpl* owner);
    };

    //! valid if graph has been compiled with comp_node_seq_record_level == 2
    //! must be placed first so it can be destructed last
    std::unique_ptr<MegDNNDtorCheck> m_recorded_seq_level2_dtor_chk;

    MemPool<VarNode> m_var_node_pool;

    //! if not null, this graph is set as subgraph of it by set_as_subgraph()
    ComputingGraphImpl* m_parent_graph = nullptr;
    std::vector<ComputingGraphImpl*> m_subgraphs;

    AsyncExecutable* m_current_comp_seq = nullptr;

    std::shared_ptr<size_t> m_node_id_counter = std::make_shared<size_t>();

    std::vector<std::unique_ptr<OperatorNodeBase>> m_opr_refkeeper;

    /*!
     * list of operator nodes that take some var as one of the inputs; each
     * output var would be in this map, even with an empty operator set
     */
    ThinHashMap<VarNode*, OprNodeArray> m_var_receiver;

    std::aligned_storage_t<sizeof(Components), alignof(Components)>
            m_components_storage;

    /*!
     * \brief get dest vars and add extra_vardeps from OutputSpec
     * \param[out] has_virtual_grad whether there are VirtualGrad oprs that
     *      need to be expanded
     */
    VarNodeArray get_dest_vars_from_out_spec(const OutputSpec& spec,
                                             SpecialOprStat& sopr_stat);

    void cleanup();

    std::shared_ptr<void> on_comp_node_finalize() override;

    Components& components() {
        return reinterpret_cast<Components&>(m_components_storage);
    }

    const Components& components() const {
        return reinterpret_cast<const Components&>(m_components_storage);
    }

    //! prepare computing sequence and initialize opr sequence
    CompileState compile_prepare(const OutputSpec& out_spec);

    //! finalize the computing sequence for compiling
    std::unique_ptr<AsyncExecutable> compile_commit(CompileState state);

public:
    class ComputingSequence;

    ComputingGraphImpl();
    ~ComputingGraphImpl();

125 126 127 128 129 130 131
    template<typename T> static ComputingGraphImpl* downcast(T* ptr) = delete;

    inline static ComputingGraphImpl* downcast(ComputingGraph* graph) {
        mgb_assert(!graph->options().imperative_proxy_graph);
        return static_cast<ComputingGraphImpl*>(graph);
    }

132 133 134 135 136 137 138 139 140 141 142
    friend struct ComputingGraph::Options;

    std::unique_ptr<AsyncExecutable> compile(
            const OutputSpec& out_spec) override;

    SmallVector<std::unique_ptr<AsyncExecutable>> compile_multi_part(
            const SmallVector<OutputSpec>& out_specs) override;

    OperatorNodeBase* insert_opr(
            std::unique_ptr<OperatorNodeBase> opr) override;

143 144 145 146
    void* alloc_varnode_storage() override;

    void free_varnode_storage(void *ptr) override;

147 148 149 150 151 152 153 154 155 156 157 158
    const VarReceiverInfo& var_receiver_in_current_comp_seq(
            const VarNode* var) const override;

    /*!
     * \brief get the nodes in opr_set that directly depend on var (i.e.
     *      those oprs that take var as one input)
     *
     * Guaranteed no duplication in this list, and the order is stable
     */
    const OprNodeArray& var_receiver(VarNode* var) const {
        return m_var_receiver.at(var);
    }
159

160
    std::string get_mem_allocation_info() const override;
161 162 163 164 165

    VarNode* find_var_by_id(size_t id) const override;

    TopoSorter& topo_sorter() { return components().topo_sorter; }

166
    size_t next_node_id() override { return (*m_node_id_counter)++; }
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259

    VarNodeMemManager& var_node_mem_manager() {
        return components().var_node_mem_manager;
    }

    SeqCompNodeOptimizer& seq_comp_node_optimizer() override {
        return components().seq_comp_node_opt;
    }

    static_infer::StaticInferManager& static_infer_manager() override {
        return components().static_infer_manager;
    }

    static_infer::StaticInferManagerImpl& static_infer_manager_impl() {
        return components().static_infer_manager;
    }

    static_infer::CompSeqManager& static_infer_comp_seq_manager() {
        return components().static_infer_comp_seq_manager;
    }

    GraphOptimizer& graph_optimizer() { return components().graph_optimizer; }

    EagerEvalManager& eager_eval_manager() {
        return components().eager_eval_manager;
    }

#if MGB_ENABLE_SUBLINEAR
    SeqModifierForSublinearMemory& seq_modifier_for_sublinear_memory();
#endif

    void share_device_memory_with(ComputingGraph& other) override;

    void set_device_memory_allocator(
            std::shared_ptr<DeviceMemoryAllocator> allocator) override;

    size_t get_device_memory_size(CompNode cn) override;

    size_t clear_device_memory() override;

    void set_as_subgraph(ComputingGraph& par_graph) override;

    void record_async_error(std::unique_ptr<MegBrainError> async_exc) override;

    /*!
     * \brief latest computing sequence from this graph
     *
     * Since new computing sequence would invalidate memory layout of older
     * ones and operators may have class-level states, only the last
     * computing sequence could be used
     *
     * \return current comp seq, or nullptr if no alive comp seq
     */
    AsyncExecutable* current_comp_seq() override {
        return static_cast<AsyncExecutable*>(m_current_comp_seq);
    }

    /*!
     * \brief get current ExecEnv
     * \return ExecEnv if there is a compiled computing sequence, or
     *      nullptr if not compiled yet
     */
    GraphExecutable::ExecEnv* current_exec_env();

    /*!
     * \brief get step number of an operator in current computing sequence
     * \return step number; None if opr not in seq
     */
    Maybe<size_t> opr_step_num_in_cur_comp_seq(OperatorNodeBase* opr);

    /*!
     * \brief get extra info for current computing sequence
     */
    const CompSeqExtraInfo& current_comp_seq_extra_info();

    /*!
     * \brief get associated grad manager
     */
    GradManager& grad_manager() { return components().grad_manager; }

    //! get all the operators in this graph
    auto&& all_oprs() const { return m_opr_refkeeper; }

    size_t nr_oprs_in_graph() const override { return m_opr_refkeeper.size(); }

    //! memory pool for the var nodes; used by OperatorNodeBase
    auto&& var_node_pool() { return m_var_node_pool; }
};

}  // namespace cg
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}