From d42cf4cd65f3c5c7121be1279d2f68cd52fa302a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 11 Jul 2020 18:28:11 +0800 Subject: [PATCH] refactor(mgb): replace static_cast with a checked version GitOrigin-RevId: d05b114668f84f9b4045abe55af11bce0ff3bd4a --- CMakeLists.txt | 7 +++--- src/core/impl/graph/cg_impl.cpp | 4 ++-- src/core/impl/graph/cg_impl.h | 9 ++++++++ src/core/impl/graph/cg_impl_seq.h | 2 +- src/core/impl/graph/eager_eval.cpp | 10 ++++----- src/core/impl/graph/grad_manager.cpp | 10 ++++----- src/core/impl/graph/helper.cpp | 6 ++--- src/core/impl/graph/operator_node.cpp | 12 +++++----- src/core/impl/graph/static_infer_impl.cpp | 6 ++--- src/core/impl/graph/symbol_var.cpp | 2 +- src/core/impl/graph/topo_sort.cpp | 2 +- src/core/impl/graph/var_node.cpp | 22 +++++++++---------- src/core/impl/graph/var_node_mem_mgr.cpp | 2 +- .../impl/graph/var_node_mem_mgr/defrag.cpp | 2 +- 14 files changed, 53 insertions(+), 43 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 81b85956a..4038954ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -129,7 +129,7 @@ else() if(ANDROID) set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -DNDEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -DNDEBUG -g") - + else() set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -DNDEBUG -g") @@ -224,6 +224,7 @@ endif() option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) option(MGE_BUILD_XXX "Build _xxx.so instead of mgb.so " OFF) if(MGE_BUILD_XXX) + add_compile_definitions(MGB_ENABLE_IMPERATIVE_RUNTIME) set(CMAKE_CXX_STANDARD 17) endif() @@ -662,14 +663,14 @@ endif() configure_file(cmake/megengine.pc.in ${CMAKE_CURRENT_BINARY_DIR}/megengine.pc @ONLY) -install(FILES ${CMAKE_CURRENT_BINARY_DIR}/megengine.pc +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/megengine.pc DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) # Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready. if (NOT MGE_WITH_DISTRIBUTED) include(CMakePackageConfigHelpers) set (MGE_INSTALL_CMAKEDIR ${CMAKE_INSTALL_LIBDIR}/cmake/MegEngine) - configure_package_config_file(cmake/MegEngineConfig.cmake.in + configure_package_config_file(cmake/MegEngineConfig.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/MegEngineConfig.cmake INSTALL_DESTINATION ${MGE_INSTALL_CMAKEDIR} ) diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index 9839b6099..308dde9c7 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -674,7 +674,7 @@ void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) { mgb_assert( !m_current_comp_seq, "share_device_memory_with must be called before compiling graph"); - auto&& oimpl = static_cast(other); + auto&& oimpl = *ComputingGraphImpl::downcast(&other); var_node_mem_manager().static_device_memory_manager( oimpl.var_node_mem_manager().static_device_memory_manager()); } @@ -707,7 +707,7 @@ size_t ComputingGraphImpl::clear_device_memory() { } void ComputingGraphImpl::set_as_subgraph(ComputingGraph& par_graph) { - m_parent_graph = static_cast(&par_graph); + m_parent_graph = ComputingGraphImpl::downcast(&par_graph); m_parent_graph->m_subgraphs.emplace_back(this); m_node_id_counter = m_parent_graph->m_node_id_counter; options().var_sanity_check_first_run = diff --git a/src/core/impl/graph/cg_impl.h b/src/core/impl/graph/cg_impl.h index ec314fbce..c859cd6b5 100644 --- a/src/core/impl/graph/cg_impl.h +++ b/src/core/impl/graph/cg_impl.h @@ -122,6 +122,15 @@ public: ComputingGraphImpl(); ~ComputingGraphImpl(); + template static ComputingGraphImpl* downcast(T* ptr) = delete; + + inline static ComputingGraphImpl* downcast(ComputingGraph* graph) { + #ifdef MGB_ENABLE_IMPERATIVE_RUNTIME + mgb_assert(!graph->options().imperative_proxy_graph); + #endif + return static_cast(graph); + } + friend struct ComputingGraph::Options; std::unique_ptr compile( diff --git a/src/core/impl/graph/cg_impl_seq.h b/src/core/impl/graph/cg_impl_seq.h index 1b0da032f..0eee81ab2 100644 --- a/src/core/impl/graph/cg_impl_seq.h +++ b/src/core/impl/graph/cg_impl_seq.h @@ -100,7 +100,7 @@ class ComputingGraphImpl::ComputingSequence final : public AsyncExecutable { public: ComputingSequence(const std::shared_ptr& graph) : m_owner_graph_refkeep{graph}, - m_owner_graph{static_cast(graph.get())}, + m_owner_graph{ComputingGraphImpl::downcast(graph.get())}, m_have_parent_graph{m_owner_graph->m_parent_graph} {} GraphExecutable::ExecEnv& exec_env() { return m_exec_env; } diff --git a/src/core/impl/graph/eager_eval.cpp b/src/core/impl/graph/eager_eval.cpp index c0d70f383..6c13057c9 100644 --- a/src/core/impl/graph/eager_eval.cpp +++ b/src/core/impl/graph/eager_eval.cpp @@ -76,7 +76,7 @@ EagerEvalManager::~EagerEvalManager() noexcept { if (m_first_opr_enable_status == 1) { m_var_sync_mgr_pool.disable_freelist(); for (auto&& i : - static_cast(m_owner_graph)->all_oprs()) { + ComputingGraphImpl::downcast(m_owner_graph)->all_oprs()) { for (auto var : i->output()) { auto mgr = VarNodeMemManager::var_node_cn_sync_manager(var); if (mgr) { @@ -223,7 +223,7 @@ void EagerEvalManager::prepare_for_exec(OperatorNodeBase* opr) { } void EagerEvalManager::update_static_infer_result(OperatorNodeBase* opr) { - auto&& mgr = static_cast(m_owner_graph) + auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph) ->static_infer_manager_impl(); auto sync_missing_trait = [&](static_infer::StaticInferManagerImpl::TagHandler* handler) { @@ -260,7 +260,7 @@ void EagerEvalManager::update_static_infer_result(OperatorNodeBase* opr) { } void EagerEvalManager::ensure_input_layout(VarNode* var) { - auto&& mem_mgr = static_cast(var->owner_graph()) + auto&& mem_mgr = ComputingGraphImpl::downcast(var->owner_graph()) ->var_node_mem_manager(); auto trait = mem_mgr.get_var_node_mem_trait_nullable(var); @@ -287,7 +287,7 @@ void EagerEvalManager::alloc_output_mem(OperatorNodeBase* opr) { } } - auto&& mgr = static_cast(m_owner_graph) + auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph) ->var_node_mem_manager(); OprNodeArray opr_seq{opr}; @@ -348,7 +348,7 @@ void EagerEvalManager::do_on_opr_insert(OperatorNodeBase* opr) { if (status) { update_static_infer_result(opr); alloc_output_mem(opr); - auto&& mgr = static_cast(m_owner_graph) + auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph) ->var_node_mem_manager(); mgr.on_graph_compile_finished(); opr->execute(*m_exec_env); diff --git a/src/core/impl/graph/grad_manager.cpp b/src/core/impl/graph/grad_manager.cpp index 448897a45..493fb57c3 100644 --- a/src/core/impl/graph/grad_manager.cpp +++ b/src/core/impl/graph/grad_manager.cpp @@ -40,7 +40,7 @@ class GradShapeChecker { void do_on_var_shape(VarNode *var) { MGB_MARK_USED_VAR(m_opr); - auto graph = static_cast(var->owner_graph()); + auto graph = ComputingGraphImpl::downcast(var->owner_graph()); auto seq = graph->current_comp_seq(); if (seq) { @@ -90,7 +90,7 @@ class GradShapeChecker { } static void make(OperatorNodeBase *opr, VarNode *wrt, VarNode *grad) { - if (static_cast(wrt->owner_graph()) + if (ComputingGraphImpl::downcast(wrt->owner_graph()) ->eager_eval_manager().enabled()) return; using namespace std::placeholders; @@ -650,13 +650,13 @@ void GradManager::add_var_virtual_receiver( } void cg::add_grad_transformer(VarNode *var, const GradTransformer &cb) { - static_cast(var->owner_graph())-> + ComputingGraphImpl::downcast(var->owner_graph())-> grad_manager(). add_grad_transformer(var, cb); } void cg::add_extra_dep_for_grad(VarNode *inp, VarNode *out) { - static_cast(inp->owner_graph())->grad_manager(). + ComputingGraphImpl::downcast(inp->owner_graph())->grad_manager(). add_extra_dep_for_grad(inp, out); } @@ -667,7 +667,7 @@ void cg::add_var_virtual_receiver( desc->inputs = inputs; desc->outputs = outputs; desc->grad = grad; - static_cast(inputs.at(0)->owner_graph())-> + ComputingGraphImpl::downcast(inputs.at(0)->owner_graph())-> grad_manager(). add_var_virtual_receiver(desc); } diff --git a/src/core/impl/graph/helper.cpp b/src/core/impl/graph/helper.cpp index 63dd883a2..7e221250e 100644 --- a/src/core/impl/graph/helper.cpp +++ b/src/core/impl/graph/helper.cpp @@ -99,8 +99,8 @@ SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_w grads.reserve(wrts_.size()); VarNodeArray dest_vars; auto&& graph = target->owner_graph(); - auto&& eager_mgr = static_cast(graph)->eager_eval_manager(); - auto&& grad_mgr = static_cast(graph)->grad_manager(); + auto&& eager_mgr = ComputingGraphImpl::downcast(graph)->eager_eval_manager(); + auto&& grad_mgr = ComputingGraphImpl::downcast(graph)->grad_manager(); bool already_recorded = eager_mgr.enter_record_mode(); for (auto&& wrt_ : wrts_) { auto wrt = wrt_.node(); @@ -139,7 +139,7 @@ SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_w SymbolVar cg::current_grad_target(ComputingGraph &graph) { #if MGB_ENABLE_GRAD - auto var = static_cast(graph).grad_manager( + auto var = ComputingGraphImpl::downcast(&graph)->grad_manager( ).current_grad_target(); mgb_throw_if(!var, GraphError, "current_grad_target() called outside " "grad computing environment"); diff --git a/src/core/impl/graph/operator_node.cpp b/src/core/impl/graph/operator_node.cpp index bef29dbd9..2bf06a1c6 100644 --- a/src/core/impl/graph/operator_node.cpp +++ b/src/core/impl/graph/operator_node.cpp @@ -93,7 +93,7 @@ OperatorNodeBase::OperatorNodeBase(ComputingGraph *owner, } OperatorNodeBase::~OperatorNodeBase() noexcept { - auto &&pool = static_cast( + auto &&pool = ComputingGraphImpl::cast( owner_graph())->var_node_pool(); for (auto i: m_output) { pool.free(i); @@ -124,7 +124,7 @@ void OperatorNodeBase::execute(ExecEnv &env) { } // allocate output with dynamic storage - static_cast(owner_graph()) + ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager() .alloc_var_node_mem_dynamic(env, this); @@ -135,11 +135,11 @@ void OperatorNodeBase::execute(ExecEnv &env) { // static_infer_manager so the value would be up-to-date; however for shape // deps, oprs would access the shape directly, so we need to insert some // code here to ensure it is up-to-date. - if (!static_cast(owner_graph()) + if (!ComputingGraphImpl::downcast(owner_graph()) ->eager_eval_manager() .enabled()) { VarNodeArray vars_to_set; - auto cg = static_cast(owner_graph()); + auto cg = ComputingGraphImpl::downcast(owner_graph()); auto step_cur = cg->opr_step_num_in_cur_comp_seq(this).val(); mgb_assert(step_cur < std::numeric_limits::max()); using DT = NodeProp::DepType; @@ -264,7 +264,7 @@ VarNode* OperatorNodeBase::add_output(const Maybe &name) { mgb_assert(!m_inserted_in_graph && !m_node_prop.valid(), "add output on opr after it has been inserted into graph"); - auto ptr = static_cast( + auto ptr = ComputingGraphImpl::cast( owner_graph())->var_node_pool().alloc( name.valid() ? this->name() + ":" + name.val() : name, this); m_output.push_back(ptr); @@ -676,7 +676,7 @@ void mixin::IOSameShapeOperatorNode::get_output_var_shape( void PostExecActions::add(VarNode* var) { mgb_assert(m_comp_node == var->comp_node()); - auto graph = static_cast(var->owner_graph()); + auto graph = ComputingGraphImpl::downcast(var->owner_graph()); auto&& infer_mgr = graph->static_infer_manager_impl(); auto&& extra_info = graph->current_comp_seq_extra_info(); diff --git a/src/core/impl/graph/static_infer_impl.cpp b/src/core/impl/graph/static_infer_impl.cpp index 7d0db6053..ed3ffbc6c 100644 --- a/src/core/impl/graph/static_infer_impl.cpp +++ b/src/core/impl/graph/static_infer_impl.cpp @@ -813,7 +813,7 @@ StaticInferManagerImpl::~StaticInferManagerImpl() noexcept { m_mem_pool_value_trait.disable_freelist(); for (auto &&i: m_dtor_callbacks) i.second(); - for (auto &&i: static_cast( + for (auto &&i: ComputingGraphImpl::downcast( m_owner_graph)->all_oprs()) { for (auto j: i->output()) { clear_tag_handler(j); @@ -1212,7 +1212,7 @@ class StaticInferManagerImpl::SubgraphStaticInferHelperImpl final: void check_graph_par(VarNode *var) { if (mgb_unlikely(!m_par_graph)) { - m_par_graph = static_cast(var->owner_graph()); + m_par_graph = ComputingGraphImpl::downcast(var->owner_graph()); mgb_assert(m_par_graph != m_sub_graph); auto cb = [this]() { @@ -1230,7 +1230,7 @@ class StaticInferManagerImpl::SubgraphStaticInferHelperImpl final: void check_graph_sub(VarNode *var) { if (mgb_unlikely(!m_sub_graph)) { - m_sub_graph = static_cast(var->owner_graph()); + m_sub_graph = ComputingGraphImpl::downcast(var->owner_graph()); mgb_assert(m_sub_graph != m_par_graph); } else { mgb_assert(m_sub_graph == var->owner_graph()); diff --git a/src/core/impl/graph/symbol_var.cpp b/src/core/impl/graph/symbol_var.cpp index 323de296b..034b13219 100644 --- a/src/core/impl/graph/symbol_var.cpp +++ b/src/core/impl/graph/symbol_var.cpp @@ -132,7 +132,7 @@ const DeviceTensorND& SymbolVar::eager_eval_get_value() const { #if MGB_BUILD_SLIM_SERVING mgb_throw(MegBrainError, "eager eval disabled at compile time"); #else - auto og = static_cast(node()->owner_graph()); + auto og = ComputingGraphImpl::downcast(node()->owner_graph()); mgb_assert(og->options().eager_evaluation); return node()->dev_tensor(); #endif diff --git a/src/core/impl/graph/topo_sort.cpp b/src/core/impl/graph/topo_sort.cpp index 9678e3396..4f5f6c2d7 100644 --- a/src/core/impl/graph/topo_sort.cpp +++ b/src/core/impl/graph/topo_sort.cpp @@ -260,7 +260,7 @@ void TopoSorter::DFSDepDiscover::proc_add_dep_comp_order1() { void TopoSorter::DFSDepDiscover::proc_find_missing_inp() { auto frame = m_cur_frame; auto opr = frame->opr; - auto&& mgr = static_cast(opr->owner_graph()) + auto&& mgr = ComputingGraphImpl::downcast(opr->owner_graph()) ->static_infer_manager_impl(); auto&& missing_inp = frame->missing_inputs; diff --git a/src/core/impl/graph/var_node.cpp b/src/core/impl/graph/var_node.cpp index 1e00f8f08..1a136d9bb 100644 --- a/src/core/impl/graph/var_node.cpp +++ b/src/core/impl/graph/var_node.cpp @@ -233,12 +233,12 @@ bool VarNode::set_fwd_in2out_readonly( if (owner_graph()->options().imperative_proxy_graph) { return false; } - return static_cast(owner_graph()) + return ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager().fwd_in2out_readonly(input, sub, this); } VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) { - static_cast(owner_graph()) + ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager().fwd_in2out_writable(input, this); return *this; } @@ -246,20 +246,20 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) { VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) { mgb_assert(!owner_graph()->options().imperative_proxy_graph); - static_cast(owner_graph()) + ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager().fwd_in2out_writable_force(input, this); return *this; } VarNode& VarNode::add_layout_constraint(LayoutConstraintCallback callback) { - static_cast(owner_graph()) + ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager().add_layout_constraint( this, std::move(callback)); return *this; } VarNode& VarNode::add_layout_constraint_contiguous() { - static_cast(owner_graph()) + ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager() .add_layout_constraint_level( this, VarNodeMemManager::LayoutConstraintLevel::CONTIG); @@ -267,7 +267,7 @@ VarNode& VarNode::add_layout_constraint_contiguous() { } VarNode& VarNode::add_layout_constraint_monotone() { - static_cast(owner_graph()) + ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager() .add_layout_constraint_level( this, VarNodeMemManager::LayoutConstraintLevel::MONOTONE); @@ -315,7 +315,7 @@ VarNode& VarNode::shape_alloc(const TensorShape &shape) { "shape_alloc() could only be used for vars with" " NO_SYS_MEM_ALLOC flag; actual var: %s", cg::dump_var_info({this}).c_str()); - static_cast(owner_graph()) + ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager().var_alloc_with_shape(this, shape); return *this; } @@ -330,7 +330,7 @@ bool VarNode::reset_dev_tensor_from_other_var(VarNode* src_var) { "dynamic storage on src is required for dynamic readonly " "forwarding: vars=%s", dump_var_info({src_var, this}).c_str()); - auto&& trait = static_cast(owner_graph()) + auto&& trait = ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager() .get_var_node_mem_trait_at(src_var); if (trait.seq_force_update_dest || @@ -403,7 +403,7 @@ std::shared_ptr VarNode::to_json() const { return json::Null::make(); }; - auto &&trait = static_cast(owner_graph() + auto &&trait = ComputingGraphImpl::downcast(owner_graph() )->var_node_mem_manager().get_var_node_mem_trait(this); auto flag = json::Array::make(); { @@ -459,7 +459,7 @@ std::shared_ptr VarNode::to_json() const { #endif MemAllocPlan& VarNode::init_mem_plan(const DeviceTensorND* fixed_alloc) { - static_cast(owner_graph()) + ComputingGraphImpl::downcast(owner_graph()) ->var_node_mem_manager() .init_single_var_mem_plan(this, fixed_alloc); return m_mem_plan; @@ -477,7 +477,7 @@ void VarNode::modify_flag(Flag delta, Flag new_flag) { Flag::NO_SYS_STATIC_MEM_ALLOC | Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta); - mgb_assert(!static_cast(owner_graph())-> + mgb_assert(!ComputingGraphImpl::downcast(owner_graph())-> var_node_mem_manager().optimize_started(), "could not modify var flags after optimization started"); } diff --git a/src/core/impl/graph/var_node_mem_mgr.cpp b/src/core/impl/graph/var_node_mem_mgr.cpp index bea4b1d07..381436b08 100644 --- a/src/core/impl/graph/var_node_mem_mgr.cpp +++ b/src/core/impl/graph/var_node_mem_mgr.cpp @@ -340,7 +340,7 @@ VarNodeMemManager::DynamicAllocOprInfo::DynamicAllocOprInfo( prev_dev_val_input.clear(); static_infer_inp.clear(); dev_val_input.clear(); - auto &&mgr = static_cast(opr->owner_graph())-> + auto &&mgr = ComputingGraphImpl::downcast(opr->owner_graph())-> static_infer_manager_impl(); CompNode single_cn; diff --git a/src/core/impl/graph/var_node_mem_mgr/defrag.cpp b/src/core/impl/graph/var_node_mem_mgr/defrag.cpp index 5053ce10a..a1fe0b11c 100644 --- a/src/core/impl/graph/var_node_mem_mgr/defrag.cpp +++ b/src/core/impl/graph/var_node_mem_mgr/defrag.cpp @@ -73,7 +73,7 @@ void VarDevMemDefragmenter::defrag(VarNode* req_var, const CompNodeInfo& cn_info, size_t extra_size) { // pause all other comp nodes before calling defrag_impl() - auto exec_env = static_cast(req_var->owner_graph()) + auto exec_env = ComputingGraphImpl::downcast(req_var->owner_graph()) ->current_exec_env(); mgb_assert(exec_env); exec_env->pause_exec(); -- GitLab