diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 32c6e17143fa2d5e03adcb8eaa36d6d41238fb62..419eb0e7fe849c1890774c075fa9fcd0a55e4fab 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/dlpack_tensor.h" +#include "pybind11/pybind11.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_type.h" @@ -134,8 +135,9 @@ struct DLDeviceVisitor }; } // namespace internal -DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) { +DLPackTensor::DLPackTensor(phi::DenseTensor &tensor, LaneType lanes) { // init data, data buffer + dt_ = &tensor; t_.data = const_cast(tensor.data()); // init device, DLDevice type with device_type and device_id @@ -188,12 +190,17 @@ DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) { tensor->dl_tensor = t_; tensor->deleter = [](DLManagedTensor *arg) { + phi::DenseTensor *tensor_ptr = + reinterpret_cast(arg->manager_ctx); + pybind11::handle tensor_handle = pybind11::cast(tensor_ptr); + tensor_handle.dec_ref(); + delete[] arg->dl_tensor.shape; delete[] arg->dl_tensor.strides; delete arg; }; - tensor->manager_ctx = nullptr; + tensor->manager_ctx = dt_; return tensor; } diff --git a/paddle/fluid/framework/dlpack_tensor.h b/paddle/fluid/framework/dlpack_tensor.h index c6fca6707fad29f48d4cc7656036ab90e0274512..2fcd2511afb8b5122ec81fc5ac0cf4d7f1867b9a 100644 --- a/paddle/fluid/framework/dlpack_tensor.h +++ b/paddle/fluid/framework/dlpack_tensor.h @@ -28,7 +28,7 @@ class DLPackTensor { std::remove_reference::type; // int64_t // lanes is only used in CPU to enable vectorization - explicit DLPackTensor(const phi::DenseTensor& tensor, LaneType lanes = 1); + explicit DLPackTensor(phi::DenseTensor& tensor, LaneType lanes = 1); inline operator const ::DLTensor&() const { return t_; } @@ -42,6 +42,8 @@ class DLPackTensor { // The shape in DLTensor is defined as int64_t* // Add this member to make TVMTensor init without heap allocation ShapeType shape_[DDim::kMaxRank]; + + phi::DenseTensor* dt_; }; } // namespace framework diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index 29a17d25d9dc24c29cecf4c57f27fd45c8d646c2..b58e697e69532b0083cad77b3aedba0236449e32 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -474,6 +474,8 @@ void BindTensor(pybind11::module &m) { // NOLINT .def("_to_dlpack", [](phi::DenseTensor &self) { DLPackTensor dlpack_tensor(self, 1); + pybind11::handle tensor_handle = pybind11::cast(&self); + tensor_handle.inc_ref(); DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor(); auto capsule = pybind11::capsule( static_cast(dmt), "dltensor", [](PyObject *ptr) {