diff --git a/lite/lite-c/src/network.cpp b/lite/lite-c/src/network.cpp index 7419d7a81d30dd06ca714ff1fbeaa0183302e7ac..eb222a65969e5e2423f13e48a5ec12fac6305c23 100644 --- a/lite/lite-c/src/network.cpp +++ b/lite/lite-c/src/network.cpp @@ -242,7 +242,13 @@ int LITE_destroy_network(LiteNetwork network) { LITE_CAPI_BEGIN(); LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_LOCK_GUARD(mtx_network); - get_gloabl_network_holder().erase(network); + auto& global_holder = get_gloabl_network_holder(); + if (global_holder.find(network) != global_holder.end()) { + global_holder.erase(network); + } else { + //! means the network has been destoryed + return -1; + } LITE_CAPI_END(); } diff --git a/lite/lite-c/src/tensor.cpp b/lite/lite-c/src/tensor.cpp index 6f6a674dc63650e5b78c40593b814cb919733eaf..20102fc2e74a6aabc54461f8289c0b9f984f6691 100644 --- a/lite/lite-c/src/tensor.cpp +++ b/lite/lite-c/src/tensor.cpp @@ -60,8 +60,10 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { auto lite_tensor = std::make_shared( tensor_describe.device_id, tensor_describe.device_type, layout, tensor_describe.is_pinned_host); - LITE_LOCK_GUARD(mtx_tensor); - get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; + { + LITE_LOCK_GUARD(mtx_tensor); + get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; + } *tensor = lite_tensor.get(); LITE_CAPI_END(); } @@ -70,7 +72,13 @@ int LITE_destroy_tensor(LiteTensor tensor) { LITE_CAPI_BEGIN(); LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); LITE_LOCK_GUARD(mtx_tensor); - get_global_tensor_holder().erase(tensor); + auto& global_holder = get_global_tensor_holder(); + if (global_holder.find(tensor) != global_holder.end()) { + global_holder.erase(tensor); + } else { + //! return -1, means the tensor has been destroyed. + return -1; + } LITE_CAPI_END(); } @@ -126,8 +134,10 @@ int LITE_tensor_slice( } } auto ret_tensor = static_cast(tensor)->slice(starts, ends, steps); - LITE_LOCK_GUARD(mtx_tensor); - get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; + { + LITE_LOCK_GUARD(mtx_tensor); + get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; + } *slice_tensor = ret_tensor.get(); LITE_CAPI_END(); } @@ -226,12 +236,16 @@ int LITE_tensor_concat( LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device, int device_id, LiteTensor* result_tensor) { LITE_CAPI_BEGIN(); + LITE_ASSERT(result_tensor, "The tensor pass to LITE c_api is null"); std::vector v_tensors; for (int i = 0; i < nr_tensor; i++) { v_tensors.push_back(*static_cast(tensors[i])); } auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id); - get_global_tensor_holder()[tensor.get()] = tensor; + { + LITE_LOCK_GUARD(mtx_tensor); + get_global_tensor_holder()[tensor.get()] = tensor; + } *result_tensor = tensor.get(); LITE_CAPI_END() } diff --git a/lite/pylite/megenginelite/network.py b/lite/pylite/megenginelite/network.py index a66bb94d3896483b825daad09f7dc16fe9609eb0..57c984ab45bcc233c3daa191d0d7d032549cd6b3 100644 --- a/lite/pylite/megenginelite/network.py +++ b/lite/pylite/megenginelite/network.py @@ -476,7 +476,7 @@ def start_finish_callback(func): def wrapper(c_ios, c_tensors, size): ios = {} for i in range(size): - tensor = LiteTensor() + tensor = LiteTensor(physic_construct=False) tensor._tensor = c_void_p(c_tensors[i]) tensor.update() io = c_ios[i] @@ -729,7 +729,7 @@ class LiteNetwork(object): c_name = c_char_p(name.encode("utf-8")) else: c_name = c_char_p(name) - tensor = LiteTensor() + tensor = LiteTensor(physic_construct=False) self._api.LITE_get_io_tensor( self._network, c_name, phase, byref(tensor._tensor) ) diff --git a/lite/pylite/megenginelite/tensor.py b/lite/pylite/megenginelite/tensor.py index 15bb071a7105f3bd32eaea620d054f95cd00f439..77faff611e4a2a05760e4f15b3a9845d4be8386e 100644 --- a/lite/pylite/megenginelite/tensor.py +++ b/lite/pylite/megenginelite/tensor.py @@ -233,6 +233,7 @@ class LiteTensor(object): is_pinned_host=False, shapes=None, dtype=None, + physic_construct=True, ): self._tensor = _Ctensor() self._layout = LiteLayout() @@ -250,8 +251,10 @@ class LiteTensor(object): tensor_desc.device_type = device_type tensor_desc.device_id = device_id tensor_desc.is_pinned_host = is_pinned_host - self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) - self.update() + + if physic_construct: + self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) + self.update() def __del__(self): self._api.LITE_destroy_tensor(self._tensor) @@ -399,7 +402,7 @@ class LiteTensor(object): c_start = (c_size_t * length)(*start) c_end = (c_size_t * length)(*end) c_step = (c_size_t * length)(*step) - slice_tensor = LiteTensor() + slice_tensor = LiteTensor(physic_construct=False) self._api.LITE_tensor_slice( self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor), ) @@ -560,7 +563,7 @@ def LiteTensorConcat( length = len(tensors) c_tensors = [t._tensor for t in tensors] c_tensors = (_Ctensor * length)(*c_tensors) - result_tensor = LiteTensor() + result_tensor = LiteTensor(physic_construct=False) api.LITE_tensor_concat( cast(byref(c_tensors), POINTER(c_void_p)), length, diff --git a/lite/test/test_network_c.cpp b/lite/test/test_network_c.cpp index 0b76467f044daf1f4a06daa5cdb7da6c1785319f..40245de3ad4bb154efdbdab4439e187313fc401b 100644 --- a/lite/test/test_network_c.cpp +++ b/lite/test/test_network_c.cpp @@ -1022,6 +1022,20 @@ TEST(TestCapiNetWork, TestShareWeights) { LITE_CAPI_CHECK(LITE_destroy_network(c_network2)); } +TEST(TestCapiNetWork, GlobalHolder) { + std::string model_path = "./shufflenet.mge"; + LiteNetwork c_network; + LITE_CAPI_CHECK( + LITE_make_network(&c_network, *default_config(), *default_network_io())); + auto destroy_network = c_network; + LITE_CAPI_CHECK( + LITE_make_network(&c_network, *default_config(), *default_network_io())); + //! make sure destroy_network is destroyed by LITE_make_network + LITE_destroy_network(destroy_network); + ASSERT_EQ(LITE_destroy_network(destroy_network), -1); + LITE_CAPI_CHECK(LITE_destroy_network(c_network)); +} + #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/test/test_tensor_c.cpp b/lite/test/test_tensor_c.cpp index 6d71dbd13ab0ead89db289bcde6dd16daecf06ea..1444f099b8d5ea0cb6bc87f4a960199915b5f101 100644 --- a/lite/test/test_tensor_c.cpp +++ b/lite/test/test_tensor_c.cpp @@ -251,6 +251,7 @@ TEST(TestCapiTensor, Slice) { } } LITE_destroy_tensor(tensor); + LITE_destroy_tensor(slice_tensor); }; check(1, 8, 1, true); check(1, 8, 1, false); @@ -316,6 +317,21 @@ TEST(TestCapiTensor, ThreadLocalError) { thread2.join(); } +TEST(TestCapiTensor, GlobalHolder) { + LiteTensor c_tensor0; + LiteTensorDesc description = default_desc; + description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT}; + + LITE_make_tensor(description, &c_tensor0); + auto destroy_tensor = c_tensor0; + + LITE_make_tensor(description, &c_tensor0); + //! make sure destroy_tensor is destroyed by LITE_make_tensor + LITE_destroy_tensor(destroy_tensor); + ASSERT_EQ(LITE_destroy_tensor(destroy_tensor), -1); + LITE_destroy_tensor(c_tensor0); +} + #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}