Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
26b52a61
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
26b52a61
编写于
1月 24, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(lite): add get model infomation before create network interface
GitOrigin-RevId: e499f3ebf8e03ccbe25e9b698c9e351fd19f0ed6
上级
5e17b3e4
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
379 addition
and
3 deletion
+379
-3
lite/include/lite/network.h
lite/include/lite/network.h
+8
-0
lite/lite-c/include/lite-c/network_c.h
lite/lite-c/include/lite-c/network_c.h
+22
-0
lite/lite-c/src/network.cpp
lite/lite-c/src/network.cpp
+80
-0
lite/pylite/megenginelite/network.py
lite/pylite/megenginelite/network.py
+32
-0
lite/pylite/test/test_utils.py
lite/pylite/test/test_utils.py
+18
-0
lite/src/function_base.h
lite/src/function_base.h
+3
-3
lite/src/mge/function_dft.h
lite/src/mge/function_dft.h
+20
-0
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+70
-0
lite/src/mge/network_impl.h
lite/src/mge/network_impl.h
+7
-0
lite/src/network.cpp
lite/src/network.cpp
+22
-0
lite/test/test_network.cpp
lite/test/test_network.cpp
+48
-0
lite/test/test_network_c.cpp
lite/test/test_network_c.cpp
+49
-0
未找到文件。
lite/include/lite/network.h
浏览文件 @
26b52a61
...
...
@@ -373,6 +373,14 @@ public:
//! dump network after global layout transform optimization
static
void
dump_layout_transform_model
(
std
::
shared_ptr
<
Network
>
network
,
std
::
string
optimized_model_path
);
//! get the model io information before model loaded by model path.
static
NetworkIO
get_model_io_info
(
const
std
::
string
&
model_path
,
const
Config
&
config
=
{});
//! get the model io information before model loaded by model memory.
static
NetworkIO
get_model_io_info
(
const
void
*
model_mem
,
size_t
size
,
const
Config
&
config
=
{});
};
}
// namespace lite
...
...
lite/lite-c/include/lite-c/network_c.h
浏览文件 @
26b52a61
...
...
@@ -588,6 +588,28 @@ LITE_API int LITE_enable_global_layout_transform(LiteNetwork network);
LITE_API
int
LITE_dump_layout_transform_model
(
LiteNetwork
network
,
const
char
*
dump_file_path
);
/**! get the model io information before model loaded by model path.
* \param[in] model_path The model file path
* \param[in] config The model config for loading
* \param[out] ios The model io infermation
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API
int
LITE_get_model_io_info_by_path
(
const
char
*
model_path
,
const
LiteConfig
config
,
LiteNetworkIO
*
ios
);
/** get the model io information before model loaded by model memory.
* \param[in] model_mem The model memory ptr
* \param[in] size The model memory ptr length
* \param[in] config The model config for loading
* \param[out] ios The model io infermation
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API
int
LITE_get_model_io_info_by_memory
(
const
void
*
model_mem
,
size_t
size
,
const
LiteConfig
config
,
LiteNetworkIO
*
ios
);
#ifdef __cplusplus
}
#endif
...
...
lite/lite-c/src/network.cpp
浏览文件 @
26b52a61
...
...
@@ -167,6 +167,31 @@ lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) {
return
network_io
;
}
struct
InnerIO
{
std
::
vector
<
std
::
string
>
names
;
std
::
vector
<
LiteIO
>
inputs
;
std
::
vector
<
LiteIO
>
outputs
;
};
InnerIO
convert_to_inner_io
(
const
lite
::
NetworkIO
&
network_io
)
{
InnerIO
innner_io
;
for
(
size_t
i
=
0
;
i
<
network_io
.
inputs
.
size
();
i
++
)
{
lite
::
IO
io
=
network_io
.
inputs
[
i
];
innner_io
.
names
.
push_back
(
io
.
name
);
innner_io
.
inputs
.
push_back
(
{
innner_io
.
names
.
back
().
c_str
(),
io
.
is_host
,
io
.
io_type
,
convert_to_clayout
(
io
.
config_layout
)});
}
for
(
size_t
i
=
0
;
i
<
network_io
.
outputs
.
size
();
i
++
)
{
lite
::
IO
io
=
network_io
.
outputs
[
i
];
innner_io
.
names
.
push_back
(
io
.
name
);
innner_io
.
outputs
.
push_back
(
{
innner_io
.
names
.
back
().
c_str
(),
io
.
is_host
,
io
.
io_type
,
convert_to_clayout
(
io
.
config_layout
)});
}
return
innner_io
;
}
int
LITE_make_default_network
(
LiteNetwork
*
network
)
{
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
...
...
@@ -665,4 +690,59 @@ int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_
lite
::
Runtime
::
dump_layout_transform_model
(
network_shared
,
dump_file_path
);
LITE_CAPI_END
();
}
namespace
{
static
LITE_MUTEX
mtx_io
;
static
std
::
unordered_map
<
const
void
*
,
InnerIO
>&
get_global_io_holder
()
{
static
std
::
unordered_map
<
const
void
*
,
InnerIO
>
global_holder
;
return
global_holder
;
}
int
write_ios_from_cpp_io
(
const
lite
::
NetworkIO
&
cpp_io
,
LiteNetworkIO
*
ios
,
const
void
*
key
)
{
LITE_CAPI_BEGIN
();
LITE_LOCK_GUARD
(
mtx_io
);
get_global_io_holder
()[
key
]
=
convert_to_inner_io
(
cpp_io
);
auto
&&
inner_io
=
get_global_io_holder
()[
key
];
ios
->
input_size
=
inner_io
.
inputs
.
size
();
ios
->
output_size
=
inner_io
.
outputs
.
size
();
ios
->
inputs
=
inner_io
.
inputs
.
data
();
ios
->
outputs
=
inner_io
.
outputs
.
data
();
size_t
i
=
0
;
for
(;
i
<
ios
->
input_size
;
i
++
)
{
auto
io_ptr
=
ios
->
inputs
+
i
;
io_ptr
->
name
=
inner_io
.
names
[
i
].
c_str
();
}
for
(;
i
<
ios
->
output_size
;
i
++
)
{
auto
io_ptr
=
ios
->
outputs
+
i
;
io_ptr
->
name
=
inner_io
.
names
[
i
].
c_str
();
}
LITE_CAPI_END
();
}
}
// namespace
int
LITE_get_model_io_info_by_path
(
const
char
*
model_path
,
const
LiteConfig
config
,
LiteNetworkIO
*
ios
)
{
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
model_path
,
"The model_path pass to LITE api is null"
);
auto
&&
cpp_ios
=
lite
::
Runtime
::
get_model_io_info
(
std
::
string
{
model_path
},
convert_to_lite_config
(
config
));
return
write_ios_from_cpp_io
(
cpp_ios
,
ios
,
reinterpret_cast
<
const
void
*>
(
model_path
));
LITE_CAPI_END
();
}
int
LITE_get_model_io_info_by_memory
(
const
void
*
model_mem
,
size_t
size
,
const
LiteConfig
config
,
LiteNetworkIO
*
ios
)
{
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
model_mem
,
"The model_mem pass to LITE api is null"
);
auto
&&
cpp_ios
=
lite
::
Runtime
::
get_model_io_info
(
model_mem
,
size
,
convert_to_lite_config
(
config
));
return
write_ios_from_cpp_io
(
cpp_ios
,
ios
,
reinterpret_cast
<
const
void
*>
(
model_mem
));
LITE_CAPI_END
();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
lite/pylite/megenginelite/network.py
浏览文件 @
26b52a61
...
...
@@ -364,6 +364,14 @@ class _NetworkAPI(_LiteCObjBase):
(
"LITE_get_static_memory_alloc_info"
,
[
_Cnetwork
,
c_char_p
]),
(
"LITE_enable_global_layout_transform"
,
[
_Cnetwork
]),
(
"LITE_dump_layout_transform_model"
,
[
_Cnetwork
,
c_char_p
]),
(
"LITE_get_model_io_info_by_path"
,
[
c_char_p
,
LiteConfig
,
POINTER
(
_LiteNetworkIO
)],
),
(
"LITE_get_model_io_info_by_memory"
,
[
c_char_p
,
c_size_t
,
LiteConfig
,
POINTER
(
_LiteNetworkIO
)],
),
]
...
...
@@ -619,3 +627,27 @@ class LiteNetwork(object):
def
dump_layout_transform_model
(
self
,
model_file
):
c_file
=
model_file
.
encode
(
"utf-8"
)
self
.
_api
.
LITE_dump_layout_transform_model
(
self
.
_network
,
c_file
)
def
get_model_io_info
(
model_path
,
config
=
None
):
"""
get the model IO information before create the NetWork, this IO
information can be used to configuration the NetWork.
"""
api
=
_NetworkAPI
().
_lib
c_path
=
c_char_p
(
model_path
.
encode
(
"utf-8"
))
ios
=
_LiteNetworkIO
()
if
config
is
not
None
:
api
.
LITE_get_model_io_info_by_path
(
c_path
,
config
,
byref
(
ios
))
else
:
config
=
LiteConfig
()
api
.
LITE_get_model_io_info_by_path
(
c_path
,
config
,
byref
(
ios
))
ret_ios
=
LiteNetworkIO
()
for
i
in
range
(
ios
.
input_size
):
ret_ios
.
add_input
(
ios
.
inputs
[
i
])
for
i
in
range
(
ios
.
output_size
):
ret_ios
.
add_output
(
ios
.
outputs
[
i
])
return
ret_ios
lite/pylite/test/test_utils.py
浏览文件 @
26b52a61
...
...
@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
functools
import
os
import
numpy
as
np
...
...
@@ -200,3 +201,20 @@ def test_tensor_collect_batch_device_numpy():
for
i
in
range
(
4
):
for
j
in
range
(
48
):
assert
data
[
i
][
j
//
8
][
j
%
8
]
==
i
+
1
def
test_get_model_io_ahead
():
source_dir
=
os
.
getenv
(
"LITE_TEST_RESOURCE"
)
model_path
=
os
.
path
.
join
(
source_dir
,
"shufflenet.mge"
)
ios
=
get_model_io_info
(
model_path
)
assert
len
(
ios
.
inputs
)
==
1
assert
ios
.
inputs
[
0
].
name
==
"data"
assert
ios
.
inputs
[
0
].
config_layout
.
shapes
[
1
]
==
3
assert
ios
.
inputs
[
0
].
config_layout
.
shapes
[
2
]
==
224
assert
ios
.
inputs
[
0
].
config_layout
.
shapes
[
3
]
==
224
assert
len
(
ios
.
outputs
)
==
1
assert
ios
.
outputs
[
0
].
name
==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
assert
ios
.
outputs
[
0
].
config_layout
.
shapes
[
0
]
==
1
assert
ios
.
outputs
[
0
].
config_layout
.
shapes
[
1
]
==
1000
lite/src/function_base.h
浏览文件 @
26b52a61
...
...
@@ -34,7 +34,7 @@ ADD_STATEMENT(NetworkImplDft, Dft);
}
// namespace
// if it can't find the function, ignore
template
<
typename
t
ensor_t
ype
,
typename
ret_type
,
typename
...
Args
>
template
<
typename
type
,
typename
ret_type
,
typename
...
Args
>
ret_type
try_call_func
(
std
::
string
func_name
,
Args
...
args
)
{
mark_used_variable
(
func_name
);
mark_used_variable
(
args
...);
...
...
@@ -42,10 +42,10 @@ ret_type try_call_func(std::string func_name, Args... args) {
}
// if it can't find the function, throw error
template
<
typename
t
ensor_t
ype
,
typename
ret_type
,
typename
...
Args
>
template
<
typename
type
,
typename
ret_type
,
typename
...
Args
>
ret_type
call_func
(
std
::
string
func_name
,
Args
...
args
)
{
mark_used_variable
(
args
...);
auto
backend_name
=
class_type_name
<
t
ensor_t
ype
>
()();
auto
backend_name
=
class_type_name
<
type
>
()();
auto
msg_info
=
func_name
+
" is not aviliable in "
+
backend_name
+
" backend."
;
LITE_THROW
(
msg_info
.
c_str
());
}
...
...
lite/src/mge/function_dft.h
浏览文件 @
26b52a61
...
...
@@ -206,6 +206,26 @@ inline void call_func<NetworkImplDft, void>(
THROW_FUNC_ERROR
(
func_name
);
}
}
template
<
>
inline
NetworkIO
call_func
<
NetworkImplDft
,
NetworkIO
>
(
std
::
string
func_name
,
std
::
string
model_path
,
Config
config
)
{
if
(
func_name
==
"get_model_io_info"
)
{
return
get_model_io_info_dft
(
model_path
,
config
);
}
else
{
THROW_FUNC_ERROR
(
func_name
);
}
}
template
<
>
inline
NetworkIO
call_func
<
NetworkImplDft
,
NetworkIO
>
(
std
::
string
func_name
,
const
void
*
model_mem
,
size_t
size
,
Config
config
)
{
if
(
func_name
==
"get_model_io_info"
)
{
return
get_model_io_info_dft
(
model_mem
,
size
,
config
);
}
else
{
THROW_FUNC_ERROR
(
func_name
);
}
}
#undef THROW_FUNC_ERROR
}
// namespace lite
...
...
lite/src/mge/network_impl.cpp
浏览文件 @
26b52a61
...
...
@@ -929,5 +929,75 @@ void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_pat
"enable_global_layout_transform before"
));
}
}
NetworkIO
lite
::
get_model_io_info_dft
(
const
std
::
string
&
model_path
,
const
Config
&
config
)
{
FILE
*
fin
=
fopen
(
model_path
.
c_str
(),
"rb"
);
LITE_ASSERT
(
fin
,
"failed to open %s: %s"
,
model_path
.
c_str
(),
strerror
(
errno
));
fseek
(
fin
,
0
,
SEEK_END
);
size_t
size
=
ftell
(
fin
);
fseek
(
fin
,
0
,
SEEK_SET
);
void
*
ptr
=
malloc
(
size
);
std
::
shared_ptr
<
void
>
buf
{
ptr
,
::
free
};
auto
nr
=
fread
(
buf
.
get
(),
1
,
size
,
fin
);
LITE_ASSERT
(
nr
==
size
);
fclose
(
fin
);
return
get_model_io_info_dft
(
ptr
,
size
,
config
);
}
NetworkIO
lite
::
get_model_io_info_dft
(
const
void
*
model_mem
,
size_t
size
,
const
Config
&
config
)
{
std
::
shared_ptr
<
void
>
model
{
const_cast
<
void
*>
(
model_mem
),
[](
void
*
)
{}};
auto
input_file
=
mgb
::
serialization
::
InputFile
::
make_mem_proxy
(
model
,
size
,
false
);
auto
format
=
mgb
::
serialization
::
GraphLoader
::
identify_graph_dump_format
(
*
input_file
);
if
(
!
format
.
valid
())
{
LITE_THROW
(
"invalid model format"
);
}
auto
loader
=
mgb
::
serialization
::
GraphLoader
::
make
(
std
::
move
(
input_file
),
format
.
val
());
mgb
::
serialization
::
GraphLoadConfig
load_config
;
load_config
.
comp_graph
=
mgb
::
ComputingGraph
::
make
();
if
(
config
.
has_compression
)
{
load_config
.
tensor_value_loader
=
decompressed_tensor_value_loader
;
}
auto
compnode_locator
=
to_compnode_locator
(
config
.
device_type
);
load_config
.
comp_node_mapper
=
[
=
](
mgb
::
CompNode
::
Locator
&
loc
)
{
if
(
loc
.
type
==
mgb
::
CompNode
::
DeviceType
::
UNSPEC
)
{
loc
.
type
=
compnode_locator
.
type
;
}
loc
.
device
=
compnode_locator
.
device
;
};
auto
load_result
=
loader
->
load
(
load_config
,
true
);
NetworkIO
IOs
;
for
(
auto
&&
in_tensor_iter
:
load_result
.
tensor_map
)
{
IO
in_io
;
in_io
.
name
=
in_tensor_iter
.
first
;
in_io
.
config_layout
=
to_lite_layout
(
in_tensor_iter
.
second
->
layout
());
IOs
.
inputs
.
push_back
(
in_io
);
}
auto
infer_shape
=
[
=
](
mgb
::
cg
::
SymbolVar
var
)
->
const
megdnn
::
TensorShape
*
{
auto
&&
static_infer_mgr
=
load_config
.
comp_graph
->
static_infer_manager
();
using
InferType
=
mgb
::
cg
::
static_infer
::
InferType
;
if
(
static_infer_mgr
.
get_infer_type
(
var
.
node
()).
shape
&
(
InferType
::
CONST
|
InferType
::
RT_STATIC
))
{
return
static_infer_mgr
.
infer_shape_fallible
(
var
.
node
());
}
else
{
return
nullptr
;
}
};
for
(
auto
&&
out
:
load_result
.
output_var_list
)
{
IO
out_io
;
out_io
.
name
=
out
.
node
()
->
name
();
if
(
auto
shape
=
infer_shape
(
out
))
{
out_io
.
config_layout
=
to_lite_layout
(
TensorLayout
{
*
shape
,
out
.
dtype
()});
}
else
{
out_io
.
config_layout
=
to_lite_layout
(
TensorLayout
{{},
out
.
dtype
()});
}
IOs
.
outputs
.
push_back
(
out_io
);
}
return
IOs
;
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
lite/src/mge/network_impl.h
浏览文件 @
26b52a61
...
...
@@ -262,6 +262,13 @@ private:
#endif
std
::
unique_ptr
<
mgb
::
OprIODumpBase
>
m_iodump
;
};
//! get the model information before model loaded by Network
NetworkIO
get_model_io_info_dft
(
const
std
::
string
&
model_path
,
const
Config
&
config
);
//! get the model information before model loaded by Network by model memory and
//! size
NetworkIO
get_model_io_info_dft
(
const
void
*
model_mem
,
size_t
size
,
const
Config
&
config
);
}
// namespace lite
...
...
lite/src/network.cpp
浏览文件 @
26b52a61
...
...
@@ -534,4 +534,26 @@ void Runtime::dump_layout_transform_model(
LITE_THROW
(
"dump_layout_transform_model is not aviliable in the backend."
);
LITE_ERROR_HANDLER_END
}
NetworkIO
Runtime
::
get_model_io_info
(
const
std
::
string
&
model_path
,
const
Config
&
config
)
{
LITE_ERROR_HANDLER_BEGIN
if
(
config
.
backend
==
LiteBackend
::
LITE_DEFAULT
)
{
return
call_func
<
NetworkImplDft
,
NetworkIO
>
(
"get_model_io_info"
,
model_path
,
config
);
}
LITE_THROW
(
"get_model_io_info is not aviliable in the backend."
);
LITE_ERROR_HANDLER_END
}
NetworkIO
Runtime
::
get_model_io_info
(
const
void
*
model_mem
,
size_t
size
,
const
Config
&
config
)
{
LITE_ERROR_HANDLER_BEGIN
if
(
config
.
backend
==
LiteBackend
::
LITE_DEFAULT
)
{
return
call_func
<
NetworkImplDft
,
NetworkIO
>
(
"get_model_io_info"
,
model_mem
,
size
,
config
);
}
LITE_THROW
(
"get_model_io_info is not aviliable in the backend."
);
LITE_ERROR_HANDLER_END
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
lite/test/test_network.cpp
浏览文件 @
26b52a61
...
...
@@ -106,6 +106,54 @@ TEST(TestNetWork, GetAllName) {
ASSERT_TRUE
(
output_names
[
0
]
==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
);
}
TEST
(
TestNetWork
,
GetAllIoInfoAhead
)
{
Config
config
;
std
::
string
model_path
=
"./shufflenet.mge"
;
auto
ios
=
Runtime
::
get_model_io_info
(
model_path
);
FILE
*
fin
=
fopen
(
model_path
.
c_str
(),
"rb"
);
ASSERT_TRUE
(
fin
);
fseek
(
fin
,
0
,
SEEK_END
);
size_t
size
=
ftell
(
fin
);
fseek
(
fin
,
0
,
SEEK_SET
);
void
*
ptr
=
malloc
(
size
);
std
::
shared_ptr
<
void
>
buf
{
ptr
,
::
free
};
auto
nr
=
fread
(
buf
.
get
(),
1
,
size
,
fin
);
LITE_ASSERT
(
nr
==
size
);
fclose
(
fin
);
auto
ios_mem
=
Runtime
::
get_model_io_info
(
ptr
,
size
);
ASSERT_EQ
(
ios
.
inputs
.
size
(),
ios_mem
.
inputs
.
size
());
ASSERT_EQ
(
ios
.
inputs
.
size
(),
1
);
ASSERT_EQ
(
ios
.
outputs
.
size
(),
ios_mem
.
outputs
.
size
());
ASSERT_EQ
(
ios
.
outputs
.
size
(),
1
);
ASSERT_TRUE
(
ios
.
inputs
[
0
].
name
==
"data"
);
ASSERT_TRUE
(
ios
.
outputs
[
0
].
name
==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
);
ASSERT_TRUE
(
ios_mem
.
inputs
[
0
].
name
==
"data"
);
ASSERT_TRUE
(
ios_mem
.
outputs
[
0
].
name
==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
);
ASSERT_EQ
(
ios
.
inputs
[
0
].
config_layout
.
ndim
,
4
);
ASSERT_EQ
(
ios
.
inputs
[
0
].
config_layout
.
shapes
[
1
],
3
);
ASSERT_EQ
(
ios
.
inputs
[
0
].
config_layout
.
shapes
[
2
],
224
);
ASSERT_EQ
(
ios
.
outputs
[
0
].
config_layout
.
ndim
,
2
);
ASSERT_EQ
(
ios
.
outputs
[
0
].
config_layout
.
shapes
[
0
],
1
);
ASSERT_EQ
(
ios
.
outputs
[
0
].
config_layout
.
shapes
[
1
],
1000
);
ASSERT_EQ
(
ios_mem
.
inputs
[
0
].
config_layout
.
ndim
,
4
);
ASSERT_EQ
(
ios_mem
.
inputs
[
0
].
config_layout
.
shapes
[
1
],
3
);
ASSERT_EQ
(
ios_mem
.
inputs
[
0
].
config_layout
.
shapes
[
2
],
224
);
ASSERT_EQ
(
ios_mem
.
outputs
[
0
].
config_layout
.
ndim
,
2
);
ASSERT_EQ
(
ios_mem
.
outputs
[
0
].
config_layout
.
shapes
[
0
],
1
);
ASSERT_EQ
(
ios_mem
.
outputs
[
0
].
config_layout
.
shapes
[
1
],
1000
);
}
TEST
(
TestNetWork
,
LoadFBSModel
)
{
Config
config
;
std
::
string
model_path
=
"./ax.mge"
;
...
...
lite/test/test_network_c.cpp
浏览文件 @
26b52a61
...
...
@@ -252,6 +252,55 @@ TEST(TestCapiNetWork, GetAllName) {
LITE_destroy_network
(
c_network
);
}
TEST
(
TestCapiNetWork
,
GetAllNameAhead
)
{
std
::
string
model_path
=
"./shufflenet.mge"
;
LiteNetworkIO
ios
,
ios_mem
;
LITE_CAPI_CHECK
(
LITE_get_model_io_info_by_path
(
model_path
.
c_str
(),
*
default_config
(),
&
ios
));
FILE
*
fin
=
fopen
(
model_path
.
c_str
(),
"rb"
);
ASSERT_TRUE
(
fin
);
fseek
(
fin
,
0
,
SEEK_END
);
size_t
size
=
ftell
(
fin
);
fseek
(
fin
,
0
,
SEEK_SET
);
void
*
ptr
=
malloc
(
size
);
std
::
shared_ptr
<
void
>
buf
{
ptr
,
::
free
};
auto
nr
=
fread
(
buf
.
get
(),
1
,
size
,
fin
);
LITE_ASSERT
(
nr
==
size
);
fclose
(
fin
);
LITE_CAPI_CHECK
(
LITE_get_model_io_info_by_memory
(
ptr
,
size
,
*
default_config
(),
&
ios_mem
));
ASSERT_EQ
(
ios
.
input_size
,
1
);
ASSERT_EQ
(
ios
.
output_size
,
1
);
ASSERT_EQ
(
ios_mem
.
input_size
,
1
);
ASSERT_EQ
(
ios_mem
.
output_size
,
1
);
ASSERT_TRUE
(
std
::
string
(
ios
.
inputs
->
name
)
==
"data"
);
ASSERT_TRUE
(
ios
.
inputs
->
config_layout
.
ndim
==
4
);
ASSERT_TRUE
(
ios
.
inputs
->
config_layout
.
shapes
[
1
]
==
3
);
ASSERT_TRUE
(
ios
.
inputs
->
config_layout
.
shapes
[
2
]
==
224
);
ASSERT_TRUE
(
ios
.
inputs
->
config_layout
.
shapes
[
3
]
==
224
);
ASSERT_TRUE
(
std
::
string
(
ios
.
outputs
->
name
)
==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
);
ASSERT_TRUE
(
ios
.
outputs
->
config_layout
.
ndim
==
2
);
ASSERT_TRUE
(
ios
.
outputs
->
config_layout
.
shapes
[
0
]
==
1
);
ASSERT_TRUE
(
ios
.
outputs
->
config_layout
.
shapes
[
1
]
==
1000
);
ASSERT_TRUE
(
std
::
string
(
ios_mem
.
inputs
->
name
)
==
"data"
);
ASSERT_TRUE
(
ios_mem
.
inputs
->
config_layout
.
ndim
==
4
);
ASSERT_TRUE
(
ios_mem
.
inputs
->
config_layout
.
shapes
[
1
]
==
3
);
ASSERT_TRUE
(
ios_mem
.
inputs
->
config_layout
.
shapes
[
2
]
==
224
);
ASSERT_TRUE
(
ios_mem
.
inputs
->
config_layout
.
shapes
[
3
]
==
224
);
ASSERT_TRUE
(
std
::
string
(
ios_mem
.
outputs
->
name
)
==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
);
ASSERT_TRUE
(
ios_mem
.
outputs
->
config_layout
.
ndim
==
2
);
ASSERT_TRUE
(
ios_mem
.
outputs
->
config_layout
.
shapes
[
0
]
==
1
);
ASSERT_TRUE
(
ios_mem
.
outputs
->
config_layout
.
shapes
[
1
]
==
1000
);
}
#if LITE_BUILD_WITH_RKNPU
static
int
GetTop
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录