From 59fec5d600d37a8b4411d180452fab66ff36d32c Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Tue, 7 Feb 2023 19:51:35 +0800 Subject: [PATCH] [cherry-pick 2.4] Fix to_dlpack (#50138) (#50250) * Fix to_dlpack (#50138) * fix to_dlpack for loop * fix reference count * fix conflicts --- paddle/fluid/framework/dlpack_tensor.cc | 54 ++++++++++++++++++++++++- paddle/fluid/framework/dlpack_tensor.h | 4 +- paddle/fluid/pybind/tensor.cc | 23 ++++------- python/paddle/tests/test_dlpack.py | 42 ++++++++++++------- 4 files changed, 92 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index b7bca733b8..9bf3485066 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -134,7 +134,59 @@ struct DLDeviceVisitor }; } // namespace internal -DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { +struct PaddleDLMTensor { + phi::DenseTensor handle; + DLManagedTensor tensor; +}; + +void deleter(DLManagedTensor *arg) { + delete[] arg->dl_tensor.shape; + delete[] arg->dl_tensor.strides; + delete static_cast(arg->manager_ctx); +} + +DLManagedTensor *toDLPack(const phi::DenseTensor &src) { + PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor); + pdDLMTensor->handle = const_cast(src); + pdDLMTensor->tensor.manager_ctx = pdDLMTensor; + pdDLMTensor->tensor.deleter = &deleter; + pdDLMTensor->tensor.dl_tensor.data = const_cast(src.data()); + + // init ndim + using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int + pdDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dims().size()); + DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim; + + // init shape + auto shape = new int64_t[ndim]; + for (DimType i = 0; i < ndim; ++i) { + shape[i] = src.dims()[i]; + } + pdDLMTensor->tensor.dl_tensor.shape = shape; + + // init stride + auto strides = new int64_t[ndim]; + for (DimType i = 0; i < ndim; ++i) { + strides[i] = 1; + } + for (DimType i = ndim - 2; i >= 0; --i) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + pdDLMTensor->tensor.dl_tensor.strides = strides; + + // init device, DLDevice type with device_type and device_id + auto place = src.place(); + pdDLMTensor->tensor.dl_tensor.device = + paddle::platform::VisitPlace(place, internal::DLDeviceVisitor()); + + pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex( + framework::TransToProtoVarType(src.dtype())); + + pdDLMTensor->tensor.dl_tensor.byte_offset = 0; + return &(pdDLMTensor->tensor); +} + +DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) { // init data, data buffer t_.data = const_cast(tensor.data()); diff --git a/paddle/fluid/framework/dlpack_tensor.h b/paddle/fluid/framework/dlpack_tensor.h index ff4cf23da6..4cd6d97a0c 100644 --- a/paddle/fluid/framework/dlpack_tensor.h +++ b/paddle/fluid/framework/dlpack_tensor.h @@ -28,7 +28,7 @@ class DLPackTensor { std::remove_reference::type; // int64_t // lanes is only used in CPU to enable vectorization - explicit DLPackTensor(const Tensor& tensor, LaneType lanes = 1); + explicit DLPackTensor(const phi::DenseTensor& tensor, LaneType lanes = 1); inline operator const ::DLTensor&() const { return t_; } @@ -44,5 +44,7 @@ class DLPackTensor { ShapeType shape_[DDim::kMaxRank]; }; +DLManagedTensor* toDLPack(const phi::DenseTensor& src); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index 6cc18bf5e2..6b7baef771 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -472,23 +472,16 @@ void BindTensor(pybind11::module &m) { // NOLINT print(t.shape()) # [5, 30] )DOC") .def("_to_dlpack", - [](framework::Tensor &self) { - DLPackTensor dlpack_tensor(self, 1); - DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor(); - auto capsule = py::capsule( + [](phi::DenseTensor &self) { + DLManagedTensor *dmt = framework::toDLPack(self); + auto capsule = pybind11::capsule( static_cast(dmt), "dltensor", [](PyObject *ptr) { - if (ptr) { - auto dltensor = new DLManagedTensor; - try { - dltensor = reinterpret_cast( - PyCapsule_GetPointer(ptr, "used_dltensor")); - return; - } catch (...) { - dltensor = reinterpret_cast( - PyCapsule_GetPointer(ptr, "dltensor")); - } - dltensor->deleter(dltensor); + if (!PyCapsule_IsValid(ptr, "dltensor")) { + return; } + DLManagedTensor *dmt = static_cast( + PyCapsule_GetPointer(ptr, "dltensor")); + dmt->deleter(dmt); }); return capsule; }) diff --git a/python/paddle/tests/test_dlpack.py b/python/paddle/tests/test_dlpack.py index 353dc7ebfe..61a3f327dd 100644 --- a/python/paddle/tests/test_dlpack.py +++ b/python/paddle/tests/test_dlpack.py @@ -22,7 +22,6 @@ from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode class TestDLPack(unittest.TestCase): - def func_test_dlpack_dygraph(self): paddle.disable_static() tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int')) @@ -30,11 +29,13 @@ class TestDLPack(unittest.TestCase): out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) if paddle.fluid.framework.in_dygraph_mode(): self.assertTrue( - isinstance(out_from_dlpack, paddle.fluid.core.eager.Tensor)) + isinstance(out_from_dlpack, paddle.fluid.core.eager.Tensor) + ) else: self.assertTrue(isinstance(out_from_dlpack, paddle.Tensor)) - np.testing.assert_array_equal(np.array(out_from_dlpack), - np.array([1, 2, 3, 4]).astype('int')) + np.testing.assert_array_equal( + np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype('int') + ) def test_dlpack_dygraph(self): with _test_eager_guard(): @@ -58,26 +59,32 @@ class TestDLPack(unittest.TestCase): def test_dlpack_static(self): paddle.enable_static() tensor = fluid.create_lod_tensor( - np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]], - fluid.CPUPlace()) + np.array([[1], [2], [3], [4]]).astype('int'), + [[1, 3]], + fluid.CPUPlace(), + ) dlpack = paddle.utils.dlpack.to_dlpack(tensor) out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) self.assertTrue(isinstance(out_from_dlpack, fluid.core.Tensor)) np.testing.assert_array_equal( np.array(out_from_dlpack), - np.array([[1], [2], [3], [4]]).astype('int')) + np.array([[1], [2], [3], [4]]).astype('int'), + ) # when build with cuda if core.is_compiled_with_cuda(): gtensor = fluid.create_lod_tensor( - np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]], - fluid.CUDAPlace(0)) + np.array([[1], [2], [3], [4]]).astype('int'), + [[1, 3]], + fluid.CUDAPlace(0), + ) gdlpack = paddle.utils.dlpack.to_dlpack(gtensor) gout_from_dlpack = paddle.utils.dlpack.from_dlpack(gdlpack) self.assertTrue(isinstance(gout_from_dlpack, fluid.core.Tensor)) np.testing.assert_array_equal( np.array(gout_from_dlpack), - np.array([[1], [2], [3], [4]]).astype('int')) + np.array([[1], [2], [3], [4]]).astype('int'), + ) def func_test_dlpack_dtype_conversion(self): paddle.disable_static() @@ -104,7 +111,8 @@ class TestDLPack(unittest.TestCase): for dtype in complex_dtypes: x = paddle.to_tensor( [[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]], - dtype=dtype) + dtype=dtype, + ) dlpack = paddle.utils.dlpack.to_dlpack(x) o = paddle.utils.dlpack.from_dlpack(dlpack) self.assertEqual(x.dtype, o.dtype) @@ -115,12 +123,18 @@ class TestDLPack(unittest.TestCase): self.func_test_dlpack_dtype_conversion() self.func_test_dlpack_dtype_conversion() + def test_to_dlpack_for_loop(self): + # See Paddle issue 50120 + for i in range(10): + x = paddle.rand([3, 5]) + dlpack = paddle.utils.dlpack.to_dlpack(x) -class TestRaiseError(unittest.TestCase): +class TestRaiseError(unittest.TestCase): def func_test_from_dlpack_raise_type_error(self): - self.assertRaises(TypeError, paddle.utils.dlpack.from_dlpack, - np.zeros(5)) + self.assertRaises( + TypeError, paddle.utils.dlpack.from_dlpack, np.zeros(5) + ) def test_from_dlpack_raise_type_error(self): with _test_eager_guard(): -- GitLab