From b84d289338a544f8ebc0461a92b08d790c36e0c5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 13 Oct 2021 11:11:12 +0800 Subject: [PATCH] fix(lite): fix lite c thread local GitOrigin-RevId: 36d2da7d68a71253904851b32daff61ed03738ce --- lite/lite-c/src/common.h | 2 ++ lite/lite-c/src/global.cpp | 10 +++++----- lite/lite-c/src/network.cpp | 8 ++++++-- lite/lite-c/src/tensor.cpp | 12 +++++++++--- lite/test/test_tensor_c.cpp | 18 ++++++++++++++++++ 5 files changed, 40 insertions(+), 10 deletions(-) diff --git a/lite/lite-c/src/common.h b/lite/lite-c/src/common.h index a17e6b87f..7201834d2 100644 --- a/lite/lite-c/src/common.h +++ b/lite/lite-c/src/common.h @@ -17,8 +17,10 @@ #include "lite-c/tensor_c.h" #include "lite/network.h" +#if LITE_ENABLE_EXCEPTION #include #include +#endif //! convert c Layout to lite::Layout lite::Layout convert_to_layout(const LiteLayout& layout); diff --git a/lite/lite-c/src/global.cpp b/lite/lite-c/src/global.cpp index 48703a446..c686b1f3f 100644 --- a/lite/lite-c/src/global.cpp +++ b/lite/lite-c/src/global.cpp @@ -13,11 +13,7 @@ #include "common.h" #include "lite-c/global_c.h" -#include -#include - namespace { - class ErrorMsg { public: std::string& get_error_msg() { return error_msg; } @@ -26,18 +22,22 @@ public: private: std::string error_msg; }; + +static LITE_MUTEX mtx_error; ErrorMsg& get_global_error() { - static thread_local ErrorMsg error_msg; + static ErrorMsg error_msg; return error_msg; } } // namespace int LiteHandleException(const std::exception& e) { + LITE_LOCK_GUARD(mtx_error); get_global_error().set_error_msg(e.what()); return -1; } const char* LITE_get_last_error() { + LITE_LOCK_GUARD(mtx_error); return get_global_error().get_error_msg().c_str(); } diff --git a/lite/lite-c/src/network.cpp b/lite/lite-c/src/network.cpp index b7e9477fd..c4aa631fb 100644 --- a/lite/lite-c/src/network.cpp +++ b/lite/lite-c/src/network.cpp @@ -72,9 +72,9 @@ LiteNetworkIO* default_network_io() { } namespace { +static LITE_MUTEX mtx_network; std::unordered_map>& get_gloabl_network_holder() { - static thread_local std::unordered_map> - network_holder; + static std::unordered_map> network_holder; return network_holder; } @@ -168,6 +168,7 @@ int LITE_make_default_network(LiteNetwork* network) { LITE_CAPI_BEGIN(); LITE_ASSERT(network, "The network pass to LITE api is null"); auto lite_network = std::make_shared(); + LITE_LOCK_GUARD(mtx_network); get_gloabl_network_holder()[lite_network.get()] = lite_network; *network = lite_network.get(); LITE_CAPI_END(); @@ -179,6 +180,7 @@ int LITE_make_network( LITE_ASSERT(network, "The network pass to LITE api is null"); auto lite_network = std::make_shared( convert_to_lite_config(config), convert_to_lite_io(network_io)); + LITE_LOCK_GUARD(mtx_network); get_gloabl_network_holder()[lite_network.get()] = lite_network; *network = lite_network.get(); LITE_CAPI_END(); @@ -188,6 +190,7 @@ int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) { LITE_CAPI_BEGIN(); LITE_ASSERT(network, "The network pass to LITE api is null"); auto lite_network = std::make_shared(convert_to_lite_config(config)); + LITE_LOCK_GUARD(mtx_network); get_gloabl_network_holder()[lite_network.get()] = lite_network; *network = lite_network.get(); LITE_CAPI_END(); @@ -212,6 +215,7 @@ int LITE_load_model_from_path(LiteNetwork network, const char* model_path) { int LITE_destroy_network(LiteNetwork network) { LITE_CAPI_BEGIN(); LITE_ASSERT(network, "The network pass to LITE api is null"); + LITE_LOCK_GUARD(mtx_network); get_gloabl_network_holder().erase(network); LITE_CAPI_END(); } diff --git a/lite/lite-c/src/tensor.cpp b/lite/lite-c/src/tensor.cpp index 362e3da4a..1f53fc641 100644 --- a/lite/lite-c/src/tensor.cpp +++ b/lite/lite-c/src/tensor.cpp @@ -26,13 +26,16 @@ const LiteTensorDesc default_desc = { .device_type = LiteDeviceType::LITE_CPU, .device_id = 0}; namespace { + +static LITE_MUTEX mtx_tensor; std::unordered_map>& get_global_tensor_holder() { - static thread_local std::unordered_map> - global_holder; + static std::unordered_map> global_holder; return global_holder; } + +static LITE_MUTEX mtx_attr; std::unordered_map& get_global_tensor_attr_holder() { - static thread_local std::unordered_map global_holder; + static std::unordered_map global_holder; return global_holder; } } // namespace @@ -68,6 +71,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { auto lite_tensor = std::make_shared( tensor_describe.device_id, tensor_describe.device_type, layout, tensor_describe.is_pinned_host); + LITE_LOCK_GUARD(mtx_tensor); get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; *tensor = lite_tensor.get(); LITE_CAPI_END(); @@ -76,6 +80,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { int LITE_destroy_tensor(LiteTensor tensor) { LITE_CAPI_BEGIN(); LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); + LITE_LOCK_GUARD(mtx_tensor); get_global_tensor_holder().erase(tensor); LITE_CAPI_END(); } @@ -132,6 +137,7 @@ int LITE_tensor_slice( } } auto ret_tensor = static_cast(tensor)->slice(starts, ends, steps); + LITE_LOCK_GUARD(mtx_tensor); get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; *slice_tensor = ret_tensor.get(); LITE_CAPI_END(); diff --git a/lite/test/test_tensor_c.cpp b/lite/test/test_tensor_c.cpp index 7a8e35f15..725dd592d 100644 --- a/lite/test/test_tensor_c.cpp +++ b/lite/test/test_tensor_c.cpp @@ -18,6 +18,7 @@ #include #include +#include TEST(TestCapiTensor, Basic) { LiteTensor c_tensor0, c_tensor1; @@ -305,6 +306,23 @@ TEST(TestCapiTensor, GetMemoryByIndex) { LITE_destroy_tensor(c_tensor0); } +TEST(TestCapiTensor, ThreadLocalError) { + LiteTensor c_tensor0; + LiteTensorDesc description = default_desc; + description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT}; + void *ptr0, *ptr1; + std::thread thread1([&]() { + LITE_make_tensor(description, &c_tensor0); + LITE_get_tensor_memory(c_tensor0, &ptr0); + }); + thread1.join(); + std::thread thread2([&]() { + LITE_get_tensor_memory(c_tensor0, &ptr1); + LITE_destroy_tensor(c_tensor0); + }); + thread2.join(); +} + #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab