提交 94cf213d 编写于 作者: D DesmonDay

fix reference count

上级 d7d490e4
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h" #include "paddle/fluid/framework/dlpack_tensor.h"
#include "pybind11/pybind11.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -134,8 +135,9 @@ struct DLDeviceVisitor ...@@ -134,8 +135,9 @@ struct DLDeviceVisitor
}; };
} // namespace internal } // namespace internal
DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) { DLPackTensor::DLPackTensor(phi::DenseTensor &tensor, LaneType lanes) {
// init data, data buffer // init data, data buffer
dt_ = &tensor;
t_.data = const_cast<void *>(tensor.data()); t_.data = const_cast<void *>(tensor.data());
// init device, DLDevice type with device_type and device_id // init device, DLDevice type with device_type and device_id
...@@ -188,12 +190,17 @@ DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) { ...@@ -188,12 +190,17 @@ DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) {
tensor->dl_tensor = t_; tensor->dl_tensor = t_;
tensor->deleter = [](DLManagedTensor *arg) { tensor->deleter = [](DLManagedTensor *arg) {
phi::DenseTensor *tensor_ptr =
reinterpret_cast<phi::DenseTensor *>(arg->manager_ctx);
pybind11::handle tensor_handle = pybind11::cast(tensor_ptr);
tensor_handle.dec_ref();
delete[] arg->dl_tensor.shape; delete[] arg->dl_tensor.shape;
delete[] arg->dl_tensor.strides; delete[] arg->dl_tensor.strides;
delete arg; delete arg;
}; };
tensor->manager_ctx = nullptr; tensor->manager_ctx = dt_;
return tensor; return tensor;
} }
......
...@@ -28,7 +28,7 @@ class DLPackTensor { ...@@ -28,7 +28,7 @@ class DLPackTensor {
std::remove_reference<decltype(::DLTensor::shape[0])>::type; // int64_t std::remove_reference<decltype(::DLTensor::shape[0])>::type; // int64_t
// lanes is only used in CPU to enable vectorization // lanes is only used in CPU to enable vectorization
explicit DLPackTensor(const phi::DenseTensor& tensor, LaneType lanes = 1); explicit DLPackTensor(phi::DenseTensor& tensor, LaneType lanes = 1);
inline operator const ::DLTensor&() const { return t_; } inline operator const ::DLTensor&() const { return t_; }
...@@ -42,6 +42,8 @@ class DLPackTensor { ...@@ -42,6 +42,8 @@ class DLPackTensor {
// The shape in DLTensor is defined as int64_t* // The shape in DLTensor is defined as int64_t*
// Add this member to make TVMTensor init without heap allocation // Add this member to make TVMTensor init without heap allocation
ShapeType shape_[DDim::kMaxRank]; ShapeType shape_[DDim::kMaxRank];
phi::DenseTensor* dt_;
}; };
} // namespace framework } // namespace framework
......
...@@ -474,6 +474,8 @@ void BindTensor(pybind11::module &m) { // NOLINT ...@@ -474,6 +474,8 @@ void BindTensor(pybind11::module &m) { // NOLINT
.def("_to_dlpack", .def("_to_dlpack",
[](phi::DenseTensor &self) { [](phi::DenseTensor &self) {
DLPackTensor dlpack_tensor(self, 1); DLPackTensor dlpack_tensor(self, 1);
pybind11::handle tensor_handle = pybind11::cast(&self);
tensor_handle.inc_ref();
DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor(); DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor();
auto capsule = pybind11::capsule( auto capsule = pybind11::capsule(
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) { static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册