From 13c7c572df75ffa8b6bb17d7f2f31292a4e42ca7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 3 Nov 2021 15:45:42 +0800 Subject: [PATCH] fix(mgb): fix shape infer's condition in lite GitOrigin-RevId: 550eaff4cd2904b2bebb60e0fc3e32cb97295738 --- lite/src/mge/network_impl.cpp | 5 ++--- lite/test/test_network.cpp | 15 +++++++++++++++ src/core/impl/graph/static_infer_impl.cpp | 10 ++++++++++ src/core/impl/graph/static_infer_impl.h | 10 ++++++++++ src/core/impl/graph/var_node.cpp | 12 ++++++++++++ src/core/include/megbrain/graph/var_node.h | 11 +++++++---- 6 files changed, 56 insertions(+), 7 deletions(-) diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 1cfde0fa2..ae106921f 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -454,9 +454,8 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { } void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr tensor, Var var) { - auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); - auto infer_trait = var.node()->get_static_infer_trait(); - if (std::get<0>(infer_trait)) { + if (var.node()->capable_shape_infer()) { + auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); auto shape = static_infer_mgr.infer_shape_fallible(var.node()); if (!shape) { LITE_WARN( diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index f786b92c1..8418daf52 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -101,6 +101,21 @@ TEST(TestNetWork, GetAllName) { ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); } +TEST(TestNetWork, LoadFBSModel) { + Config config; + std::string model_path = "./ax.mge"; + std::shared_ptr network = std::make_shared(config); + network->load_model(model_path); + + auto output_tensor = network->get_output_tensor(0); + auto out_layout = output_tensor->get_layout(); + ASSERT_EQ(out_layout.ndim, 4); + ASSERT_EQ(out_layout.shapes[0], 1); + ASSERT_EQ(out_layout.shapes[1], 1); + ASSERT_EQ(out_layout.shapes[2], 40); + ASSERT_EQ(out_layout.shapes[3], 180); +} + TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { Config config; auto lite_tensor = get_input_data("./input_data.npy"); diff --git a/src/core/impl/graph/static_infer_impl.cpp b/src/core/impl/graph/static_infer_impl.cpp index 04773cf9d..5aa84a0e5 100644 --- a/src/core/impl/graph/static_infer_impl.cpp +++ b/src/core/impl/graph/static_infer_impl.cpp @@ -892,6 +892,16 @@ StaticInferManagerImpl::TagHandler* StaticInferManagerImpl::get_tag_handler_for_ return c.value; } +bool StaticInferManagerImpl::has_shape_infer(Tag tag) const { + auto&& c = get_tag_trait_container(tag); + return c.shape != nullptr; +} + +bool StaticInferManagerImpl::has_value_infer(Tag tag) const { + auto&& c = get_tag_trait_container(tag); + return c.value != nullptr; +} + StaticInferManagerImpl::TagTraitBase* StaticInferManagerImpl::get_tag_trait_for_dep( const DepElement& dep) { TagHandler* ret; diff --git a/src/core/impl/graph/static_infer_impl.h b/src/core/impl/graph/static_infer_impl.h index 572017a02..e78543020 100644 --- a/src/core/impl/graph/static_infer_impl.h +++ b/src/core/impl/graph/static_infer_impl.h @@ -65,6 +65,16 @@ public: */ MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_value(Tag tag); + /*! + * \brief check if there is a registered shape infer func in tag + */ + bool has_shape_infer(Tag tag) const; + + /*! + * \brief check if there is a registered value infer func in tag + */ + bool has_value_infer(Tag tag) const; + /*! * \brief clear registered handler for a tag; this is only used in error * handling in opr creation diff --git a/src/core/impl/graph/var_node.cpp b/src/core/impl/graph/var_node.cpp index 833da401d..a9ceb3646 100644 --- a/src/core/impl/graph/var_node.cpp +++ b/src/core/impl/graph/var_node.cpp @@ -578,6 +578,18 @@ bool VarNode::is_graph_dest_varnode() { return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0; } +bool VarNode::capable_shape_infer() { + auto&& mgr = + ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); + return mgr.has_shape_infer(this); +} + +bool VarNode::capable_value_infer() { + auto&& mgr = + ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); + return mgr.has_value_infer(this); +} + VarNode& VarNode::add_flag(Flag flag) { modify_flag(flag, m_flag | flag); return *this; diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index 74db600a2..b09b4157d 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -495,11 +495,14 @@ public: const DeviceTensorND* fixed_alloc = nullptr); /*! - * \brief get the shape and value infer trait + * \brief check infer shape capablity by check m_static_infer_trait's shape infer */ - const std::tuple& get_static_infer_trait() { - return m_static_infer_trait; - } + MGE_WIN_DECLSPEC_FUC bool capable_shape_infer(); + + /*! + * \brief check infer shape capablity by check m_static_infer_trait's value infer + */ + MGE_WIN_DECLSPEC_FUC bool capable_value_infer(); private: //! whether its memory should be allocated by mgb system during graph -- GitLab