Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
13c7c572
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
13c7c572
编写于
11月 03, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb): fix shape infer's condition in lite
GitOrigin-RevId: 550eaff4cd2904b2bebb60e0fc3e32cb97295738
上级
8d825246
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
56 addition
and
7 deletion
+56
-7
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+2
-3
lite/test/test_network.cpp
lite/test/test_network.cpp
+15
-0
src/core/impl/graph/static_infer_impl.cpp
src/core/impl/graph/static_infer_impl.cpp
+10
-0
src/core/impl/graph/static_infer_impl.h
src/core/impl/graph/static_infer_impl.h
+10
-0
src/core/impl/graph/var_node.cpp
src/core/impl/graph/var_node.cpp
+12
-0
src/core/include/megbrain/graph/var_node.h
src/core/include/megbrain/graph/var_node.h
+7
-4
未找到文件。
lite/src/mge/network_impl.cpp
浏览文件 @
13c7c572
...
...
@@ -454,9 +454,8 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) {
}
void
NetworkImplDft
::
try_infer_tensor_layout
(
std
::
shared_ptr
<
Tensor
>
tensor
,
Var
var
)
{
auto
&&
static_infer_mgr
=
m_load_config
.
comp_graph
->
static_infer_manager
();
auto
infer_trait
=
var
.
node
()
->
get_static_infer_trait
();
if
(
std
::
get
<
0
>
(
infer_trait
))
{
if
(
var
.
node
()
->
capable_shape_infer
())
{
auto
&&
static_infer_mgr
=
m_load_config
.
comp_graph
->
static_infer_manager
();
auto
shape
=
static_infer_mgr
.
infer_shape_fallible
(
var
.
node
());
if
(
!
shape
)
{
LITE_WARN
(
...
...
lite/test/test_network.cpp
浏览文件 @
13c7c572
...
...
@@ -101,6 +101,21 @@ TEST(TestNetWork, GetAllName) {
ASSERT_TRUE
(
output_names
[
0
]
==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
);
}
TEST
(
TestNetWork
,
LoadFBSModel
)
{
Config
config
;
std
::
string
model_path
=
"./ax.mge"
;
std
::
shared_ptr
<
Network
>
network
=
std
::
make_shared
<
Network
>
(
config
);
network
->
load_model
(
model_path
);
auto
output_tensor
=
network
->
get_output_tensor
(
0
);
auto
out_layout
=
output_tensor
->
get_layout
();
ASSERT_EQ
(
out_layout
.
ndim
,
4
);
ASSERT_EQ
(
out_layout
.
shapes
[
0
],
1
);
ASSERT_EQ
(
out_layout
.
shapes
[
1
],
1
);
ASSERT_EQ
(
out_layout
.
shapes
[
2
],
40
);
ASSERT_EQ
(
out_layout
.
shapes
[
3
],
180
);
}
TEST
(
TestNetWork
,
BasicInplaceAndSingleThreadAffinity
)
{
Config
config
;
auto
lite_tensor
=
get_input_data
(
"./input_data.npy"
);
...
...
src/core/impl/graph/static_infer_impl.cpp
浏览文件 @
13c7c572
...
...
@@ -892,6 +892,16 @@ StaticInferManagerImpl::TagHandler* StaticInferManagerImpl::get_tag_handler_for_
return
c
.
value
;
}
bool
StaticInferManagerImpl
::
has_shape_infer
(
Tag
tag
)
const
{
auto
&&
c
=
get_tag_trait_container
(
tag
);
return
c
.
shape
!=
nullptr
;
}
bool
StaticInferManagerImpl
::
has_value_infer
(
Tag
tag
)
const
{
auto
&&
c
=
get_tag_trait_container
(
tag
);
return
c
.
value
!=
nullptr
;
}
StaticInferManagerImpl
::
TagTraitBase
*
StaticInferManagerImpl
::
get_tag_trait_for_dep
(
const
DepElement
&
dep
)
{
TagHandler
*
ret
;
...
...
src/core/impl/graph/static_infer_impl.h
浏览文件 @
13c7c572
...
...
@@ -65,6 +65,16 @@ public:
*/
MGE_WIN_DECLSPEC_FUC
TagHandler
*
get_tag_handler_for_value
(
Tag
tag
);
/*!
* \brief check if there is a registered shape infer func in tag
*/
bool
has_shape_infer
(
Tag
tag
)
const
;
/*!
* \brief check if there is a registered value infer func in tag
*/
bool
has_value_infer
(
Tag
tag
)
const
;
/*!
* \brief clear registered handler for a tag; this is only used in error
* handling in opr creation
...
...
src/core/impl/graph/var_node.cpp
浏览文件 @
13c7c572
...
...
@@ -578,6 +578,18 @@ bool VarNode::is_graph_dest_varnode() {
return
ComputingGraphImpl
::
downcast
(
owner_graph
())
->
var_receiver
(
this
).
size
()
==
0
;
}
bool
VarNode
::
capable_shape_infer
()
{
auto
&&
mgr
=
ComputingGraphImpl
::
downcast
(
owner_graph
())
->
static_infer_manager_impl
();
return
mgr
.
has_shape_infer
(
this
);
}
bool
VarNode
::
capable_value_infer
()
{
auto
&&
mgr
=
ComputingGraphImpl
::
downcast
(
owner_graph
())
->
static_infer_manager_impl
();
return
mgr
.
has_value_infer
(
this
);
}
VarNode
&
VarNode
::
add_flag
(
Flag
flag
)
{
modify_flag
(
flag
,
m_flag
|
flag
);
return
*
this
;
...
...
src/core/include/megbrain/graph/var_node.h
浏览文件 @
13c7c572
...
...
@@ -495,11 +495,14 @@ public:
const
DeviceTensorND
*
fixed_alloc
=
nullptr
);
/*!
* \brief
get the shape and value infer trait
* \brief
check infer shape capablity by check m_static_infer_trait's shape infer
*/
const
std
::
tuple
<
void
*
,
void
*>&
get_static_infer_trait
()
{
return
m_static_infer_trait
;
}
MGE_WIN_DECLSPEC_FUC
bool
capable_shape_infer
();
/*!
* \brief check infer shape capablity by check m_static_infer_trait's value infer
*/
MGE_WIN_DECLSPEC_FUC
bool
capable_value_infer
();
private:
//! whether its memory should be allocated by mgb system during graph
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录