提交 6c78e684 编写于 作者: M Megvii Engine Team

fix(lite): fix lite memory leak

GitOrigin-RevId: 075c686162e2c7b75a4c54245d39c6aba68c2503
上级 ff239c63
...@@ -242,7 +242,13 @@ int LITE_destroy_network(LiteNetwork network) { ...@@ -242,7 +242,13 @@ int LITE_destroy_network(LiteNetwork network) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_ASSERT(network, "The network pass to LITE api is null");
LITE_LOCK_GUARD(mtx_network); LITE_LOCK_GUARD(mtx_network);
get_gloabl_network_holder().erase(network); auto& global_holder = get_gloabl_network_holder();
if (global_holder.find(network) != global_holder.end()) {
global_holder.erase(network);
} else {
//! means the network has been destoryed
return -1;
}
LITE_CAPI_END(); LITE_CAPI_END();
} }
......
...@@ -60,8 +60,10 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { ...@@ -60,8 +60,10 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
auto lite_tensor = std::make_shared<lite::Tensor>( auto lite_tensor = std::make_shared<lite::Tensor>(
tensor_describe.device_id, tensor_describe.device_type, layout, tensor_describe.device_id, tensor_describe.device_type, layout,
tensor_describe.is_pinned_host); tensor_describe.is_pinned_host);
LITE_LOCK_GUARD(mtx_tensor); {
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor;
}
*tensor = lite_tensor.get(); *tensor = lite_tensor.get();
LITE_CAPI_END(); LITE_CAPI_END();
} }
...@@ -70,7 +72,13 @@ int LITE_destroy_tensor(LiteTensor tensor) { ...@@ -70,7 +72,13 @@ int LITE_destroy_tensor(LiteTensor tensor) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
LITE_LOCK_GUARD(mtx_tensor); LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder().erase(tensor); auto& global_holder = get_global_tensor_holder();
if (global_holder.find(tensor) != global_holder.end()) {
global_holder.erase(tensor);
} else {
//! return -1, means the tensor has been destroyed.
return -1;
}
LITE_CAPI_END(); LITE_CAPI_END();
} }
...@@ -126,8 +134,10 @@ int LITE_tensor_slice( ...@@ -126,8 +134,10 @@ int LITE_tensor_slice(
} }
} }
auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps); auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps);
LITE_LOCK_GUARD(mtx_tensor); {
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor;
}
*slice_tensor = ret_tensor.get(); *slice_tensor = ret_tensor.get();
LITE_CAPI_END(); LITE_CAPI_END();
} }
...@@ -226,12 +236,16 @@ int LITE_tensor_concat( ...@@ -226,12 +236,16 @@ int LITE_tensor_concat(
LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device, LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device,
int device_id, LiteTensor* result_tensor) { int device_id, LiteTensor* result_tensor) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(result_tensor, "The tensor pass to LITE c_api is null");
std::vector<lite::Tensor> v_tensors; std::vector<lite::Tensor> v_tensors;
for (int i = 0; i < nr_tensor; i++) { for (int i = 0; i < nr_tensor; i++) {
v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i])); v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i]));
} }
auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id); auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id);
get_global_tensor_holder()[tensor.get()] = tensor; {
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[tensor.get()] = tensor;
}
*result_tensor = tensor.get(); *result_tensor = tensor.get();
LITE_CAPI_END() LITE_CAPI_END()
} }
......
...@@ -476,7 +476,7 @@ def start_finish_callback(func): ...@@ -476,7 +476,7 @@ def start_finish_callback(func):
def wrapper(c_ios, c_tensors, size): def wrapper(c_ios, c_tensors, size):
ios = {} ios = {}
for i in range(size): for i in range(size):
tensor = LiteTensor() tensor = LiteTensor(physic_construct=False)
tensor._tensor = c_void_p(c_tensors[i]) tensor._tensor = c_void_p(c_tensors[i])
tensor.update() tensor.update()
io = c_ios[i] io = c_ios[i]
...@@ -729,7 +729,7 @@ class LiteNetwork(object): ...@@ -729,7 +729,7 @@ class LiteNetwork(object):
c_name = c_char_p(name.encode("utf-8")) c_name = c_char_p(name.encode("utf-8"))
else: else:
c_name = c_char_p(name) c_name = c_char_p(name)
tensor = LiteTensor() tensor = LiteTensor(physic_construct=False)
self._api.LITE_get_io_tensor( self._api.LITE_get_io_tensor(
self._network, c_name, phase, byref(tensor._tensor) self._network, c_name, phase, byref(tensor._tensor)
) )
......
...@@ -233,6 +233,7 @@ class LiteTensor(object): ...@@ -233,6 +233,7 @@ class LiteTensor(object):
is_pinned_host=False, is_pinned_host=False,
shapes=None, shapes=None,
dtype=None, dtype=None,
physic_construct=True,
): ):
self._tensor = _Ctensor() self._tensor = _Ctensor()
self._layout = LiteLayout() self._layout = LiteLayout()
...@@ -250,8 +251,10 @@ class LiteTensor(object): ...@@ -250,8 +251,10 @@ class LiteTensor(object):
tensor_desc.device_type = device_type tensor_desc.device_type = device_type
tensor_desc.device_id = device_id tensor_desc.device_id = device_id
tensor_desc.is_pinned_host = is_pinned_host tensor_desc.is_pinned_host = is_pinned_host
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
self.update() if physic_construct:
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
self.update()
def __del__(self): def __del__(self):
self._api.LITE_destroy_tensor(self._tensor) self._api.LITE_destroy_tensor(self._tensor)
...@@ -399,7 +402,7 @@ class LiteTensor(object): ...@@ -399,7 +402,7 @@ class LiteTensor(object):
c_start = (c_size_t * length)(*start) c_start = (c_size_t * length)(*start)
c_end = (c_size_t * length)(*end) c_end = (c_size_t * length)(*end)
c_step = (c_size_t * length)(*step) c_step = (c_size_t * length)(*step)
slice_tensor = LiteTensor() slice_tensor = LiteTensor(physic_construct=False)
self._api.LITE_tensor_slice( self._api.LITE_tensor_slice(
self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor), self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor),
) )
...@@ -560,7 +563,7 @@ def LiteTensorConcat( ...@@ -560,7 +563,7 @@ def LiteTensorConcat(
length = len(tensors) length = len(tensors)
c_tensors = [t._tensor for t in tensors] c_tensors = [t._tensor for t in tensors]
c_tensors = (_Ctensor * length)(*c_tensors) c_tensors = (_Ctensor * length)(*c_tensors)
result_tensor = LiteTensor() result_tensor = LiteTensor(physic_construct=False)
api.LITE_tensor_concat( api.LITE_tensor_concat(
cast(byref(c_tensors), POINTER(c_void_p)), cast(byref(c_tensors), POINTER(c_void_p)),
length, length,
......
...@@ -1022,6 +1022,20 @@ TEST(TestCapiNetWork, TestShareWeights) { ...@@ -1022,6 +1022,20 @@ TEST(TestCapiNetWork, TestShareWeights) {
LITE_CAPI_CHECK(LITE_destroy_network(c_network2)); LITE_CAPI_CHECK(LITE_destroy_network(c_network2));
} }
TEST(TestCapiNetWork, GlobalHolder) {
std::string model_path = "./shufflenet.mge";
LiteNetwork c_network;
LITE_CAPI_CHECK(
LITE_make_network(&c_network, *default_config(), *default_network_io()));
auto destroy_network = c_network;
LITE_CAPI_CHECK(
LITE_make_network(&c_network, *default_config(), *default_network_io()));
//! make sure destroy_network is destroyed by LITE_make_network
LITE_destroy_network(destroy_network);
ASSERT_EQ(LITE_destroy_network(destroy_network), -1);
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -251,6 +251,7 @@ TEST(TestCapiTensor, Slice) { ...@@ -251,6 +251,7 @@ TEST(TestCapiTensor, Slice) {
} }
} }
LITE_destroy_tensor(tensor); LITE_destroy_tensor(tensor);
LITE_destroy_tensor(slice_tensor);
}; };
check(1, 8, 1, true); check(1, 8, 1, true);
check(1, 8, 1, false); check(1, 8, 1, false);
...@@ -316,6 +317,21 @@ TEST(TestCapiTensor, ThreadLocalError) { ...@@ -316,6 +317,21 @@ TEST(TestCapiTensor, ThreadLocalError) {
thread2.join(); thread2.join();
} }
TEST(TestCapiTensor, GlobalHolder) {
LiteTensor c_tensor0;
LiteTensorDesc description = default_desc;
description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT};
LITE_make_tensor(description, &c_tensor0);
auto destroy_tensor = c_tensor0;
LITE_make_tensor(description, &c_tensor0);
//! make sure destroy_tensor is destroyed by LITE_make_tensor
LITE_destroy_tensor(destroy_tensor);
ASSERT_EQ(LITE_destroy_tensor(destroy_tensor), -1);
LITE_destroy_tensor(c_tensor0);
}
#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册