From 5dfe2ab9e883a9d2ea1f227730a26dc3d1a42cd2 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 30 Apr 2019 06:50:03 -0500 Subject: [PATCH] Fix mem leak when converting Tensor to numpy array (#17182) * fix mem leak when converting Tensor to numpy array test=develop * remove unused unittest,test=develop * follow comments, test=develop * fix dygraph bug,test=develop --- paddle/fluid/pybind/CMakeLists.txt | 1 - paddle/fluid/pybind/pybind.cc | 6 +- paddle/fluid/pybind/tensor_py.h | 186 ++++++++---------- paddle/fluid/pybind/tensor_py_test.cc | 44 ----- python/paddle/fluid/dygraph/layers.py | 9 +- .../tests/unittests/test_tensor_to_numpy.py | 53 +++++ 6 files changed, 144 insertions(+), 155 deletions(-) delete mode 100644 paddle/fluid/pybind/tensor_py_test.cc create mode 100644 python/paddle/fluid/tests/unittests/test_tensor_to_numpy.py diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 900c1a0ca69..d709508a6d5 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -26,5 +26,4 @@ if(WITH_PYTHON) get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(paddle_pybind ${os_dependency_modules}) - cc_test(tensor_py_test SRCS tensor_py_test.cc DEPS python pybind) endif(WITH_PYTHON) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 0827ef9f456..8545b14e71c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -302,8 +302,7 @@ PYBIND11_MODULE(core, m) { BindImperative(&m); py::class_(m, "Tensor", py::buffer_protocol()) - .def_buffer( - [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) + .def("__array__", [](Tensor &self) { return TensorToPyArray(self); }) .def("_is_initialized", [](const Tensor &self) { return self.IsInitialized(); }) .def("_get_dims", @@ -419,8 +418,7 @@ PYBIND11_MODULE(core, m) { Users should be careful about it. )DOC") - .def_buffer( - [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) + .def("__array__", [](Tensor &self) { return TensorToPyArray(self); }) .def("__init__", [](LoDTensor &instance, const std::vector> &recursive_sequence_lengths) { diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index a30c7a723df..fd48f26f411 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -32,109 +32,6 @@ namespace py = pybind11; namespace paddle { namespace pybind { -namespace details { - -template -struct CastToPyBufferImpl; - -template -struct CastToPyBufferImpl { - pybind11::buffer_info operator()(const framework::Tensor &tensor) { - PADDLE_THROW("This type of tensor cannot be expose to Python"); - return pybind11::buffer_info(); - } -}; - -template -struct CastToPyBufferImpl { - using CUR_TYPE = typename std::tuple_element>::type; - pybind11::buffer_info operator()(const framework::Tensor &tensor) { - if (framework::DataTypeTrait::DataType == tensor.type()) { - auto dim_vec = framework::vectorize(tensor.dims()); - std::vector dims_outside; - std::vector strides; - dims_outside.resize(dim_vec.size()); - strides.resize(dim_vec.size()); - - size_t prod = 1; - for (size_t i = dim_vec.size(); i != 0; --i) { - dims_outside[i - 1] = (size_t)dim_vec[i - 1]; - strides[i - 1] = sizeof(CUR_TYPE) * prod; - prod *= dims_outside[i - 1]; - } - framework::Tensor dst_tensor; - bool is_gpu = paddle::platform::is_gpu_place(tensor.place()); - if (is_gpu) { -#ifdef PADDLE_WITH_CUDA - auto *src_ptr = static_cast(tensor.data()); - auto *dst_ptr = static_cast(dst_tensor.mutable_data( - tensor.dims(), platform::CPUPlace())); - - paddle::platform::GpuMemcpySync(dst_ptr, src_ptr, - sizeof(CUR_TYPE) * tensor.numel(), - cudaMemcpyDeviceToHost); -#else - PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); -#endif - } else if (paddle::platform::is_cpu_place(tensor.place())) { - dst_tensor = tensor; - } - - std::string dtype = std::type_index(typeid(CUR_TYPE)) == - std::type_index(typeid(platform::float16)) - ? std::string("e") // np.dtype('e') == np.float16 - : pybind11::format_descriptor::format(); - - if (is_gpu) { - // manually construct a py_buffer if is_gpu since gpu data is copied - // into CPU. - // TODO(yy): Is these following code memleak? - Py_buffer *py_buffer = - reinterpret_cast(malloc(sizeof(Py_buffer))); - py_buffer->format = strdup(dtype.c_str()); - py_buffer->itemsize = sizeof(CUR_TYPE); - py_buffer->ndim = framework::arity(dst_tensor.dims()); - py_buffer->len = tensor.numel(); - py_buffer->strides = reinterpret_cast( - malloc(sizeof(Py_ssize_t) * strides.size())); - for (size_t i = 0; i < strides.size(); ++i) { - py_buffer->strides[i] = strides[i]; - } - - py_buffer->shape = reinterpret_cast( - malloc(sizeof(Py_ssize_t) * tensor.dims().size())); - for (int i = 0; i < tensor.dims().size(); ++i) { - py_buffer->shape[i] = tensor.dims()[i]; - } - - py_buffer->readonly = false; - py_buffer->suboffsets = nullptr; - py_buffer->obj = nullptr; - py_buffer->buf = - malloc(static_cast(py_buffer->len * py_buffer->itemsize)); - memcpy(py_buffer->buf, dst_tensor.data(), - static_cast(py_buffer->len * py_buffer->itemsize)); - return pybind11::buffer_info(py_buffer, true); - } else { - return pybind11::buffer_info( - dst_tensor.data(), sizeof(CUR_TYPE), dtype, - (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); - } - } else { - constexpr bool less = I + 1 < std::tuple_size>::value; - return CastToPyBufferImpl()(tensor); - } - } -}; - -} // namespace details - -inline pybind11::buffer_info CastToPyBuffer(const framework::Tensor &tensor) { - auto buffer_info = - details::CastToPyBufferImpl()(tensor); - return buffer_info; -} template T TensorGetElement(const framework::Tensor &self, size_t offset) { @@ -531,5 +428,88 @@ inline void PyCUDAPinnedTensorSetFromArray( } #endif +namespace details { + +template +constexpr bool IsValidDTypeToPyArray() { + return false; +} + +#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \ + template <> \ + constexpr bool IsValidDTypeToPyArray() { \ + return true; \ + } + +DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(float); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(double); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(int); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t); + +inline std::string TensorDTypeToPyDTypeStr( + framework::proto::VarType::Type type) { +#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \ + if (type == proto_type) { \ + if (std::is_same::value) { \ + return "e"; \ + } else { \ + PADDLE_ENFORCE(IsValidDTypeToPyArray, \ + "This type of tensor cannot be expose to Python"); \ + return py::format_descriptor::format(); \ + } \ + } + + _ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE); +#undef TENSOR_DTYPE_TO_PY_DTYPE + PADDLE_THROW("Unsupported data type %d", static_cast(type)); +} + +} // namespace details + +inline py::array TensorToPyArray(const framework::Tensor &tensor) { + bool is_gpu_tensor = platform::is_gpu_place(tensor.place()); + const auto &tensor_dims = tensor.dims(); + auto tensor_dtype = tensor.type(); + size_t sizeof_dtype = framework::SizeOfType(tensor_dtype); + + std::vector py_dims(tensor_dims.size()); + std::vector py_strides(tensor_dims.size()); + + size_t numel = 1; + for (int i = tensor_dims.size() - 1; i >= 0; --i) { + py_dims[i] = (size_t)tensor_dims[i]; + py_strides[i] = sizeof_dtype * numel; + numel *= py_dims[i]; + } + + const void *tensor_buf_ptr = tensor.data(); + + std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type()); + + if (!is_gpu_tensor) { + return py::array(py::buffer_info( + const_cast(tensor_buf_ptr), sizeof_dtype, py_dtype_str, + static_cast(tensor.dims().size()), py_dims, py_strides)); + } + +#ifdef PADDLE_WITH_CUDA + py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); + PADDLE_ENFORCE(py_arr.writeable() && py_arr.owndata(), + "PyArray must be writable and own data, otherwise memory leak " + "or double free would occur"); + + size_t copy_bytes = sizeof_dtype * numel; + paddle::platform::GpuMemcpySync(py_arr.mutable_data(), tensor_buf_ptr, + copy_bytes, cudaMemcpyDeviceToHost); + return py_arr; +#else + PADDLE_THROW("CUDAPlace is not supported when not compiled with CUDA"); +#endif +} + } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/tensor_py_test.cc b/paddle/fluid/pybind/tensor_py_test.cc deleted file mode 100644 index 1a0ae1d6583..00000000000 --- a/paddle/fluid/pybind/tensor_py_test.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/pybind/tensor_py.h" - -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/tensor.h" - -TEST(TensorPy, CastToPyBufferImpl) { - typedef int ElemType; - - paddle::framework::Tensor t; - auto d = paddle::framework::make_ddim({1, 2, 3}); - int* p = t.mutable_data(d, paddle::platform::CPUPlace()); - for (int i = 0; i < paddle::framework::product(d); ++i) { - p[i] = i; - } - - pybind11::buffer_info bi = paddle::pybind::CastToPyBuffer(t); - EXPECT_EQ(bi.itemsize, static_cast(sizeof(ElemType))); - EXPECT_EQ(bi.size, static_cast(paddle::framework::product(d))); - EXPECT_EQ(bi.ndim, static_cast(3)); // 3-dimensional as d. - EXPECT_EQ(bi.shape.size(), 3U); // as Dim d. - EXPECT_EQ(bi.shape[0], static_cast(1)); - EXPECT_EQ(bi.shape[1], static_cast(2)); - EXPECT_EQ(bi.shape[2], static_cast(3)); - EXPECT_EQ(bi.strides.size(), 3U); // 3-dimensional as d. - EXPECT_EQ(bi.strides[2], static_cast(sizeof(ElemType))); - EXPECT_EQ(bi.strides[1], static_cast(sizeof(ElemType) * 3)); - EXPECT_EQ(bi.strides[0], static_cast(sizeof(ElemType) * 2 * 3)); -} diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index f856bdfa7e1..7ddf94146c7 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -291,9 +291,12 @@ class PyLayer(core.PyLayer): inputs = [inputs] ret = [] for inp in inputs: - tensor = core.LoDTensor() - tensor.set(inp, core.CPUPlace()) - ret.append(tensor) + if isinstance(inp, core.LoDTensor): + ret.append(inp) + else: + tensor = core.LoDTensor() + tensor.set(inp, core.CPUPlace()) + ret.append(tensor) return tuple(ret) @staticmethod diff --git a/python/paddle/fluid/tests/unittests/test_tensor_to_numpy.py b/python/paddle/fluid/tests/unittests/test_tensor_to_numpy.py new file mode 100644 index 00000000000..003f27652ef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_to_numpy.py @@ -0,0 +1,53 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import numpy as np +import six + + +class TensorToNumpyTest(unittest.TestCase): + def setUp(self): + self.shape = [11, 25, 32, 43] + + def test_main(self): + dtypes = [ + 'float32', 'float64', 'int32', 'int64', 'uint8', 'int8', 'bool' + ] + + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + places.append(fluid.CUDAPinnedPlace()) + + for p in places: + for dtype in dtypes: + np_arr = np.reshape( + np.array(six.moves.range(np.prod(self.shape))).astype( + dtype), self.shape) + + t = fluid.LoDTensor() + t.set(np_arr, p) + + ret_np_arr = np.array(t) + self.assertEqual(np_arr.shape, ret_np_arr.shape) + self.assertEqual(np_arr.dtype, ret_np_arr.dtype) + + all_equal = np.all(np_arr == ret_np_arr) + self.assertTrue(all_equal) + + +if __name__ == '__main__': + unittest.main() -- GitLab