提交 13c7c572 编写于 作者: M Megvii Engine Team

fix(mgb): fix shape infer's condition in lite

GitOrigin-RevId: 550eaff4cd2904b2bebb60e0fc3e32cb97295738
上级 8d825246
......@@ -454,9 +454,8 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) {
}
void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
auto infer_trait = var.node()->get_static_infer_trait();
if (std::get<0>(infer_trait)) {
if (var.node()->capable_shape_infer()) {
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
auto shape = static_infer_mgr.infer_shape_fallible(var.node());
if (!shape) {
LITE_WARN(
......
......@@ -101,6 +101,21 @@ TEST(TestNetWork, GetAllName) {
ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
}
TEST(TestNetWork, LoadFBSModel) {
Config config;
std::string model_path = "./ax.mge";
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);
auto output_tensor = network->get_output_tensor(0);
auto out_layout = output_tensor->get_layout();
ASSERT_EQ(out_layout.ndim, 4);
ASSERT_EQ(out_layout.shapes[0], 1);
ASSERT_EQ(out_layout.shapes[1], 1);
ASSERT_EQ(out_layout.shapes[2], 40);
ASSERT_EQ(out_layout.shapes[3], 180);
}
TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) {
Config config;
auto lite_tensor = get_input_data("./input_data.npy");
......
......@@ -892,6 +892,16 @@ StaticInferManagerImpl::TagHandler* StaticInferManagerImpl::get_tag_handler_for_
return c.value;
}
bool StaticInferManagerImpl::has_shape_infer(Tag tag) const {
auto&& c = get_tag_trait_container(tag);
return c.shape != nullptr;
}
bool StaticInferManagerImpl::has_value_infer(Tag tag) const {
auto&& c = get_tag_trait_container(tag);
return c.value != nullptr;
}
StaticInferManagerImpl::TagTraitBase* StaticInferManagerImpl::get_tag_trait_for_dep(
const DepElement& dep) {
TagHandler* ret;
......
......@@ -65,6 +65,16 @@ public:
*/
MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_value(Tag tag);
/*!
* \brief check if there is a registered shape infer func in tag
*/
bool has_shape_infer(Tag tag) const;
/*!
* \brief check if there is a registered value infer func in tag
*/
bool has_value_infer(Tag tag) const;
/*!
* \brief clear registered handler for a tag; this is only used in error
* handling in opr creation
......
......@@ -578,6 +578,18 @@ bool VarNode::is_graph_dest_varnode() {
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0;
}
bool VarNode::capable_shape_infer() {
auto&& mgr =
ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
return mgr.has_shape_infer(this);
}
bool VarNode::capable_value_infer() {
auto&& mgr =
ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
return mgr.has_value_infer(this);
}
VarNode& VarNode::add_flag(Flag flag) {
modify_flag(flag, m_flag | flag);
return *this;
......
......@@ -495,11 +495,14 @@ public:
const DeviceTensorND* fixed_alloc = nullptr);
/*!
* \brief get the shape and value infer trait
* \brief check infer shape capablity by check m_static_infer_trait's shape infer
*/
const std::tuple<void*, void*>& get_static_infer_trait() {
return m_static_infer_trait;
}
MGE_WIN_DECLSPEC_FUC bool capable_shape_infer();
/*!
* \brief check infer shape capablity by check m_static_infer_trait's value infer
*/
MGE_WIN_DECLSPEC_FUC bool capable_value_infer();
private:
//! whether its memory should be allocated by mgb system during graph
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册