diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 32c6e17143fa2d5e03adcb8eaa36d6d41238fb62..9bf348506655b281804f113de014d9c113496f99 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -134,6 +134,58 @@ struct DLDeviceVisitor }; } // namespace internal +struct PaddleDLMTensor { + phi::DenseTensor handle; + DLManagedTensor tensor; +}; + +void deleter(DLManagedTensor *arg) { + delete[] arg->dl_tensor.shape; + delete[] arg->dl_tensor.strides; + delete static_cast(arg->manager_ctx); +} + +DLManagedTensor *toDLPack(const phi::DenseTensor &src) { + PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor); + pdDLMTensor->handle = const_cast(src); + pdDLMTensor->tensor.manager_ctx = pdDLMTensor; + pdDLMTensor->tensor.deleter = &deleter; + pdDLMTensor->tensor.dl_tensor.data = const_cast(src.data()); + + // init ndim + using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int + pdDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dims().size()); + DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim; + + // init shape + auto shape = new int64_t[ndim]; + for (DimType i = 0; i < ndim; ++i) { + shape[i] = src.dims()[i]; + } + pdDLMTensor->tensor.dl_tensor.shape = shape; + + // init stride + auto strides = new int64_t[ndim]; + for (DimType i = 0; i < ndim; ++i) { + strides[i] = 1; + } + for (DimType i = ndim - 2; i >= 0; --i) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + pdDLMTensor->tensor.dl_tensor.strides = strides; + + // init device, DLDevice type with device_type and device_id + auto place = src.place(); + pdDLMTensor->tensor.dl_tensor.device = + paddle::platform::VisitPlace(place, internal::DLDeviceVisitor()); + + pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex( + framework::TransToProtoVarType(src.dtype())); + + pdDLMTensor->tensor.dl_tensor.byte_offset = 0; + return &(pdDLMTensor->tensor); +} + DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) { // init data, data buffer t_.data = const_cast(tensor.data()); diff --git a/paddle/fluid/framework/dlpack_tensor.h b/paddle/fluid/framework/dlpack_tensor.h index c6fca6707fad29f48d4cc7656036ab90e0274512..4cd6d97a0c5cc3e5198586264bd8d8483b486b7c 100644 --- a/paddle/fluid/framework/dlpack_tensor.h +++ b/paddle/fluid/framework/dlpack_tensor.h @@ -44,5 +44,7 @@ class DLPackTensor { ShapeType shape_[DDim::kMaxRank]; }; +DLManagedTensor* toDLPack(const phi::DenseTensor& src); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index 4bdde24f431bc830171ea4d3c02d1064da67926b..570920022e8a9bc4e53bbfc9d2a8f2e75852fad7 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -473,22 +473,15 @@ void BindTensor(pybind11::module &m) { // NOLINT )DOC") .def("_to_dlpack", [](phi::DenseTensor &self) { - DLPackTensor dlpack_tensor(self, 1); - DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor(); - auto capsule = py::capsule( + DLManagedTensor *dmt = framework::toDLPack(self); + auto capsule = pybind11::capsule( static_cast(dmt), "dltensor", [](PyObject *ptr) { - if (ptr) { - auto dltensor = new DLManagedTensor; - try { - dltensor = reinterpret_cast( - PyCapsule_GetPointer(ptr, "used_dltensor")); - return; - } catch (...) { - dltensor = reinterpret_cast( - PyCapsule_GetPointer(ptr, "dltensor")); - } - dltensor->deleter(dltensor); + if (!PyCapsule_IsValid(ptr, "dltensor")) { + return; } + DLManagedTensor *dmt = static_cast( + PyCapsule_GetPointer(ptr, "dltensor")); + dmt->deleter(dmt); }); return capsule; }) diff --git a/python/paddle/tests/test_dlpack.py b/python/paddle/tests/test_dlpack.py index 77ffdbecedbf60317f5209b4be72bf477b9f10a8..504ff4c48c66eda9987ebf44ebb9dc794b18bf1f 100644 --- a/python/paddle/tests/test_dlpack.py +++ b/python/paddle/tests/test_dlpack.py @@ -116,6 +116,12 @@ class TestDLPack(unittest.TestCase): dlpack = paddle.utils.dlpack.to_dlpack(a) b = paddle.utils.dlpack.from_dlpack(dlpack) + def test_to_dlpack_for_loop(self): + # See Paddle issue 50120 + for i in range(10): + x = paddle.rand([3, 5]) + dlpack = paddle.utils.dlpack.to_dlpack(x) + class TestRaiseError(unittest.TestCase): def test_from_dlpack_raise_type_error(self):