/** * \file src/core/impl/graph/cg_impl.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "./eager_eval.h" #include "./grad_manager.h" #include "./graph_opt.h" #include "./seq_comp_node_opt_impl.h" #include "./seq_sublinear_memory.h" #include "./static_infer_impl.h" #include "./swap/memory_swap.h" #include "./topo_sort.h" #include "./var_node_mem_mgr.h" #include "megbrain/utils/mempool.h" namespace mgb { namespace cg { class ComputingGraphImpl final : public ComputingGraph { class CallbackCaller; class RecordedComputingSequence; class MegDNNDtorCheck; class MultiPartCompiler; friend class GradManager; //! temporary state in compiling struct CompileState { //! extra info that must be set in the ComputingSequence CompSeqExtraInfo extra_info; const OprNodeArray* opr_seq = nullptr; }; /*! * Components for implementing algorithms on a computing graph. * * They are put in a separate struct because they need to be destructed in * on_comp_node_finalize(), before ~ComputingGraphImpl(). */ struct Components : public NonCopyableObj { TopoSorter topo_sorter; VarNodeMemManager var_node_mem_manager; SeqCompNodeOptimizerImpl seq_comp_node_opt; static_infer::StaticInferManagerImpl static_infer_manager; static_infer::CompSeqManager static_infer_comp_seq_manager; GradManager grad_manager; GraphOptimizer graph_optimizer; #if MGB_ENABLE_SUBLINEAR SeqModifierForSublinearMemory seq_modifier_for_sublinear_memory; #endif #if MGB_ENABLE_MEMORY_SWAP swap::MemorySwap memory_swap_support; #endif EagerEvalManager eager_eval_manager; explicit Components(ComputingGraphImpl* owner); }; //! valid if graph has been compiled with comp_node_seq_record_level == 2 //! must be placed first so it can be destructed last std::unique_ptr m_recorded_seq_level2_dtor_chk; MemPool m_var_node_pool; //! if not null, this graph is set as subgraph of it by set_as_subgraph() ComputingGraphImpl* m_parent_graph = nullptr; std::vector m_subgraphs; AsyncExecutable* m_current_comp_seq = nullptr; std::shared_ptr m_node_id_counter = std::make_shared(); std::vector> m_opr_refkeeper; /*! * list of operator nodes that take some var as one of the inputs; each * output var would be in this map, even with an empty operator set */ ThinHashMap m_var_receiver; std::aligned_storage_t m_components_storage; /*! * \brief get dest vars and add extra_vardeps from OutputSpec * \param[out] has_virtual_grad whether there are VirtualGrad oprs that * need to be expanded */ VarNodeArray get_dest_vars_from_out_spec(const OutputSpec& spec, SpecialOprStat& sopr_stat); void cleanup(); std::shared_ptr on_comp_node_finalize() override; Components& components() { return reinterpret_cast(m_components_storage); } const Components& components() const { return reinterpret_cast(m_components_storage); } //! prepare computing sequence and initialize opr sequence CompileState compile_prepare(const OutputSpec& out_spec); //! finalize the computing sequence for compiling std::unique_ptr compile_commit(CompileState state); public: class ComputingSequence; ComputingGraphImpl(); ~ComputingGraphImpl(); template static ComputingGraphImpl* downcast(T* ptr) = delete; inline static ComputingGraphImpl* downcast(ComputingGraph* graph) { mgb_assert(!graph->options().imperative_proxy_graph); return static_cast(graph); } friend struct ComputingGraph::Options; std::unique_ptr compile( const OutputSpec& out_spec) override; SmallVector> compile_multi_part( const SmallVector& out_specs) override; OperatorNodeBase* insert_opr( std::unique_ptr opr) override; void* alloc_varnode_storage() override; void free_varnode_storage(void *ptr) override; const VarReceiverInfo& var_receiver_in_current_comp_seq( const VarNode* var) const override; /*! * \brief get the nodes in opr_set that directly depend on var (i.e. * those oprs that take var as one input) * * Guaranteed no duplication in this list, and the order is stable */ const OprNodeArray& var_receiver(VarNode* var) const { return m_var_receiver.at(var); } std::string get_mem_allocation_info() const override; VarNode* find_var_by_id(size_t id) const override; TopoSorter& topo_sorter() { return components().topo_sorter; } size_t next_node_id() override { return (*m_node_id_counter)++; } VarNodeMemManager& var_node_mem_manager() { return components().var_node_mem_manager; } SeqCompNodeOptimizer& seq_comp_node_optimizer() override { return components().seq_comp_node_opt; } static_infer::StaticInferManager& static_infer_manager() override { return components().static_infer_manager; } static_infer::StaticInferManagerImpl& static_infer_manager_impl() { return components().static_infer_manager; } static_infer::CompSeqManager& static_infer_comp_seq_manager() { return components().static_infer_comp_seq_manager; } GraphOptimizer& graph_optimizer() { return components().graph_optimizer; } EagerEvalManager& eager_eval_manager() { return components().eager_eval_manager; } #if MGB_ENABLE_SUBLINEAR SeqModifierForSublinearMemory& seq_modifier_for_sublinear_memory(); #endif void share_device_memory_with(ComputingGraph& other) override; void set_device_memory_allocator( std::shared_ptr allocator) override; size_t get_device_memory_size(CompNode cn) override; size_t clear_device_memory() override; void set_as_subgraph(ComputingGraph& par_graph) override; void record_async_error(std::unique_ptr async_exc) override; /*! * \brief latest computing sequence from this graph * * Since new computing sequence would invalidate memory layout of older * ones and operators may have class-level states, only the last * computing sequence could be used * * \return current comp seq, or nullptr if no alive comp seq */ AsyncExecutable* current_comp_seq() override { return static_cast(m_current_comp_seq); } /*! * \brief get current ExecEnv * \return ExecEnv if there is a compiled computing sequence, or * nullptr if not compiled yet */ GraphExecutable::ExecEnv* current_exec_env(); /*! * \brief get step number of an operator in current computing sequence * \return step number; None if opr not in seq */ Maybe opr_step_num_in_cur_comp_seq(OperatorNodeBase* opr); /*! * \brief get extra info for current computing sequence */ const CompSeqExtraInfo& current_comp_seq_extra_info(); /*! * \brief get associated grad manager */ GradManager& grad_manager() { return components().grad_manager; } //! get all the operators in this graph auto&& all_oprs() const { return m_opr_refkeeper; } size_t nr_oprs_in_graph() const override { return m_opr_refkeeper.size(); } //! memory pool for the var nodes; used by OperatorNodeBase auto&& var_node_pool() { return m_var_node_pool; } }; } // namespace cg } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}