提交 b84d2893 编写于 作者: M Megvii Engine Team

fix(lite): fix lite c thread local

GitOrigin-RevId: 36d2da7d68a71253904851b32daff61ed03738ce
上级 936bb237
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
#include "lite-c/tensor_c.h" #include "lite-c/tensor_c.h"
#include "lite/network.h" #include "lite/network.h"
#if LITE_ENABLE_EXCEPTION
#include <exception> #include <exception>
#include <stdexcept> #include <stdexcept>
#endif
//! convert c Layout to lite::Layout //! convert c Layout to lite::Layout
lite::Layout convert_to_layout(const LiteLayout& layout); lite::Layout convert_to_layout(const LiteLayout& layout);
......
...@@ -13,11 +13,7 @@ ...@@ -13,11 +13,7 @@
#include "common.h" #include "common.h"
#include "lite-c/global_c.h" #include "lite-c/global_c.h"
#include <exception>
#include <mutex>
namespace { namespace {
class ErrorMsg { class ErrorMsg {
public: public:
std::string& get_error_msg() { return error_msg; } std::string& get_error_msg() { return error_msg; }
...@@ -26,18 +22,22 @@ public: ...@@ -26,18 +22,22 @@ public:
private: private:
std::string error_msg; std::string error_msg;
}; };
static LITE_MUTEX mtx_error;
ErrorMsg& get_global_error() { ErrorMsg& get_global_error() {
static thread_local ErrorMsg error_msg; static ErrorMsg error_msg;
return error_msg; return error_msg;
} }
} // namespace } // namespace
int LiteHandleException(const std::exception& e) { int LiteHandleException(const std::exception& e) {
LITE_LOCK_GUARD(mtx_error);
get_global_error().set_error_msg(e.what()); get_global_error().set_error_msg(e.what());
return -1; return -1;
} }
const char* LITE_get_last_error() { const char* LITE_get_last_error() {
LITE_LOCK_GUARD(mtx_error);
return get_global_error().get_error_msg().c_str(); return get_global_error().get_error_msg().c_str();
} }
......
...@@ -72,9 +72,9 @@ LiteNetworkIO* default_network_io() { ...@@ -72,9 +72,9 @@ LiteNetworkIO* default_network_io() {
} }
namespace { namespace {
static LITE_MUTEX mtx_network;
std::unordered_map<void*, std::shared_ptr<lite::Network>>& get_gloabl_network_holder() { std::unordered_map<void*, std::shared_ptr<lite::Network>>& get_gloabl_network_holder() {
static thread_local std::unordered_map<void*, std::shared_ptr<lite::Network>> static std::unordered_map<void*, std::shared_ptr<lite::Network>> network_holder;
network_holder;
return network_holder; return network_holder;
} }
...@@ -168,6 +168,7 @@ int LITE_make_default_network(LiteNetwork* network) { ...@@ -168,6 +168,7 @@ int LITE_make_default_network(LiteNetwork* network) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_ASSERT(network, "The network pass to LITE api is null");
auto lite_network = std::make_shared<lite::Network>(); auto lite_network = std::make_shared<lite::Network>();
LITE_LOCK_GUARD(mtx_network);
get_gloabl_network_holder()[lite_network.get()] = lite_network; get_gloabl_network_holder()[lite_network.get()] = lite_network;
*network = lite_network.get(); *network = lite_network.get();
LITE_CAPI_END(); LITE_CAPI_END();
...@@ -179,6 +180,7 @@ int LITE_make_network( ...@@ -179,6 +180,7 @@ int LITE_make_network(
LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_ASSERT(network, "The network pass to LITE api is null");
auto lite_network = std::make_shared<lite::Network>( auto lite_network = std::make_shared<lite::Network>(
convert_to_lite_config(config), convert_to_lite_io(network_io)); convert_to_lite_config(config), convert_to_lite_io(network_io));
LITE_LOCK_GUARD(mtx_network);
get_gloabl_network_holder()[lite_network.get()] = lite_network; get_gloabl_network_holder()[lite_network.get()] = lite_network;
*network = lite_network.get(); *network = lite_network.get();
LITE_CAPI_END(); LITE_CAPI_END();
...@@ -188,6 +190,7 @@ int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) { ...@@ -188,6 +190,7 @@ int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_ASSERT(network, "The network pass to LITE api is null");
auto lite_network = std::make_shared<lite::Network>(convert_to_lite_config(config)); auto lite_network = std::make_shared<lite::Network>(convert_to_lite_config(config));
LITE_LOCK_GUARD(mtx_network);
get_gloabl_network_holder()[lite_network.get()] = lite_network; get_gloabl_network_holder()[lite_network.get()] = lite_network;
*network = lite_network.get(); *network = lite_network.get();
LITE_CAPI_END(); LITE_CAPI_END();
...@@ -212,6 +215,7 @@ int LITE_load_model_from_path(LiteNetwork network, const char* model_path) { ...@@ -212,6 +215,7 @@ int LITE_load_model_from_path(LiteNetwork network, const char* model_path) {
int LITE_destroy_network(LiteNetwork network) { int LITE_destroy_network(LiteNetwork network) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_ASSERT(network, "The network pass to LITE api is null");
LITE_LOCK_GUARD(mtx_network);
get_gloabl_network_holder().erase(network); get_gloabl_network_holder().erase(network);
LITE_CAPI_END(); LITE_CAPI_END();
} }
......
...@@ -26,13 +26,16 @@ const LiteTensorDesc default_desc = { ...@@ -26,13 +26,16 @@ const LiteTensorDesc default_desc = {
.device_type = LiteDeviceType::LITE_CPU, .device_type = LiteDeviceType::LITE_CPU,
.device_id = 0}; .device_id = 0};
namespace { namespace {
static LITE_MUTEX mtx_tensor;
std::unordered_map<void*, std::shared_ptr<lite::Tensor>>& get_global_tensor_holder() { std::unordered_map<void*, std::shared_ptr<lite::Tensor>>& get_global_tensor_holder() {
static thread_local std::unordered_map<void*, std::shared_ptr<lite::Tensor>> static std::unordered_map<void*, std::shared_ptr<lite::Tensor>> global_holder;
global_holder;
return global_holder; return global_holder;
} }
static LITE_MUTEX mtx_attr;
std::unordered_map<std::string, lite::LiteAny>& get_global_tensor_attr_holder() { std::unordered_map<std::string, lite::LiteAny>& get_global_tensor_attr_holder() {
static thread_local std::unordered_map<std::string, lite::LiteAny> global_holder; static std::unordered_map<std::string, lite::LiteAny> global_holder;
return global_holder; return global_holder;
} }
} // namespace } // namespace
...@@ -68,6 +71,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { ...@@ -68,6 +71,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
auto lite_tensor = std::make_shared<lite::Tensor>( auto lite_tensor = std::make_shared<lite::Tensor>(
tensor_describe.device_id, tensor_describe.device_type, layout, tensor_describe.device_id, tensor_describe.device_type, layout,
tensor_describe.is_pinned_host); tensor_describe.is_pinned_host);
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; get_global_tensor_holder()[lite_tensor.get()] = lite_tensor;
*tensor = lite_tensor.get(); *tensor = lite_tensor.get();
LITE_CAPI_END(); LITE_CAPI_END();
...@@ -76,6 +80,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { ...@@ -76,6 +80,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
int LITE_destroy_tensor(LiteTensor tensor) { int LITE_destroy_tensor(LiteTensor tensor) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder().erase(tensor); get_global_tensor_holder().erase(tensor);
LITE_CAPI_END(); LITE_CAPI_END();
} }
...@@ -132,6 +137,7 @@ int LITE_tensor_slice( ...@@ -132,6 +137,7 @@ int LITE_tensor_slice(
} }
} }
auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps); auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps);
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; get_global_tensor_holder()[ret_tensor.get()] = ret_tensor;
*slice_tensor = ret_tensor.get(); *slice_tensor = ret_tensor.get();
LITE_CAPI_END(); LITE_CAPI_END();
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <thread>
TEST(TestCapiTensor, Basic) { TEST(TestCapiTensor, Basic) {
LiteTensor c_tensor0, c_tensor1; LiteTensor c_tensor0, c_tensor1;
...@@ -305,6 +306,23 @@ TEST(TestCapiTensor, GetMemoryByIndex) { ...@@ -305,6 +306,23 @@ TEST(TestCapiTensor, GetMemoryByIndex) {
LITE_destroy_tensor(c_tensor0); LITE_destroy_tensor(c_tensor0);
} }
TEST(TestCapiTensor, ThreadLocalError) {
LiteTensor c_tensor0;
LiteTensorDesc description = default_desc;
description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT};
void *ptr0, *ptr1;
std::thread thread1([&]() {
LITE_make_tensor(description, &c_tensor0);
LITE_get_tensor_memory(c_tensor0, &ptr0);
});
thread1.join();
std::thread thread2([&]() {
LITE_get_tensor_memory(c_tensor0, &ptr1);
LITE_destroy_tensor(c_tensor0);
});
thread2.join();
}
#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册