提交 94cf213d 编写于 作者: D DesmonDay

fix reference count

上级 d7d490e4
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "pybind11/pybind11.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
......@@ -134,8 +135,9 @@ struct DLDeviceVisitor
};
} // namespace internal
DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) {
DLPackTensor::DLPackTensor(phi::DenseTensor &tensor, LaneType lanes) {
// init data, data buffer
dt_ = &tensor;
t_.data = const_cast<void *>(tensor.data());
// init device, DLDevice type with device_type and device_id
......@@ -188,12 +190,17 @@ DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) {
tensor->dl_tensor = t_;
tensor->deleter = [](DLManagedTensor *arg) {
phi::DenseTensor *tensor_ptr =
reinterpret_cast<phi::DenseTensor *>(arg->manager_ctx);
pybind11::handle tensor_handle = pybind11::cast(tensor_ptr);
tensor_handle.dec_ref();
delete[] arg->dl_tensor.shape;
delete[] arg->dl_tensor.strides;
delete arg;
};
tensor->manager_ctx = nullptr;
tensor->manager_ctx = dt_;
return tensor;
}
......
......@@ -28,7 +28,7 @@ class DLPackTensor {
std::remove_reference<decltype(::DLTensor::shape[0])>::type; // int64_t
// lanes is only used in CPU to enable vectorization
explicit DLPackTensor(const phi::DenseTensor& tensor, LaneType lanes = 1);
explicit DLPackTensor(phi::DenseTensor& tensor, LaneType lanes = 1);
inline operator const ::DLTensor&() const { return t_; }
......@@ -42,6 +42,8 @@ class DLPackTensor {
// The shape in DLTensor is defined as int64_t*
// Add this member to make TVMTensor init without heap allocation
ShapeType shape_[DDim::kMaxRank];
phi::DenseTensor* dt_;
};
} // namespace framework
......
......@@ -474,6 +474,8 @@ void BindTensor(pybind11::module &m) { // NOLINT
.def("_to_dlpack",
[](phi::DenseTensor &self) {
DLPackTensor dlpack_tensor(self, 1);
pybind11::handle tensor_handle = pybind11::cast(&self);
tensor_handle.inc_ref();
DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor();
auto capsule = pybind11::capsule(
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册