提交 c361b193 编写于 作者: M Megvii Engine Team

feat(lite-c): add lite C callback with user_data API

GitOrigin-RevId: a54237488fb5f394ddf06bb5fed6a547a1d2e931
上级 7fa5f6f4
...@@ -67,7 +67,8 @@ bool config_user_allocator(const Args& args); ...@@ -67,7 +67,8 @@ bool config_user_allocator(const Args& args);
bool register_cryption_method(const Args& args); bool register_cryption_method(const Args& args);
bool update_cryption_key(const Args& args); bool update_cryption_key(const Args& args);
bool async_forward(const Args& args); bool async_forward(const Args& args);
bool set_input_callback(const Args& arg);
bool set_output_callback(const Args& arg);
#if LITE_WITH_CUDA #if LITE_WITH_CUDA
bool device_input(const Args& args); bool device_input(const Args& args);
bool device_input_output(const Args& args); bool device_input_output(const Args& args);
......
...@@ -160,6 +160,8 @@ REGIST_EXAMPLE("reset_input", reset_input); ...@@ -160,6 +160,8 @@ REGIST_EXAMPLE("reset_input", reset_input);
REGIST_EXAMPLE("reset_input_output", reset_input_output); REGIST_EXAMPLE("reset_input_output", reset_input_output);
REGIST_EXAMPLE("config_user_allocator", config_user_allocator); REGIST_EXAMPLE("config_user_allocator", config_user_allocator);
REGIST_EXAMPLE("async_forward", async_forward); REGIST_EXAMPLE("async_forward", async_forward);
REGIST_EXAMPLE("set_input_callback", set_input_callback);
REGIST_EXAMPLE("set_output_callback", set_output_callback);
REGIST_EXAMPLE("basic_c_interface", basic_c_interface); REGIST_EXAMPLE("basic_c_interface", basic_c_interface);
REGIST_EXAMPLE("device_io_c_interface", device_io_c_interface); REGIST_EXAMPLE("device_io_c_interface", device_io_c_interface);
......
...@@ -365,6 +365,142 @@ bool lite::example::async_forward(const Args& args) { ...@@ -365,6 +365,142 @@ bool lite::example::async_forward(const Args& args) {
printf("max=%e, sum=%e\n", max, sum); printf("max=%e, sum=%e\n", max, sum);
return true; return true;
} }
bool lite::example::set_input_callback(const Args& args) {
std::string network_path = args.model_path;
std::string input_path = args.input_path;
Config config;
config.options.var_sanity_check_first_run = false;
//! create and load the network
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(network_path);
//! set input data to input tensor
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
//! copy or forward data to network
size_t length = input_tensor->get_tensor_total_size_in_byte();
void* dst_ptr = input_tensor->get_memory_ptr();
auto src_tensor = parse_npy(input_path);
void* src = src_tensor->get_memory_ptr();
memcpy(dst_ptr, src, length);
//! set input callback
volatile bool finished = false;
network->set_start_callback(
[&finished](const std::unordered_map<
std::string, std::pair<IO, std::shared_ptr<Tensor>>>& inputs) {
#if !__DEPLOY_ON_XP_SP2__
std::cout << "worker thread_id:" << std::this_thread::get_id()
<< std::endl;
#endif
for (auto&& item : inputs) {
std::cout << "input name: " << item.first
<< "input dim: " << item.second.second->get_layout().ndim
<< std::endl;
}
finished = true;
});
#if !__DEPLOY_ON_XP_SP2__
std::cout << "out thread_id:" << std::this_thread::get_id() << std::endl;
#endif
//! forward
network->forward();
size_t count = 0;
while (finished == false) {
count++;
}
printf("Forward finish, count is %zu\n", count);
//! get the output data or read tensor set in network_in
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
void* out_data = output_tensor->get_memory_ptr();
size_t out_length = output_tensor->get_tensor_total_size_in_byte() /
output_tensor->get_layout().get_elem_size();
printf("length=%zu\n", length);
float max = -1.0f;
float sum = 0.0f;
for (size_t i = 0; i < out_length; i++) {
float data = static_cast<float*>(out_data)[i];
sum += data;
if (max < data)
max = data;
}
printf("max=%e, sum=%e\n", max, sum);
return true;
}
bool lite::example::set_output_callback(const Args& args) {
std::string network_path = args.model_path;
std::string input_path = args.input_path;
Config config;
config.options.var_sanity_check_first_run = false;
//! create and load the network
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(network_path);
//! set input data to input tensor
std::shared_ptr<Tensor> input_tensor = network->get_output_tensor(0);
//! copy or forward data to network
size_t length = input_tensor->get_tensor_total_size_in_byte();
void* dst_ptr = input_tensor->get_memory_ptr();
auto src_tensor = parse_npy(input_path);
void* src = src_tensor->get_memory_ptr();
memcpy(dst_ptr, src, length);
//! set output callback
volatile bool finished = false;
network->set_finish_callback(
[&finished](const std::unordered_map<
std::string, std::pair<IO, std::shared_ptr<Tensor>>>& outputs) {
#if !__DEPLOY_ON_XP_SP2__
std::cout << "worker thread_id:" << std::this_thread::get_id()
<< std::endl;
#endif
for (auto&& item : outputs) {
std::cout << "output name: " << item.first
<< "output dim: " << item.second.second->get_layout().ndim
<< std::endl;
}
finished = true;
});
#if !__DEPLOY_ON_XP_SP2__
std::cout << "out thread_id:" << std::this_thread::get_id() << std::endl;
#endif
//! forward
network->forward();
network->wait();
size_t count = 0;
while (finished == false) {
count++;
}
printf("Forward finish, count is %zu\n", count);
//! get the output data or read tensor set in network_in
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
void* out_data = output_tensor->get_memory_ptr();
size_t out_length = output_tensor->get_tensor_total_size_in_byte() /
output_tensor->get_layout().get_elem_size();
printf("length=%zu\n", length);
float max = -1.0f;
float sum = 0.0f;
for (size_t i = 0; i < out_length; i++) {
float data = static_cast<float*>(out_data)[i];
sum += data;
if (max < data)
max = data;
}
printf("max=%e, sum=%e\n", max, sum);
return true;
}
#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -184,6 +184,8 @@ typedef int (*LiteThreadAffinityCallback)(int thread_id); ...@@ -184,6 +184,8 @@ typedef int (*LiteThreadAffinityCallback)(int thread_id);
typedef int (*LiteAsyncCallback)(); typedef int (*LiteAsyncCallback)();
typedef int (*LiteAsyncCallbackWithData)(void* user_data);
/*! /*!
* \brief the start/finish callback function * \brief the start/finish callback function
* \param unordered_map map from the io tensor name to the pair of which is the * \param unordered_map map from the io tensor name to the pair of which is the
...@@ -193,9 +195,17 @@ typedef int (*LiteAsyncCallback)(); ...@@ -193,9 +195,17 @@ typedef int (*LiteAsyncCallback)();
typedef int (*LiteStartCallback)( typedef int (*LiteStartCallback)(
const LiteIO* inputs, const LiteTensor* input_tensors, size_t size); const LiteIO* inputs, const LiteTensor* input_tensors, size_t size);
typedef int (*LiteStartCallbackWithData)(
const LiteIO* inputs, const LiteTensor* input_tensors, size_t size,
void* user_data);
typedef int (*LiteFinishCallback)( typedef int (*LiteFinishCallback)(
const LiteIO* outputs, const LiteTensor* output_tensors, size_t size); const LiteIO* outputs, const LiteTensor* output_tensors, size_t size);
typedef int (*LiteFinishCallbackWithData)(
const LiteIO* outputs, const LiteTensor* output_tensors, size_t size,
void* user_data);
/*! /*!
* \brief The network is construct form a model, implement model load, init, * \brief The network is construct form a model, implement model load, init,
* forward, and display some model information * forward, and display some model information
...@@ -442,6 +452,19 @@ LITE_API int LITE_set_network_algo_workspace_limit( ...@@ -442,6 +452,19 @@ LITE_API int LITE_set_network_algo_workspace_limit(
LITE_API int LITE_set_async_callback( LITE_API int LITE_set_async_callback(
LiteNetwork network, const LiteAsyncCallback async_callback); LiteNetwork network, const LiteAsyncCallback async_callback);
/**
* \brief set the network forward in async mode and set the async callback
* function
* \param[in] network The loaded model
* \param[in] async_callback when network finish forwarding, the callback
* will be called
* \param[in] user_data user defined data for something user want to deploy
* at forward finish stage
*/
LITE_API int LITE_set_async_callback_with_userdata(
LiteNetwork network, const LiteAsyncCallbackWithData async_callback,
void* user_data);
/** /**
* \brief set the start forward callback function, which will be execute beform * \brief set the start forward callback function, which will be execute beform
* forward, this can be used to check network input or dump model inputs * forward, this can be used to check network input or dump model inputs
...@@ -453,6 +476,20 @@ LITE_API int LITE_set_async_callback( ...@@ -453,6 +476,20 @@ LITE_API int LITE_set_async_callback(
LITE_API int LITE_set_start_callback( LITE_API int LITE_set_start_callback(
LiteNetwork network, const LiteStartCallback start_callback); LiteNetwork network, const LiteStartCallback start_callback);
/**
* \brief set the start forward callback function, which will be execute beform
* forward, this can be used to check network input or dump model inputs
* for debug
* \param[in] network The loaded model
* \param[in] start_callback when network start forwarding, the callbak
* will be called
* \param[in] user_data user defined data for something user want to deploy
* at forward start stage
*/
LITE_API int LITE_set_start_callback_with_userdata(
LiteNetwork network, const LiteStartCallbackWithData start_callback,
void* user_data);
/** /**
* \brief set the finish forward callback function, which will be execute after * \brief set the finish forward callback function, which will be execute after
* forward, this can be used to dump model outputs for debug * forward, this can be used to dump model outputs for debug
...@@ -463,6 +500,19 @@ LITE_API int LITE_set_start_callback( ...@@ -463,6 +500,19 @@ LITE_API int LITE_set_start_callback(
LITE_API int LITE_set_finish_callback( LITE_API int LITE_set_finish_callback(
LiteNetwork network, const LiteFinishCallback finish_callback); LiteNetwork network, const LiteFinishCallback finish_callback);
/**
* \brief set the finish forward callback function, which will be execute after
* forward, this can be used to dump model outputs for debug
* \param[in] network The loaded model
* \param[in] finish_callback when network finish forwarding, the callbak
* will be called
* \param[in] user_data user defined data for something user want to deploy
* at finish stage
*/
LITE_API int LITE_set_finish_callback_with_userdata(
LiteNetwork network, const LiteFinishCallbackWithData finish_callback,
void* user_data);
/** /**
* \brief set threads affinity callback * \brief set threads affinity callback
* \param[in] network The loaded model * \param[in] network The loaded model
......
...@@ -355,6 +355,22 @@ int LITE_set_async_callback( ...@@ -355,6 +355,22 @@ int LITE_set_async_callback(
LITE_CAPI_END(); LITE_CAPI_END();
} }
int LITE_set_async_callback_with_userdata(
LiteNetwork network, LiteAsyncCallbackWithData async_callback,
void* user_data) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
LITE_ASSERT(async_callback, "The ptr pass to LITE api is null");
auto lite_async_callback = [async_callback, user_data]() -> void {
async_callback(user_data);
};
static_cast<lite::Network*>(network)->set_async_callback(
std::move(lite_async_callback));
LITE_CAPI_END();
}
int LITE_set_start_callback( int LITE_set_start_callback(
LiteNetwork network, const LiteStartCallback start_callback) { LiteNetwork network, const LiteStartCallback start_callback) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
...@@ -381,6 +397,34 @@ int LITE_set_start_callback( ...@@ -381,6 +397,34 @@ int LITE_set_start_callback(
LITE_CAPI_END(); LITE_CAPI_END();
} }
int LITE_set_start_callback_with_userdata(
LiteNetwork network, const LiteStartCallbackWithData start_callback,
void* user_data) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
auto lite_start_callback =
[start_callback,
user_data](const std::unordered_map<
std::string,
std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>& inputs_map)
-> void {
std::vector<LiteIO> ios;
std::vector<LiteTensor> io_tensors;
size_t nr_io = 0;
for (const auto& io : inputs_map) {
nr_io++;
auto&& lite_io = io.second.first;
ios.push_back(
{lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
convert_to_clayout(lite_io.config_layout)});
io_tensors.push_back(io.second.second.get());
}
start_callback(ios.data(), io_tensors.data(), nr_io, user_data);
};
static_cast<lite::Network*>(network)->set_start_callback(lite_start_callback);
LITE_CAPI_END();
}
int LITE_set_finish_callback( int LITE_set_finish_callback(
LiteNetwork network, const LiteFinishCallback finish_callback) { LiteNetwork network, const LiteFinishCallback finish_callback) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
...@@ -407,6 +451,34 @@ int LITE_set_finish_callback( ...@@ -407,6 +451,34 @@ int LITE_set_finish_callback(
LITE_CAPI_END(); LITE_CAPI_END();
} }
int LITE_set_finish_callback_with_userdata(
LiteNetwork network, const LiteFinishCallbackWithData finish_callback,
void* user_data) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
auto lite_finish_callback =
[finish_callback,
user_data](const std::unordered_map<
std::string,
std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
outputs_map) -> void {
std::vector<LiteIO> ios;
std::vector<LiteTensor> io_tensors;
size_t nr_io = 0;
for (const auto& io : outputs_map) {
nr_io++;
auto&& lite_io = io.second.first;
ios.push_back(
{lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
convert_to_clayout(lite_io.config_layout)});
io_tensors.push_back(io.second.second.get());
}
finish_callback(ios.data(), io_tensors.data(), nr_io, user_data);
};
static_cast<lite::Network*>(network)->set_finish_callback(lite_finish_callback);
LITE_CAPI_END();
}
int LITE_enable_profile_performance( int LITE_enable_profile_performance(
LiteNetwork network, const char* profile_json_file_path) { LiteNetwork network, const char* profile_json_file_path) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
......
...@@ -74,11 +74,21 @@ int multi_thread_affinity(int id) { ...@@ -74,11 +74,21 @@ int multi_thread_affinity(int id) {
}; };
volatile bool finished = false; volatile bool finished = false;
int finish_callback() { int async_callback() {
finished = true; finished = true;
return 0; return 0;
} }
volatile bool finished_with_data = false;
int async_callback_with_data(void* user_data) {
if (user_data != NULL) {
std::cout << "async_callback user_data addr=" << std::hex << user_data
<< std::endl;
}
finished_with_data = true;
return 0;
}
volatile bool start_checked = false; volatile bool start_checked = false;
int start_callback(const LiteIO* inputs, const LiteTensor* input_tensors, size_t size) { int start_callback(const LiteIO* inputs, const LiteTensor* input_tensors, size_t size) {
start_checked = true; start_checked = true;
...@@ -96,6 +106,29 @@ int start_callback(const LiteIO* inputs, const LiteTensor* input_tensors, size_t ...@@ -96,6 +106,29 @@ int start_callback(const LiteIO* inputs, const LiteTensor* input_tensors, size_t
return 0; return 0;
} }
volatile bool start_checked_with_data = false;
int start_callback_with_data(
const LiteIO* inputs, const LiteTensor* input_tensors, size_t size,
void* user_data) {
start_checked_with_data = true;
auto check_func = [&]() {
if (user_data != NULL) {
std::cout << "start_callback user_data addr=" << std::hex << user_data
<< std::endl;
}
ASSERT_EQ(size, 1);
ASSERT_EQ(std::string(inputs->name), "data");
LiteLayout layout;
LITE_get_tensor_layout(*input_tensors, &layout);
ASSERT_EQ(layout.ndim, 4);
ASSERT_EQ(layout.shapes[1], 3);
ASSERT_EQ(layout.shapes[2], 224);
ASSERT_EQ(layout.shapes[3], 224);
};
check_func();
return 0;
}
volatile bool finish_checked = false; volatile bool finish_checked = false;
int finish_callback( int finish_callback(
const LiteIO* outputs, const LiteTensor* output_tensors, size_t size) { const LiteIO* outputs, const LiteTensor* output_tensors, size_t size) {
...@@ -113,6 +146,28 @@ int finish_callback( ...@@ -113,6 +146,28 @@ int finish_callback(
return 0; return 0;
} }
volatile bool finish_checked_with_data = false;
int finish_callback_with_data(
const LiteIO* outputs, const LiteTensor* output_tensors, size_t size,
void* user_data) {
finish_checked_with_data = true;
auto check_func = [&]() {
if (user_data != NULL) {
std::cout << "finish_callback user_data addr=" << std::hex << user_data
<< std::endl;
}
ASSERT_EQ(size, 1);
ASSERT_EQ(
std::string(outputs->name),
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
LiteLayout layout;
LITE_get_tensor_layout(*output_tensors, &layout);
ASSERT_EQ(layout.shapes[1], 1000);
};
check_func();
return 0;
}
} // namespace } // namespace
#define LITE_CAPI_CHECK(_expr) \ #define LITE_CAPI_CHECK(_expr) \
...@@ -671,6 +726,21 @@ TEST(TestCapiNetWork, StartCallBack) { ...@@ -671,6 +726,21 @@ TEST(TestCapiNetWork, StartCallBack) {
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); LITE_CAPI_CHECK(LITE_destroy_network(c_network));
} }
TEST(TestCapiNetWork, StartCallBackWithData) {
ForwardMgb;
MakeNetwork;
LoadNetwork;
size_t user_data = 1;
LITE_CAPI_CHECK(LITE_set_start_callback_with_userdata(
c_network, start_callback_with_data, &user_data));
SetInput;
ForwardNetwork;
GetOutput;
CompareResult;
ASSERT_TRUE(start_checked_with_data);
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
TEST(TestCapiNetWork, FinishCallBack) { TEST(TestCapiNetWork, FinishCallBack) {
ForwardMgb; ForwardMgb;
MakeNetwork; MakeNetwork;
...@@ -684,6 +754,21 @@ TEST(TestCapiNetWork, FinishCallBack) { ...@@ -684,6 +754,21 @@ TEST(TestCapiNetWork, FinishCallBack) {
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); LITE_CAPI_CHECK(LITE_destroy_network(c_network));
} }
TEST(TestCapiNetWork, FinishCallBackWtihData) {
ForwardMgb;
MakeNetwork;
LoadNetwork;
size_t user_data = 1;
LITE_CAPI_CHECK(LITE_set_finish_callback_with_userdata(
c_network, finish_callback_with_data, &user_data));
SetInput;
ForwardNetwork;
GetOutput;
CompareResult;
ASSERT_TRUE(finish_checked_with_data);
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
TEST(TestCapiNetWork, BasicCryptAes) { TEST(TestCapiNetWork, BasicCryptAes) {
ForwardMgb; ForwardMgb;
...@@ -723,7 +808,7 @@ TEST(TestCapiNetWork, AsyncExec) { ...@@ -723,7 +808,7 @@ TEST(TestCapiNetWork, AsyncExec) {
LiteConfig c_config = *default_config(); LiteConfig c_config = *default_config();
c_config.options.var_sanity_check_first_run = false; c_config.options.var_sanity_check_first_run = false;
LITE_CAPI_CHECK(LITE_make_network(&c_network, c_config, *default_network_io())); LITE_CAPI_CHECK(LITE_make_network(&c_network, c_config, *default_network_io()));
LITE_CAPI_CHECK(LITE_set_async_callback(c_network, finish_callback)); LITE_CAPI_CHECK(LITE_set_async_callback(c_network, async_callback));
LoadNetwork; LoadNetwork;
SetInput; SetInput;
...@@ -740,6 +825,32 @@ TEST(TestCapiNetWork, AsyncExec) { ...@@ -740,6 +825,32 @@ TEST(TestCapiNetWork, AsyncExec) {
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); LITE_CAPI_CHECK(LITE_destroy_network(c_network));
} }
TEST(TestCapiNetWork, AsyncExecWithData) {
finished = false;
ForwardMgb;
LiteNetwork c_network;
LiteConfig c_config = *default_config();
c_config.options.var_sanity_check_first_run = false;
LITE_CAPI_CHECK(LITE_make_network(&c_network, c_config, *default_network_io()));
size_t user_data = 1;
LITE_CAPI_CHECK(LITE_set_async_callback_with_userdata(
c_network, async_callback_with_data, &user_data));
LoadNetwork;
SetInput;
LITE_forward(c_network);
size_t count = 0;
while (finished_with_data == false) {
count++;
}
ASSERT_GT(count, 0);
finished_with_data = false;
GetOutput;
CompareResult;
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
TEST(TestCapiNetWork, OutputShapeOnly) { TEST(TestCapiNetWork, OutputShapeOnly) {
ForwardMgb; ForwardMgb;
LiteNetwork c_network; LiteNetwork c_network;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册