From 551cad4955eb55f96a1974808c00c0dcfa3c07e1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 5 Jul 2022 14:35:35 +0800 Subject: [PATCH] refactor(megbrain): refactor try infer tensor layout in lite avoiding using megbrain interface GitOrigin-RevId: 9799e671022002be719a4ee71ccca89d2ec318b8 --- lite/src/mge/network_impl.cpp | 6 ++++-- src/core/impl/graph/var_node.cpp | 12 ------------ src/core/include/megbrain/graph/var_node.h | 10 ---------- 3 files changed, 4 insertions(+), 24 deletions(-) diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index a77195206..58eb463e9 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -660,8 +660,10 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { } void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr tensor, Var var) { - if (var.node()->capable_shape_infer()) { - auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); + using InferType = mgb::cg::static_infer::InferType; + auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); + if (static_infer_mgr.get_infer_type(var.node()).shape & + (InferType::CONST | InferType::RT_STATIC)) { auto shape = static_infer_mgr.infer_shape_fallible(var.node()); if (!shape) { LITE_WARN( diff --git a/src/core/impl/graph/var_node.cpp b/src/core/impl/graph/var_node.cpp index 863354ee2..1380acef7 100644 --- a/src/core/impl/graph/var_node.cpp +++ b/src/core/impl/graph/var_node.cpp @@ -596,18 +596,6 @@ bool VarNode::is_graph_dest_varnode() { return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0; } -bool VarNode::capable_shape_infer() { - auto&& mgr = - ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); - return mgr.has_shape_infer(this); -} - -bool VarNode::capable_value_infer() { - auto&& mgr = - ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); - return mgr.has_value_infer(this); -} - VarNode& VarNode::add_flag(Flag flag) { modify_flag(flag, m_flag | flag); return *this; diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index e5778456c..3b84c378e 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -488,16 +488,6 @@ public: MGE_WIN_DECLSPEC_FUC MemAllocPlan& init_mem_plan( const DeviceTensorND* fixed_alloc = nullptr); - /*! - * \brief check infer shape capablity by check m_static_infer_trait's shape infer - */ - MGE_WIN_DECLSPEC_FUC bool capable_shape_infer(); - - /*! - * \brief check infer shape capablity by check m_static_infer_trait's value infer - */ - MGE_WIN_DECLSPEC_FUC bool capable_value_infer(); - //! whether the var is graph output, if it is output, the Flag of //! NO_SYS_MEM_ALLOC can be modified. MGE_WIN_DECLSPEC_FUC bool is_graph_dest_varnode(); -- GitLab