Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c361b193
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c361b193
编写于
10月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(lite-c): add lite C callback with user_data API
GitOrigin-RevId: a54237488fb5f394ddf06bb5fed6a547a1d2e931
上级
7fa5f6f4
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
375 addition
and
3 deletion
+375
-3
lite/example/cpp_example/example.h
lite/example/cpp_example/example.h
+2
-1
lite/example/cpp_example/main.cpp
lite/example/cpp_example/main.cpp
+2
-0
lite/example/cpp_example/mge/basic.cpp
lite/example/cpp_example/mge/basic.cpp
+136
-0
lite/lite-c/include/lite-c/network_c.h
lite/lite-c/include/lite-c/network_c.h
+50
-0
lite/lite-c/src/network.cpp
lite/lite-c/src/network.cpp
+72
-0
lite/test/test_network_c.cpp
lite/test/test_network_c.cpp
+113
-2
未找到文件。
lite/example/cpp_example/example.h
浏览文件 @
c361b193
...
...
@@ -67,7 +67,8 @@ bool config_user_allocator(const Args& args);
bool
register_cryption_method
(
const
Args
&
args
);
bool
update_cryption_key
(
const
Args
&
args
);
bool
async_forward
(
const
Args
&
args
);
bool
set_input_callback
(
const
Args
&
arg
);
bool
set_output_callback
(
const
Args
&
arg
);
#if LITE_WITH_CUDA
bool
device_input
(
const
Args
&
args
);
bool
device_input_output
(
const
Args
&
args
);
...
...
lite/example/cpp_example/main.cpp
浏览文件 @
c361b193
...
...
@@ -160,6 +160,8 @@ REGIST_EXAMPLE("reset_input", reset_input);
REGIST_EXAMPLE
(
"reset_input_output"
,
reset_input_output
);
REGIST_EXAMPLE
(
"config_user_allocator"
,
config_user_allocator
);
REGIST_EXAMPLE
(
"async_forward"
,
async_forward
);
REGIST_EXAMPLE
(
"set_input_callback"
,
set_input_callback
);
REGIST_EXAMPLE
(
"set_output_callback"
,
set_output_callback
);
REGIST_EXAMPLE
(
"basic_c_interface"
,
basic_c_interface
);
REGIST_EXAMPLE
(
"device_io_c_interface"
,
device_io_c_interface
);
...
...
lite/example/cpp_example/mge/basic.cpp
浏览文件 @
c361b193
...
...
@@ -365,6 +365,142 @@ bool lite::example::async_forward(const Args& args) {
printf
(
"max=%e, sum=%e
\n
"
,
max
,
sum
);
return
true
;
}
bool
lite
::
example
::
set_input_callback
(
const
Args
&
args
)
{
std
::
string
network_path
=
args
.
model_path
;
std
::
string
input_path
=
args
.
input_path
;
Config
config
;
config
.
options
.
var_sanity_check_first_run
=
false
;
//! create and load the network
std
::
shared_ptr
<
Network
>
network
=
std
::
make_shared
<
Network
>
(
config
);
network
->
load_model
(
network_path
);
//! set input data to input tensor
std
::
shared_ptr
<
Tensor
>
input_tensor
=
network
->
get_input_tensor
(
0
);
//! copy or forward data to network
size_t
length
=
input_tensor
->
get_tensor_total_size_in_byte
();
void
*
dst_ptr
=
input_tensor
->
get_memory_ptr
();
auto
src_tensor
=
parse_npy
(
input_path
);
void
*
src
=
src_tensor
->
get_memory_ptr
();
memcpy
(
dst_ptr
,
src
,
length
);
//! set input callback
volatile
bool
finished
=
false
;
network
->
set_start_callback
(
[
&
finished
](
const
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
IO
,
std
::
shared_ptr
<
Tensor
>>>&
inputs
)
{
#if !__DEPLOY_ON_XP_SP2__
std
::
cout
<<
"worker thread_id:"
<<
std
::
this_thread
::
get_id
()
<<
std
::
endl
;
#endif
for
(
auto
&&
item
:
inputs
)
{
std
::
cout
<<
"input name: "
<<
item
.
first
<<
"input dim: "
<<
item
.
second
.
second
->
get_layout
().
ndim
<<
std
::
endl
;
}
finished
=
true
;
});
#if !__DEPLOY_ON_XP_SP2__
std
::
cout
<<
"out thread_id:"
<<
std
::
this_thread
::
get_id
()
<<
std
::
endl
;
#endif
//! forward
network
->
forward
();
size_t
count
=
0
;
while
(
finished
==
false
)
{
count
++
;
}
printf
(
"Forward finish, count is %zu
\n
"
,
count
);
//! get the output data or read tensor set in network_in
std
::
shared_ptr
<
Tensor
>
output_tensor
=
network
->
get_output_tensor
(
0
);
void
*
out_data
=
output_tensor
->
get_memory_ptr
();
size_t
out_length
=
output_tensor
->
get_tensor_total_size_in_byte
()
/
output_tensor
->
get_layout
().
get_elem_size
();
printf
(
"length=%zu
\n
"
,
length
);
float
max
=
-
1.0
f
;
float
sum
=
0.0
f
;
for
(
size_t
i
=
0
;
i
<
out_length
;
i
++
)
{
float
data
=
static_cast
<
float
*>
(
out_data
)[
i
];
sum
+=
data
;
if
(
max
<
data
)
max
=
data
;
}
printf
(
"max=%e, sum=%e
\n
"
,
max
,
sum
);
return
true
;
}
bool
lite
::
example
::
set_output_callback
(
const
Args
&
args
)
{
std
::
string
network_path
=
args
.
model_path
;
std
::
string
input_path
=
args
.
input_path
;
Config
config
;
config
.
options
.
var_sanity_check_first_run
=
false
;
//! create and load the network
std
::
shared_ptr
<
Network
>
network
=
std
::
make_shared
<
Network
>
(
config
);
network
->
load_model
(
network_path
);
//! set input data to input tensor
std
::
shared_ptr
<
Tensor
>
input_tensor
=
network
->
get_output_tensor
(
0
);
//! copy or forward data to network
size_t
length
=
input_tensor
->
get_tensor_total_size_in_byte
();
void
*
dst_ptr
=
input_tensor
->
get_memory_ptr
();
auto
src_tensor
=
parse_npy
(
input_path
);
void
*
src
=
src_tensor
->
get_memory_ptr
();
memcpy
(
dst_ptr
,
src
,
length
);
//! set output callback
volatile
bool
finished
=
false
;
network
->
set_finish_callback
(
[
&
finished
](
const
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
IO
,
std
::
shared_ptr
<
Tensor
>>>&
outputs
)
{
#if !__DEPLOY_ON_XP_SP2__
std
::
cout
<<
"worker thread_id:"
<<
std
::
this_thread
::
get_id
()
<<
std
::
endl
;
#endif
for
(
auto
&&
item
:
outputs
)
{
std
::
cout
<<
"output name: "
<<
item
.
first
<<
"output dim: "
<<
item
.
second
.
second
->
get_layout
().
ndim
<<
std
::
endl
;
}
finished
=
true
;
});
#if !__DEPLOY_ON_XP_SP2__
std
::
cout
<<
"out thread_id:"
<<
std
::
this_thread
::
get_id
()
<<
std
::
endl
;
#endif
//! forward
network
->
forward
();
network
->
wait
();
size_t
count
=
0
;
while
(
finished
==
false
)
{
count
++
;
}
printf
(
"Forward finish, count is %zu
\n
"
,
count
);
//! get the output data or read tensor set in network_in
std
::
shared_ptr
<
Tensor
>
output_tensor
=
network
->
get_output_tensor
(
0
);
void
*
out_data
=
output_tensor
->
get_memory_ptr
();
size_t
out_length
=
output_tensor
->
get_tensor_total_size_in_byte
()
/
output_tensor
->
get_layout
().
get_elem_size
();
printf
(
"length=%zu
\n
"
,
length
);
float
max
=
-
1.0
f
;
float
sum
=
0.0
f
;
for
(
size_t
i
=
0
;
i
<
out_length
;
i
++
)
{
float
data
=
static_cast
<
float
*>
(
out_data
)[
i
];
sum
+=
data
;
if
(
max
<
data
)
max
=
data
;
}
printf
(
"max=%e, sum=%e
\n
"
,
max
,
sum
);
return
true
;
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
lite/lite-c/include/lite-c/network_c.h
浏览文件 @
c361b193
...
...
@@ -184,6 +184,8 @@ typedef int (*LiteThreadAffinityCallback)(int thread_id);
typedef
int
(
*
LiteAsyncCallback
)();
typedef
int
(
*
LiteAsyncCallbackWithData
)(
void
*
user_data
);
/*!
* \brief the start/finish callback function
* \param unordered_map map from the io tensor name to the pair of which is the
...
...
@@ -193,9 +195,17 @@ typedef int (*LiteAsyncCallback)();
typedef
int
(
*
LiteStartCallback
)(
const
LiteIO
*
inputs
,
const
LiteTensor
*
input_tensors
,
size_t
size
);
typedef
int
(
*
LiteStartCallbackWithData
)(
const
LiteIO
*
inputs
,
const
LiteTensor
*
input_tensors
,
size_t
size
,
void
*
user_data
);
typedef
int
(
*
LiteFinishCallback
)(
const
LiteIO
*
outputs
,
const
LiteTensor
*
output_tensors
,
size_t
size
);
typedef
int
(
*
LiteFinishCallbackWithData
)(
const
LiteIO
*
outputs
,
const
LiteTensor
*
output_tensors
,
size_t
size
,
void
*
user_data
);
/*!
* \brief The network is construct form a model, implement model load, init,
* forward, and display some model information
...
...
@@ -442,6 +452,19 @@ LITE_API int LITE_set_network_algo_workspace_limit(
LITE_API
int
LITE_set_async_callback
(
LiteNetwork
network
,
const
LiteAsyncCallback
async_callback
);
/**
* \brief set the network forward in async mode and set the async callback
* function
* \param[in] network The loaded model
* \param[in] async_callback when network finish forwarding, the callback
* will be called
* \param[in] user_data user defined data for something user want to deploy
* at forward finish stage
*/
LITE_API
int
LITE_set_async_callback_with_userdata
(
LiteNetwork
network
,
const
LiteAsyncCallbackWithData
async_callback
,
void
*
user_data
);
/**
* \brief set the start forward callback function, which will be execute beform
* forward, this can be used to check network input or dump model inputs
...
...
@@ -453,6 +476,20 @@ LITE_API int LITE_set_async_callback(
LITE_API
int
LITE_set_start_callback
(
LiteNetwork
network
,
const
LiteStartCallback
start_callback
);
/**
* \brief set the start forward callback function, which will be execute beform
* forward, this can be used to check network input or dump model inputs
* for debug
* \param[in] network The loaded model
* \param[in] start_callback when network start forwarding, the callbak
* will be called
* \param[in] user_data user defined data for something user want to deploy
* at forward start stage
*/
LITE_API
int
LITE_set_start_callback_with_userdata
(
LiteNetwork
network
,
const
LiteStartCallbackWithData
start_callback
,
void
*
user_data
);
/**
* \brief set the finish forward callback function, which will be execute after
* forward, this can be used to dump model outputs for debug
...
...
@@ -463,6 +500,19 @@ LITE_API int LITE_set_start_callback(
LITE_API
int
LITE_set_finish_callback
(
LiteNetwork
network
,
const
LiteFinishCallback
finish_callback
);
/**
* \brief set the finish forward callback function, which will be execute after
* forward, this can be used to dump model outputs for debug
* \param[in] network The loaded model
* \param[in] finish_callback when network finish forwarding, the callbak
* will be called
* \param[in] user_data user defined data for something user want to deploy
* at finish stage
*/
LITE_API
int
LITE_set_finish_callback_with_userdata
(
LiteNetwork
network
,
const
LiteFinishCallbackWithData
finish_callback
,
void
*
user_data
);
/**
* \brief set threads affinity callback
* \param[in] network The loaded model
...
...
lite/lite-c/src/network.cpp
浏览文件 @
c361b193
...
...
@@ -355,6 +355,22 @@ int LITE_set_async_callback(
LITE_CAPI_END
();
}
int
LITE_set_async_callback_with_userdata
(
LiteNetwork
network
,
LiteAsyncCallbackWithData
async_callback
,
void
*
user_data
)
{
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
LITE_ASSERT
(
async_callback
,
"The ptr pass to LITE api is null"
);
auto
lite_async_callback
=
[
async_callback
,
user_data
]()
->
void
{
async_callback
(
user_data
);
};
static_cast
<
lite
::
Network
*>
(
network
)
->
set_async_callback
(
std
::
move
(
lite_async_callback
));
LITE_CAPI_END
();
}
int
LITE_set_start_callback
(
LiteNetwork
network
,
const
LiteStartCallback
start_callback
)
{
LITE_CAPI_BEGIN
();
...
...
@@ -381,6 +397,34 @@ int LITE_set_start_callback(
LITE_CAPI_END
();
}
int
LITE_set_start_callback_with_userdata
(
LiteNetwork
network
,
const
LiteStartCallbackWithData
start_callback
,
void
*
user_data
)
{
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
auto
lite_start_callback
=
[
start_callback
,
user_data
](
const
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
lite
::
IO
,
std
::
shared_ptr
<
lite
::
Tensor
>>>&
inputs_map
)
->
void
{
std
::
vector
<
LiteIO
>
ios
;
std
::
vector
<
LiteTensor
>
io_tensors
;
size_t
nr_io
=
0
;
for
(
const
auto
&
io
:
inputs_map
)
{
nr_io
++
;
auto
&&
lite_io
=
io
.
second
.
first
;
ios
.
push_back
(
{
lite_io
.
name
.
c_str
(),
lite_io
.
is_host
,
lite_io
.
io_type
,
convert_to_clayout
(
lite_io
.
config_layout
)});
io_tensors
.
push_back
(
io
.
second
.
second
.
get
());
}
start_callback
(
ios
.
data
(),
io_tensors
.
data
(),
nr_io
,
user_data
);
};
static_cast
<
lite
::
Network
*>
(
network
)
->
set_start_callback
(
lite_start_callback
);
LITE_CAPI_END
();
}
int
LITE_set_finish_callback
(
LiteNetwork
network
,
const
LiteFinishCallback
finish_callback
)
{
LITE_CAPI_BEGIN
();
...
...
@@ -407,6 +451,34 @@ int LITE_set_finish_callback(
LITE_CAPI_END
();
}
int
LITE_set_finish_callback_with_userdata
(
LiteNetwork
network
,
const
LiteFinishCallbackWithData
finish_callback
,
void
*
user_data
)
{
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
auto
lite_finish_callback
=
[
finish_callback
,
user_data
](
const
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
lite
::
IO
,
std
::
shared_ptr
<
lite
::
Tensor
>>>&
outputs_map
)
->
void
{
std
::
vector
<
LiteIO
>
ios
;
std
::
vector
<
LiteTensor
>
io_tensors
;
size_t
nr_io
=
0
;
for
(
const
auto
&
io
:
outputs_map
)
{
nr_io
++
;
auto
&&
lite_io
=
io
.
second
.
first
;
ios
.
push_back
(
{
lite_io
.
name
.
c_str
(),
lite_io
.
is_host
,
lite_io
.
io_type
,
convert_to_clayout
(
lite_io
.
config_layout
)});
io_tensors
.
push_back
(
io
.
second
.
second
.
get
());
}
finish_callback
(
ios
.
data
(),
io_tensors
.
data
(),
nr_io
,
user_data
);
};
static_cast
<
lite
::
Network
*>
(
network
)
->
set_finish_callback
(
lite_finish_callback
);
LITE_CAPI_END
();
}
int
LITE_enable_profile_performance
(
LiteNetwork
network
,
const
char
*
profile_json_file_path
)
{
LITE_CAPI_BEGIN
();
...
...
lite/test/test_network_c.cpp
浏览文件 @
c361b193
...
...
@@ -74,11 +74,21 @@ int multi_thread_affinity(int id) {
};
volatile
bool
finished
=
false
;
int
finish
_callback
()
{
int
async
_callback
()
{
finished
=
true
;
return
0
;
}
volatile
bool
finished_with_data
=
false
;
int
async_callback_with_data
(
void
*
user_data
)
{
if
(
user_data
!=
NULL
)
{
std
::
cout
<<
"async_callback user_data addr="
<<
std
::
hex
<<
user_data
<<
std
::
endl
;
}
finished_with_data
=
true
;
return
0
;
}
volatile
bool
start_checked
=
false
;
int
start_callback
(
const
LiteIO
*
inputs
,
const
LiteTensor
*
input_tensors
,
size_t
size
)
{
start_checked
=
true
;
...
...
@@ -96,6 +106,29 @@ int start_callback(const LiteIO* inputs, const LiteTensor* input_tensors, size_t
return
0
;
}
volatile
bool
start_checked_with_data
=
false
;
int
start_callback_with_data
(
const
LiteIO
*
inputs
,
const
LiteTensor
*
input_tensors
,
size_t
size
,
void
*
user_data
)
{
start_checked_with_data
=
true
;
auto
check_func
=
[
&
]()
{
if
(
user_data
!=
NULL
)
{
std
::
cout
<<
"start_callback user_data addr="
<<
std
::
hex
<<
user_data
<<
std
::
endl
;
}
ASSERT_EQ
(
size
,
1
);
ASSERT_EQ
(
std
::
string
(
inputs
->
name
),
"data"
);
LiteLayout
layout
;
LITE_get_tensor_layout
(
*
input_tensors
,
&
layout
);
ASSERT_EQ
(
layout
.
ndim
,
4
);
ASSERT_EQ
(
layout
.
shapes
[
1
],
3
);
ASSERT_EQ
(
layout
.
shapes
[
2
],
224
);
ASSERT_EQ
(
layout
.
shapes
[
3
],
224
);
};
check_func
();
return
0
;
}
volatile
bool
finish_checked
=
false
;
int
finish_callback
(
const
LiteIO
*
outputs
,
const
LiteTensor
*
output_tensors
,
size_t
size
)
{
...
...
@@ -113,6 +146,28 @@ int finish_callback(
return
0
;
}
volatile
bool
finish_checked_with_data
=
false
;
int
finish_callback_with_data
(
const
LiteIO
*
outputs
,
const
LiteTensor
*
output_tensors
,
size_t
size
,
void
*
user_data
)
{
finish_checked_with_data
=
true
;
auto
check_func
=
[
&
]()
{
if
(
user_data
!=
NULL
)
{
std
::
cout
<<
"finish_callback user_data addr="
<<
std
::
hex
<<
user_data
<<
std
::
endl
;
}
ASSERT_EQ
(
size
,
1
);
ASSERT_EQ
(
std
::
string
(
outputs
->
name
),
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
);
LiteLayout
layout
;
LITE_get_tensor_layout
(
*
output_tensors
,
&
layout
);
ASSERT_EQ
(
layout
.
shapes
[
1
],
1000
);
};
check_func
();
return
0
;
}
}
// namespace
#define LITE_CAPI_CHECK(_expr) \
...
...
@@ -671,6 +726,21 @@ TEST(TestCapiNetWork, StartCallBack) {
LITE_CAPI_CHECK
(
LITE_destroy_network
(
c_network
));
}
TEST
(
TestCapiNetWork
,
StartCallBackWithData
)
{
ForwardMgb
;
MakeNetwork
;
LoadNetwork
;
size_t
user_data
=
1
;
LITE_CAPI_CHECK
(
LITE_set_start_callback_with_userdata
(
c_network
,
start_callback_with_data
,
&
user_data
));
SetInput
;
ForwardNetwork
;
GetOutput
;
CompareResult
;
ASSERT_TRUE
(
start_checked_with_data
);
LITE_CAPI_CHECK
(
LITE_destroy_network
(
c_network
));
}
TEST
(
TestCapiNetWork
,
FinishCallBack
)
{
ForwardMgb
;
MakeNetwork
;
...
...
@@ -684,6 +754,21 @@ TEST(TestCapiNetWork, FinishCallBack) {
LITE_CAPI_CHECK
(
LITE_destroy_network
(
c_network
));
}
TEST
(
TestCapiNetWork
,
FinishCallBackWtihData
)
{
ForwardMgb
;
MakeNetwork
;
LoadNetwork
;
size_t
user_data
=
1
;
LITE_CAPI_CHECK
(
LITE_set_finish_callback_with_userdata
(
c_network
,
finish_callback_with_data
,
&
user_data
));
SetInput
;
ForwardNetwork
;
GetOutput
;
CompareResult
;
ASSERT_TRUE
(
finish_checked_with_data
);
LITE_CAPI_CHECK
(
LITE_destroy_network
(
c_network
));
}
TEST
(
TestCapiNetWork
,
BasicCryptAes
)
{
ForwardMgb
;
...
...
@@ -723,7 +808,7 @@ TEST(TestCapiNetWork, AsyncExec) {
LiteConfig
c_config
=
*
default_config
();
c_config
.
options
.
var_sanity_check_first_run
=
false
;
LITE_CAPI_CHECK
(
LITE_make_network
(
&
c_network
,
c_config
,
*
default_network_io
()));
LITE_CAPI_CHECK
(
LITE_set_async_callback
(
c_network
,
finish
_callback
));
LITE_CAPI_CHECK
(
LITE_set_async_callback
(
c_network
,
async
_callback
));
LoadNetwork
;
SetInput
;
...
...
@@ -740,6 +825,32 @@ TEST(TestCapiNetWork, AsyncExec) {
LITE_CAPI_CHECK
(
LITE_destroy_network
(
c_network
));
}
TEST
(
TestCapiNetWork
,
AsyncExecWithData
)
{
finished
=
false
;
ForwardMgb
;
LiteNetwork
c_network
;
LiteConfig
c_config
=
*
default_config
();
c_config
.
options
.
var_sanity_check_first_run
=
false
;
LITE_CAPI_CHECK
(
LITE_make_network
(
&
c_network
,
c_config
,
*
default_network_io
()));
size_t
user_data
=
1
;
LITE_CAPI_CHECK
(
LITE_set_async_callback_with_userdata
(
c_network
,
async_callback_with_data
,
&
user_data
));
LoadNetwork
;
SetInput
;
LITE_forward
(
c_network
);
size_t
count
=
0
;
while
(
finished_with_data
==
false
)
{
count
++
;
}
ASSERT_GT
(
count
,
0
);
finished_with_data
=
false
;
GetOutput
;
CompareResult
;
LITE_CAPI_CHECK
(
LITE_destroy_network
(
c_network
));
}
TEST
(
TestCapiNetWork
,
OutputShapeOnly
)
{
ForwardMgb
;
LiteNetwork
c_network
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录