未验证 提交 5dfe2ab9 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix mem leak when converting Tensor to numpy array (#17182)

* fix mem leak when converting Tensor to numpy array
test=develop

* remove unused unittest,test=develop

* follow comments, test=develop

* fix dygraph bug,test=develop
上级 e4a53324
......@@ -26,5 +26,4 @@ if(WITH_PYTHON)
get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(paddle_pybind ${os_dependency_modules})
cc_test(tensor_py_test SRCS tensor_py_test.cc DEPS python pybind)
endif(WITH_PYTHON)
......@@ -302,8 +302,7 @@ PYBIND11_MODULE(core, m) {
BindImperative(&m);
py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer(
[](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
.def("__array__", [](Tensor &self) { return TensorToPyArray(self); })
.def("_is_initialized",
[](const Tensor &self) { return self.IsInitialized(); })
.def("_get_dims",
......@@ -419,8 +418,7 @@ PYBIND11_MODULE(core, m) {
Users should be careful about it.
)DOC")
.def_buffer(
[](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
.def("__array__", [](Tensor &self) { return TensorToPyArray(self); })
.def("__init__",
[](LoDTensor &instance, const std::vector<std::vector<size_t>>
&recursive_sequence_lengths) {
......
......@@ -32,109 +32,6 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
namespace details {
template <bool less, size_t I, typename... ARGS>
struct CastToPyBufferImpl;
template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<false, I, ARGS...> {
pybind11::buffer_info operator()(const framework::Tensor &tensor) {
PADDLE_THROW("This type of tensor cannot be expose to Python");
return pybind11::buffer_info();
}
};
template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
pybind11::buffer_info operator()(const framework::Tensor &tensor) {
if (framework::DataTypeTrait<CUR_TYPE>::DataType == tensor.type()) {
auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside;
std::vector<size_t> strides;
dims_outside.resize(dim_vec.size());
strides.resize(dim_vec.size());
size_t prod = 1;
for (size_t i = dim_vec.size(); i != 0; --i) {
dims_outside[i - 1] = (size_t)dim_vec[i - 1];
strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1];
}
framework::Tensor dst_tensor;
bool is_gpu = paddle::platform::is_gpu_place(tensor.place());
if (is_gpu) {
#ifdef PADDLE_WITH_CUDA
auto *src_ptr = static_cast<const void *>(tensor.data<CUR_TYPE>());
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace()));
paddle::platform::GpuMemcpySync(dst_ptr, src_ptr,
sizeof(CUR_TYPE) * tensor.numel(),
cudaMemcpyDeviceToHost);
#else
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
#endif
} else if (paddle::platform::is_cpu_place(tensor.place())) {
dst_tensor = tensor;
}
std::string dtype = std::type_index(typeid(CUR_TYPE)) ==
std::type_index(typeid(platform::float16))
? std::string("e") // np.dtype('e') == np.float16
: pybind11::format_descriptor<CUR_TYPE>::format();
if (is_gpu) {
// manually construct a py_buffer if is_gpu since gpu data is copied
// into CPU.
// TODO(yy): Is these following code memleak?
Py_buffer *py_buffer =
reinterpret_cast<Py_buffer *>(malloc(sizeof(Py_buffer)));
py_buffer->format = strdup(dtype.c_str());
py_buffer->itemsize = sizeof(CUR_TYPE);
py_buffer->ndim = framework::arity(dst_tensor.dims());
py_buffer->len = tensor.numel();
py_buffer->strides = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * strides.size()));
for (size_t i = 0; i < strides.size(); ++i) {
py_buffer->strides[i] = strides[i];
}
py_buffer->shape = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * tensor.dims().size()));
for (int i = 0; i < tensor.dims().size(); ++i) {
py_buffer->shape[i] = tensor.dims()[i];
}
py_buffer->readonly = false;
py_buffer->suboffsets = nullptr;
py_buffer->obj = nullptr;
py_buffer->buf =
malloc(static_cast<size_t>(py_buffer->len * py_buffer->itemsize));
memcpy(py_buffer->buf, dst_tensor.data<CUR_TYPE>(),
static_cast<size_t>(py_buffer->len * py_buffer->itemsize));
return pybind11::buffer_info(py_buffer, true);
} else {
return pybind11::buffer_info(
dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE), dtype,
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
}
} else {
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
}
}
};
} // namespace details
inline pybind11::buffer_info CastToPyBuffer(const framework::Tensor &tensor) {
auto buffer_info =
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool,
uint8_t, int8_t, platform::float16>()(tensor);
return buffer_info;
}
template <typename T>
T TensorGetElement(const framework::Tensor &self, size_t offset) {
......@@ -531,5 +428,88 @@ inline void PyCUDAPinnedTensorSetFromArray(
}
#endif
namespace details {
template <typename T>
constexpr bool IsValidDTypeToPyArray() {
return false;
}
#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \
template <> \
constexpr bool IsValidDTypeToPyArray<type>() { \
return true; \
}
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t);
inline std::string TensorDTypeToPyDTypeStr(
framework::proto::VarType::Type type) {
#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \
if (type == proto_type) { \
if (std::is_same<T, platform::float16>::value) { \
return "e"; \
} else { \
PADDLE_ENFORCE(IsValidDTypeToPyArray<T>, \
"This type of tensor cannot be expose to Python"); \
return py::format_descriptor<T>::format(); \
} \
}
_ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE);
#undef TENSOR_DTYPE_TO_PY_DTYPE
PADDLE_THROW("Unsupported data type %d", static_cast<int>(type));
}
} // namespace details
inline py::array TensorToPyArray(const framework::Tensor &tensor) {
bool is_gpu_tensor = platform::is_gpu_place(tensor.place());
const auto &tensor_dims = tensor.dims();
auto tensor_dtype = tensor.type();
size_t sizeof_dtype = framework::SizeOfType(tensor_dtype);
std::vector<size_t> py_dims(tensor_dims.size());
std::vector<size_t> py_strides(tensor_dims.size());
size_t numel = 1;
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
py_dims[i] = (size_t)tensor_dims[i];
py_strides[i] = sizeof_dtype * numel;
numel *= py_dims[i];
}
const void *tensor_buf_ptr = tensor.data<void>();
std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type());
if (!is_gpu_tensor) {
return py::array(py::buffer_info(
const_cast<void *>(tensor_buf_ptr), sizeof_dtype, py_dtype_str,
static_cast<size_t>(tensor.dims().size()), py_dims, py_strides));
}
#ifdef PADDLE_WITH_CUDA
py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
PADDLE_ENFORCE(py_arr.writeable() && py_arr.owndata(),
"PyArray must be writable and own data, otherwise memory leak "
"or double free would occur");
size_t copy_bytes = sizeof_dtype * numel;
paddle::platform::GpuMemcpySync(py_arr.mutable_data(), tensor_buf_ptr,
copy_bytes, cudaMemcpyDeviceToHost);
return py_arr;
#else
PADDLE_THROW("CUDAPlace is not supported when not compiled with CUDA");
#endif
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/tensor_py.h"
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
TEST(TensorPy, CastToPyBufferImpl) {
typedef int ElemType;
paddle::framework::Tensor t;
auto d = paddle::framework::make_ddim({1, 2, 3});
int* p = t.mutable_data<ElemType>(d, paddle::platform::CPUPlace());
for (int i = 0; i < paddle::framework::product(d); ++i) {
p[i] = i;
}
pybind11::buffer_info bi = paddle::pybind::CastToPyBuffer(t);
EXPECT_EQ(bi.itemsize, static_cast<size_t>(sizeof(ElemType)));
EXPECT_EQ(bi.size, static_cast<size_t>(paddle::framework::product(d)));
EXPECT_EQ(bi.ndim, static_cast<size_t>(3)); // 3-dimensional as d.
EXPECT_EQ(bi.shape.size(), 3U); // as Dim d.
EXPECT_EQ(bi.shape[0], static_cast<size_t>(1));
EXPECT_EQ(bi.shape[1], static_cast<size_t>(2));
EXPECT_EQ(bi.shape[2], static_cast<size_t>(3));
EXPECT_EQ(bi.strides.size(), 3U); // 3-dimensional as d.
EXPECT_EQ(bi.strides[2], static_cast<size_t>(sizeof(ElemType)));
EXPECT_EQ(bi.strides[1], static_cast<size_t>(sizeof(ElemType) * 3));
EXPECT_EQ(bi.strides[0], static_cast<size_t>(sizeof(ElemType) * 2 * 3));
}
......@@ -291,6 +291,9 @@ class PyLayer(core.PyLayer):
inputs = [inputs]
ret = []
for inp in inputs:
if isinstance(inp, core.LoDTensor):
ret.append(inp)
else:
tensor = core.LoDTensor()
tensor.set(inp, core.CPUPlace())
ret.append(tensor)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import six
class TensorToNumpyTest(unittest.TestCase):
def setUp(self):
self.shape = [11, 25, 32, 43]
def test_main(self):
dtypes = [
'float32', 'float64', 'int32', 'int64', 'uint8', 'int8', 'bool'
]
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
places.append(fluid.CUDAPinnedPlace())
for p in places:
for dtype in dtypes:
np_arr = np.reshape(
np.array(six.moves.range(np.prod(self.shape))).astype(
dtype), self.shape)
t = fluid.LoDTensor()
t.set(np_arr, p)
ret_np_arr = np.array(t)
self.assertEqual(np_arr.shape, ret_np_arr.shape)
self.assertEqual(np_arr.dtype, ret_np_arr.dtype)
all_equal = np.all(np_arr == ret_np_arr)
self.assertTrue(all_equal)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册