提交 6bd41ecc 编写于 作者: D DesmonDay

fix dlpack

上级 5cedad40
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h" #include "paddle/fluid/framework/dlpack_tensor.h"
#include "pybind11/pybind11.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -135,9 +134,60 @@ struct DLDeviceVisitor ...@@ -135,9 +134,60 @@ struct DLDeviceVisitor
}; };
} // namespace internal } // namespace internal
DLPackTensor::DLPackTensor(phi::DenseTensor &tensor, LaneType lanes) { struct PaddleDLMTensor {
phi::DenseTensor handle;
DLManagedTensor tensor;
};
void deleter(DLManagedTensor *arg) {
delete[] arg->dl_tensor.shape;
delete[] arg->dl_tensor.strides;
delete static_cast<PaddleDLMTensor *>(arg->manager_ctx);
}
DLManagedTensor *toDLPack(const phi::DenseTensor &src) {
PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor);
pdDLMTensor->handle = const_cast<phi::DenseTensor &>(src);
pdDLMTensor->tensor.manager_ctx = pdDLMTensor;
pdDLMTensor->tensor.deleter = &deleter;
pdDLMTensor->tensor.dl_tensor.data = const_cast<void *>(src.data());
// init ndim
using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int
pdDLMTensor->tensor.dl_tensor.ndim = static_cast<DimType>(src.dims().size());
DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim;
// init shape
auto shape = new int64_t[ndim];
for (DimType i = 0; i < ndim; ++i) {
shape[i] = src.dims()[i];
}
pdDLMTensor->tensor.dl_tensor.shape = shape;
// init stride
auto strides = new int64_t[ndim];
for (DimType i = 0; i < ndim; ++i) {
strides[i] = 1;
}
for (DimType i = ndim - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
pdDLMTensor->tensor.dl_tensor.strides = strides;
// init device, DLDevice type with device_type and device_id
auto place = src.place();
pdDLMTensor->tensor.dl_tensor.device =
paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());
pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex(
framework::TransToProtoVarType(src.dtype()));
pdDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(pdDLMTensor->tensor);
}
DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) {
// init data, data buffer // init data, data buffer
dt_ = &tensor;
t_.data = const_cast<void *>(tensor.data()); t_.data = const_cast<void *>(tensor.data());
// init device, DLDevice type with device_type and device_id // init device, DLDevice type with device_type and device_id
...@@ -188,19 +238,15 @@ DLPackTensor::DLPackTensor(phi::DenseTensor &tensor, LaneType lanes) { ...@@ -188,19 +238,15 @@ DLPackTensor::DLPackTensor(phi::DenseTensor &tensor, LaneType lanes) {
auto tensor = new DLManagedTensor; auto tensor = new DLManagedTensor;
tensor->dl_tensor = t_; tensor->dl_tensor = t_;
tensor->manager_ctx = dt_;
tensor->deleter = [](DLManagedTensor *arg) { tensor->deleter = [](DLManagedTensor *arg) {
phi::DenseTensor *tensor_ptr =
reinterpret_cast<phi::DenseTensor *>(arg->manager_ctx);
pybind11::handle tensor_handle = pybind11::cast(tensor_ptr);
tensor_handle.dec_ref();
delete[] arg->dl_tensor.shape; delete[] arg->dl_tensor.shape;
delete[] arg->dl_tensor.strides; delete[] arg->dl_tensor.strides;
delete arg; delete arg;
}; };
tensor->manager_ctx = nullptr;
return tensor; return tensor;
} }
......
...@@ -28,7 +28,7 @@ class DLPackTensor { ...@@ -28,7 +28,7 @@ class DLPackTensor {
std::remove_reference<decltype(::DLTensor::shape[0])>::type; // int64_t std::remove_reference<decltype(::DLTensor::shape[0])>::type; // int64_t
// lanes is only used in CPU to enable vectorization // lanes is only used in CPU to enable vectorization
explicit DLPackTensor(phi::DenseTensor& tensor, LaneType lanes = 1); explicit DLPackTensor(const phi::DenseTensor& tensor, LaneType lanes = 1);
inline operator const ::DLTensor&() const { return t_; } inline operator const ::DLTensor&() const { return t_; }
...@@ -42,8 +42,9 @@ class DLPackTensor { ...@@ -42,8 +42,9 @@ class DLPackTensor {
// The shape in DLTensor is defined as int64_t* // The shape in DLTensor is defined as int64_t*
// Add this member to make TVMTensor init without heap allocation // Add this member to make TVMTensor init without heap allocation
ShapeType shape_[DDim::kMaxRank]; ShapeType shape_[DDim::kMaxRank];
phi::DenseTensor* dt_;
}; };
DLManagedTensor* toDLPack(const phi::DenseTensor& src);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -473,10 +473,7 @@ void BindTensor(pybind11::module &m) { // NOLINT ...@@ -473,10 +473,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
)DOC") )DOC")
.def("_to_dlpack", .def("_to_dlpack",
[](phi::DenseTensor &self) { [](phi::DenseTensor &self) {
pybind11::handle tensor_handle = pybind11::cast(&self); DLManagedTensor *dmt = framework::toDLPack(self);
tensor_handle.inc_ref();
DLPackTensor dlpack_tensor(self, 1);
DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor();
auto capsule = pybind11::capsule( auto capsule = pybind11::capsule(
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) { static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
if (!PyCapsule_IsValid(ptr, "dltensor")) { if (!PyCapsule_IsValid(ptr, "dltensor")) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册