提交 13c7c572 编写于 作者: M Megvii Engine Team

fix(mgb): fix shape infer's condition in lite

GitOrigin-RevId: 550eaff4cd2904b2bebb60e0fc3e32cb97295738
上级 8d825246
...@@ -454,9 +454,8 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { ...@@ -454,9 +454,8 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) {
} }
void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) { void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); if (var.node()->capable_shape_infer()) {
auto infer_trait = var.node()->get_static_infer_trait(); auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
if (std::get<0>(infer_trait)) {
auto shape = static_infer_mgr.infer_shape_fallible(var.node()); auto shape = static_infer_mgr.infer_shape_fallible(var.node());
if (!shape) { if (!shape) {
LITE_WARN( LITE_WARN(
......
...@@ -101,6 +101,21 @@ TEST(TestNetWork, GetAllName) { ...@@ -101,6 +101,21 @@ TEST(TestNetWork, GetAllName) {
ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
} }
TEST(TestNetWork, LoadFBSModel) {
Config config;
std::string model_path = "./ax.mge";
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);
auto output_tensor = network->get_output_tensor(0);
auto out_layout = output_tensor->get_layout();
ASSERT_EQ(out_layout.ndim, 4);
ASSERT_EQ(out_layout.shapes[0], 1);
ASSERT_EQ(out_layout.shapes[1], 1);
ASSERT_EQ(out_layout.shapes[2], 40);
ASSERT_EQ(out_layout.shapes[3], 180);
}
TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) {
Config config; Config config;
auto lite_tensor = get_input_data("./input_data.npy"); auto lite_tensor = get_input_data("./input_data.npy");
......
...@@ -892,6 +892,16 @@ StaticInferManagerImpl::TagHandler* StaticInferManagerImpl::get_tag_handler_for_ ...@@ -892,6 +892,16 @@ StaticInferManagerImpl::TagHandler* StaticInferManagerImpl::get_tag_handler_for_
return c.value; return c.value;
} }
bool StaticInferManagerImpl::has_shape_infer(Tag tag) const {
auto&& c = get_tag_trait_container(tag);
return c.shape != nullptr;
}
bool StaticInferManagerImpl::has_value_infer(Tag tag) const {
auto&& c = get_tag_trait_container(tag);
return c.value != nullptr;
}
StaticInferManagerImpl::TagTraitBase* StaticInferManagerImpl::get_tag_trait_for_dep( StaticInferManagerImpl::TagTraitBase* StaticInferManagerImpl::get_tag_trait_for_dep(
const DepElement& dep) { const DepElement& dep) {
TagHandler* ret; TagHandler* ret;
......
...@@ -65,6 +65,16 @@ public: ...@@ -65,6 +65,16 @@ public:
*/ */
MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_value(Tag tag); MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_value(Tag tag);
/*!
* \brief check if there is a registered shape infer func in tag
*/
bool has_shape_infer(Tag tag) const;
/*!
* \brief check if there is a registered value infer func in tag
*/
bool has_value_infer(Tag tag) const;
/*! /*!
* \brief clear registered handler for a tag; this is only used in error * \brief clear registered handler for a tag; this is only used in error
* handling in opr creation * handling in opr creation
......
...@@ -578,6 +578,18 @@ bool VarNode::is_graph_dest_varnode() { ...@@ -578,6 +578,18 @@ bool VarNode::is_graph_dest_varnode() {
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0; return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0;
} }
bool VarNode::capable_shape_infer() {
auto&& mgr =
ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
return mgr.has_shape_infer(this);
}
bool VarNode::capable_value_infer() {
auto&& mgr =
ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
return mgr.has_value_infer(this);
}
VarNode& VarNode::add_flag(Flag flag) { VarNode& VarNode::add_flag(Flag flag) {
modify_flag(flag, m_flag | flag); modify_flag(flag, m_flag | flag);
return *this; return *this;
......
...@@ -495,11 +495,14 @@ public: ...@@ -495,11 +495,14 @@ public:
const DeviceTensorND* fixed_alloc = nullptr); const DeviceTensorND* fixed_alloc = nullptr);
/*! /*!
* \brief get the shape and value infer trait * \brief check infer shape capablity by check m_static_infer_trait's shape infer
*/ */
const std::tuple<void*, void*>& get_static_infer_trait() { MGE_WIN_DECLSPEC_FUC bool capable_shape_infer();
return m_static_infer_trait;
} /*!
* \brief check infer shape capablity by check m_static_infer_trait's value infer
*/
MGE_WIN_DECLSPEC_FUC bool capable_value_infer();
private: private:
//! whether its memory should be allocated by mgb system during graph //! whether its memory should be allocated by mgb system during graph
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册