Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
565466c2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
565466c2
编写于
10月 11, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(lite): auto deduce output tensor shape before model forward
GitOrigin-RevId: 78e00dab5da3fcc91bb53d8588c06b5b25295e19
上级
29f9935d
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
59 addition
and
2 deletion
+59
-2
lite/pylite/megenginelite/struct.py
lite/pylite/megenginelite/struct.py
+1
-0
lite/pylite/megenginelite/tensor.py
lite/pylite/megenginelite/tensor.py
+4
-2
lite/src/mge/common.cpp
lite/src/mge/common.cpp
+6
-0
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+28
-0
lite/src/mge/network_impl.h
lite/src/mge/network_impl.h
+4
-0
lite/src/misc.cpp
lite/src/misc.cpp
+4
-0
lite/test/test_network.cpp
lite/test/test_network.cpp
+5
-0
src/core/include/megbrain/graph/var_node.h
src/core/include/megbrain/graph/var_node.h
+7
-0
未找到文件。
lite/pylite/megenginelite/struct.py
浏览文件 @
565466c2
...
...
@@ -31,6 +31,7 @@ class LiteDataType(IntEnum):
LITE_INT16
=
3
LITE_INT8
=
4
LITE_UINT8
=
5
LITE_UINT16
=
6
class
LiteTensorPhase
(
IntEnum
):
...
...
lite/pylite/megenginelite/tensor.py
浏览文件 @
565466c2
...
...
@@ -22,6 +22,7 @@ _lite_type_to_nptypes = {
LiteDataType
.
LITE_UINT8
:
np
.
uint8
,
LiteDataType
.
LITE_INT8
:
np
.
int8
,
LiteDataType
.
LITE_INT16
:
np
.
int16
,
LiteDataType
.
LITE_UINT16
:
np
.
uint16
,
LiteDataType
.
LITE_HALF
:
np
.
float16
,
}
...
...
@@ -33,6 +34,7 @@ _str_nptypes_to_lite_nptypes = {
np
.
dtype
(
"uint8"
):
LiteDataType
.
LITE_UINT8
,
np
.
dtype
(
"int8"
):
LiteDataType
.
LITE_INT8
,
np
.
dtype
(
"int16"
):
LiteDataType
.
LITE_INT16
,
np
.
dtype
(
"uint16"
):
LiteDataType
.
LITE_UINT16
,
np
.
dtype
(
"float16"
):
LiteDataType
.
LITE_HALF
,
}
...
...
@@ -43,7 +45,7 @@ ctype_to_lite_dtypes = {
c_ubyte
:
LiteDataType
.
LITE_UINT8
,
c_byte
:
LiteDataType
.
LITE_INT8
,
c_short
:
LiteDataType
.
LITE_INT16
,
c_ushort
:
LiteDataType
.
LITE_INT16
,
c_ushort
:
LiteDataType
.
LITE_
U
INT16
,
}
...
...
@@ -83,7 +85,7 @@ class LiteLayout(Structure):
def
__repr__
(
self
):
data
=
{
"shapes"
:
list
(
self
.
shapes
),
"shapes"
:
list
(
self
.
shapes
)
[
0
:
self
.
ndim
]
,
"ndim"
:
self
.
ndim
,
"data_type"
:
_lite_type_to_nptypes
[
LiteDataType
(
self
.
data_type
)],
}
...
...
lite/src/mge/common.cpp
浏览文件 @
565466c2
...
...
@@ -100,6 +100,9 @@ LTensorLayout lite::to_impl_layout(const Layout& layout) {
case
LiteDataType
::
LITE_INT16
:
mge_layout
.
dtype
=
mgb
::
dtype
::
Int16
();
break
;
case
LiteDataType
::
LITE_UINT16
:
mge_layout
.
dtype
=
mgb
::
dtype
::
Uint16
();
break
;
default:
LITE_THROW
(
mgb
::
ssprintf
(
"unsupport dtype in lite enum id is %d."
,
...
...
@@ -133,6 +136,9 @@ Layout lite::to_lite_layout(const LTensorLayout& mge_layout) {
case
mgb
::
DTypeEnum
::
Int16
:
layout
.
data_type
=
LiteDataType
::
LITE_INT16
;
break
;
case
mgb
::
DTypeEnum
::
Uint16
:
layout
.
data_type
=
LiteDataType
::
LITE_UINT16
;
break
;
case
mgb
::
DTypeEnum
::
Int8
:
layout
.
data_type
=
LiteDataType
::
LITE_INT8
;
break
;
...
...
lite/src/mge/network_impl.cpp
浏览文件 @
565466c2
...
...
@@ -442,6 +442,24 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) {
}
}
void
NetworkImplDft
::
try_infer_tensor_layout
(
std
::
shared_ptr
<
Tensor
>
tensor
,
mgb
::
cg
::
SymbolVar
var
)
{
auto
&&
static_infer_mgr
=
m_load_config
.
comp_graph
->
static_infer_manager
();
auto
infer_trait
=
var
.
node
()
->
get_static_infer_trait
();
if
(
std
::
get
<
0
>
(
infer_trait
))
{
auto
shape
=
static_infer_mgr
.
infer_shape_fallible
(
var
.
node
());
if
(
!
shape
)
{
LITE_WARN
(
"Lite infer output shape failed, maybe the model is "
"dynamic "
"shape.
\n
"
);
return
;
}
Layout
layout
=
to_lite_layout
(
mgb
::
TensorLayout
{
*
shape
,
var
.
dtype
()});
tensor
->
set_layout
(
layout
);
}
}
void
NetworkImplDft
::
update_io
()
{
update_input
();
update_output
();
...
...
@@ -564,6 +582,14 @@ void NetworkImplDft::update_output() {
out_it
->
lite_tensor
=
std
::
make_shared
<
Tensor
>
(
device_id
,
stream_id
,
device_type
);
}
mgb
::
SymbolVar
var
;
for
(
auto
&&
out_var
:
m_load_result
.
output_var_list
)
{
if
(
out_var
.
node
()
->
name
()
==
out_it
->
name
)
{
var
=
out_var
;
break
;
}
}
try_infer_tensor_layout
(
out_it
->
lite_tensor
,
var
);
}
//! user not set, use default output
}
else
{
...
...
@@ -579,12 +605,14 @@ void NetworkImplDft::update_output() {
it
->
lite_tensor
=
std
::
make_shared
<
Tensor
>
(
device_id
,
stream_id
,
device_type
);
}
try_infer_tensor_layout
(
it
->
lite_tensor
,
out
);
}
else
{
IOInner
output
;
output
.
name
=
out
.
node
()
->
name
();
output
.
lite_tensor
=
std
::
make_shared
<
Tensor
>
(
device_id
,
stream_id
,
device_type
,
true
);
m_network_io
->
outputs
.
push_back
({
output
});
try_infer_tensor_layout
(
output
.
lite_tensor
,
out
);
}
}
}
...
...
lite/src/mge/network_impl.h
浏览文件 @
565466c2
...
...
@@ -201,6 +201,10 @@ private:
//! compile the graph to get the execute function
void
compile_graph
();
//! try to infer output tensor layout
void
try_infer_tensor_layout
(
std
::
shared_ptr
<
Tensor
>
tensor
,
mgb
::
cg
::
SymbolVar
var
);
private:
bool
m_async
=
false
;
bool
m_is_cpu_inplace_mode
=
false
;
...
...
lite/src/misc.cpp
浏览文件 @
565466c2
...
...
@@ -102,6 +102,8 @@ LiteLogLevel lite::get_log_level() {
}
std
::
string
lite
::
ssprintf
(
const
char
*
format
,
...)
{
if
(
!
format
)
return
""
;
va_list
ap
;
va_start
(
ap
,
format
);
auto
ret
=
svsprintf
(
format
,
ap
);
...
...
@@ -110,6 +112,8 @@ std::string lite::ssprintf(const char* format, ...) {
}
void
lite
::
print_log
(
LiteLogLevel
level
,
const
char
*
format
,
...)
{
if
(
!
format
)
return
;
if
(
static_cast
<
uint32_t
>
(
level
)
<
static_cast
<
uint32_t
>
(
get_log_level
()))
{
return
;
}
...
...
lite/test/test_network.cpp
浏览文件 @
565466c2
...
...
@@ -90,6 +90,11 @@ TEST(TestNetWork, GetAllName) {
auto
input_names
=
network
->
get_all_input_name
();
auto
output_names
=
network
->
get_all_output_name
();
auto
output_tensor
=
network
->
get_output_tensor
(
0
);
auto
out_layout
=
output_tensor
->
get_layout
();
ASSERT_EQ
(
out_layout
.
ndim
,
2
);
ASSERT_EQ
(
out_layout
.
shapes
[
0
],
1
);
ASSERT_EQ
(
out_layout
.
shapes
[
1
],
1000
);
ASSERT_EQ
(
input_names
.
size
(),
1
);
ASSERT_EQ
(
output_names
.
size
(),
1
);
ASSERT_TRUE
(
input_names
[
0
]
==
"data"
);
...
...
src/core/include/megbrain/graph/var_node.h
浏览文件 @
565466c2
...
...
@@ -488,6 +488,13 @@ public:
*/
MemAllocPlan
&
init_mem_plan
(
const
DeviceTensorND
*
fixed_alloc
=
nullptr
);
/*!
* \brief get the shape and value infer trait
*/
const
std
::
tuple
<
void
*
,
void
*>&
get_static_infer_trait
()
{
return
m_static_infer_trait
;
}
private:
//! whether its memory should be allocated by mgb system during graph
//! execution; initialized in VarNodeMemManager::reset_opr_seq()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录