提交 6c78e684 编写于 作者: M Megvii Engine Team

fix(lite): fix lite memory leak

GitOrigin-RevId: 075c686162e2c7b75a4c54245d39c6aba68c2503
上级 ff239c63
......@@ -242,7 +242,13 @@ int LITE_destroy_network(LiteNetwork network) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
LITE_LOCK_GUARD(mtx_network);
get_gloabl_network_holder().erase(network);
auto& global_holder = get_gloabl_network_holder();
if (global_holder.find(network) != global_holder.end()) {
global_holder.erase(network);
} else {
//! means the network has been destoryed
return -1;
}
LITE_CAPI_END();
}
......
......@@ -60,8 +60,10 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
auto lite_tensor = std::make_shared<lite::Tensor>(
tensor_describe.device_id, tensor_describe.device_type, layout,
tensor_describe.is_pinned_host);
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor;
{
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor;
}
*tensor = lite_tensor.get();
LITE_CAPI_END();
}
......@@ -70,7 +72,13 @@ int LITE_destroy_tensor(LiteTensor tensor) {
LITE_CAPI_BEGIN();
LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder().erase(tensor);
auto& global_holder = get_global_tensor_holder();
if (global_holder.find(tensor) != global_holder.end()) {
global_holder.erase(tensor);
} else {
//! return -1, means the tensor has been destroyed.
return -1;
}
LITE_CAPI_END();
}
......@@ -126,8 +134,10 @@ int LITE_tensor_slice(
}
}
auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps);
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor;
{
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor;
}
*slice_tensor = ret_tensor.get();
LITE_CAPI_END();
}
......@@ -226,12 +236,16 @@ int LITE_tensor_concat(
LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device,
int device_id, LiteTensor* result_tensor) {
LITE_CAPI_BEGIN();
LITE_ASSERT(result_tensor, "The tensor pass to LITE c_api is null");
std::vector<lite::Tensor> v_tensors;
for (int i = 0; i < nr_tensor; i++) {
v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i]));
}
auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id);
get_global_tensor_holder()[tensor.get()] = tensor;
{
LITE_LOCK_GUARD(mtx_tensor);
get_global_tensor_holder()[tensor.get()] = tensor;
}
*result_tensor = tensor.get();
LITE_CAPI_END()
}
......
......@@ -476,7 +476,7 @@ def start_finish_callback(func):
def wrapper(c_ios, c_tensors, size):
ios = {}
for i in range(size):
tensor = LiteTensor()
tensor = LiteTensor(physic_construct=False)
tensor._tensor = c_void_p(c_tensors[i])
tensor.update()
io = c_ios[i]
......@@ -729,7 +729,7 @@ class LiteNetwork(object):
c_name = c_char_p(name.encode("utf-8"))
else:
c_name = c_char_p(name)
tensor = LiteTensor()
tensor = LiteTensor(physic_construct=False)
self._api.LITE_get_io_tensor(
self._network, c_name, phase, byref(tensor._tensor)
)
......
......@@ -233,6 +233,7 @@ class LiteTensor(object):
is_pinned_host=False,
shapes=None,
dtype=None,
physic_construct=True,
):
self._tensor = _Ctensor()
self._layout = LiteLayout()
......@@ -250,8 +251,10 @@ class LiteTensor(object):
tensor_desc.device_type = device_type
tensor_desc.device_id = device_id
tensor_desc.is_pinned_host = is_pinned_host
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
self.update()
if physic_construct:
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
self.update()
def __del__(self):
self._api.LITE_destroy_tensor(self._tensor)
......@@ -399,7 +402,7 @@ class LiteTensor(object):
c_start = (c_size_t * length)(*start)
c_end = (c_size_t * length)(*end)
c_step = (c_size_t * length)(*step)
slice_tensor = LiteTensor()
slice_tensor = LiteTensor(physic_construct=False)
self._api.LITE_tensor_slice(
self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor),
)
......@@ -560,7 +563,7 @@ def LiteTensorConcat(
length = len(tensors)
c_tensors = [t._tensor for t in tensors]
c_tensors = (_Ctensor * length)(*c_tensors)
result_tensor = LiteTensor()
result_tensor = LiteTensor(physic_construct=False)
api.LITE_tensor_concat(
cast(byref(c_tensors), POINTER(c_void_p)),
length,
......
......@@ -1022,6 +1022,20 @@ TEST(TestCapiNetWork, TestShareWeights) {
LITE_CAPI_CHECK(LITE_destroy_network(c_network2));
}
TEST(TestCapiNetWork, GlobalHolder) {
std::string model_path = "./shufflenet.mge";
LiteNetwork c_network;
LITE_CAPI_CHECK(
LITE_make_network(&c_network, *default_config(), *default_network_io()));
auto destroy_network = c_network;
LITE_CAPI_CHECK(
LITE_make_network(&c_network, *default_config(), *default_network_io()));
//! make sure destroy_network is destroyed by LITE_make_network
LITE_destroy_network(destroy_network);
ASSERT_EQ(LITE_destroy_network(destroy_network), -1);
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -251,6 +251,7 @@ TEST(TestCapiTensor, Slice) {
}
}
LITE_destroy_tensor(tensor);
LITE_destroy_tensor(slice_tensor);
};
check(1, 8, 1, true);
check(1, 8, 1, false);
......@@ -316,6 +317,21 @@ TEST(TestCapiTensor, ThreadLocalError) {
thread2.join();
}
TEST(TestCapiTensor, GlobalHolder) {
LiteTensor c_tensor0;
LiteTensorDesc description = default_desc;
description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT};
LITE_make_tensor(description, &c_tensor0);
auto destroy_tensor = c_tensor0;
LITE_make_tensor(description, &c_tensor0);
//! make sure destroy_tensor is destroyed by LITE_make_tensor
LITE_destroy_tensor(destroy_tensor);
ASSERT_EQ(LITE_destroy_tensor(destroy_tensor), -1);
LITE_destroy_tensor(c_tensor0);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册