diff --git a/cmake/external/dlpack.cmake b/cmake/external/dlpack.cmake index 87db181d953afb5bfb17d3167f1e5efac3353b79..43ffde75992266c432f602e54bad8cbc70c17f86 100644 --- a/cmake/external/dlpack.cmake +++ b/cmake/external/dlpack.cmake @@ -18,7 +18,7 @@ set(DLPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/dlpack) set(DLPACK_SOURCE_DIR ${THIRD_PARTY_PATH}/dlpack/src/extern_dlpack) set(DLPACK_REPOSITORY ${GIT_URL}/dmlc/dlpack.git) -set(DLPACK_TAG v0.2) +set(DLPACK_TAG v0.4) cache_third_party(extern_dlpack REPOSITORY ${DLPACK_REPOSITORY} diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index f1f5ba7789ea6137800e7fcfe2d404ca2d87845b..71b53b8a51882fbb3a130737e5b80a5460bad2cb 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -30,14 +30,10 @@ static ::DLDataType GetDLDataTypeCode() { ::DLDataType dtype; if (std::is_same>::value || std::is_same>::value) { - // The current dlpack library version is v0.2, and does not define - // kDLComplex value. But kDLComplex is defined by 5U in v0.4, so we set - // dtype.code to 5U directly here. After the dlpack library version being - // upgraded to v0.4, it should be written as follow. - // dtype.code = kDLComplex; - dtype.code = 5U; + dtype.code = kDLComplex; + } else if (std::is_same::value) { + dtype.code = kDLBfloat; } else if (std::is_same::value || - std::is_same::value || std::is_floating_point::value) { dtype.code = kDLFloat; } else if (std::is_unsigned::value) { @@ -77,47 +73,47 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) { #undef REG_DL_DATA_TYPE } -struct DLContextVisitor : public boost::static_visitor<::DLContext> { - inline ::DLContext operator()(const platform::CPUPlace &place) const { - ::DLContext ctx; - ctx.device_type = kDLCPU; - ctx.device_id = 0; - return ctx; +struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> { + inline ::DLDevice operator()(const platform::CPUPlace &place) const { + ::DLDevice device; + device.device_type = kDLCPU; + device.device_id = 0; + return device; } - inline ::DLContext operator()(const platform::XPUPlace &place) const { + inline ::DLDevice operator()(const platform::XPUPlace &place) const { PADDLE_THROW( platform::errors::Unimplemented("platform::XPUPlace is not supported")); } - inline ::DLContext operator()(const platform::NPUPlace &place) const { + inline ::DLDevice operator()(const platform::NPUPlace &place) const { PADDLE_THROW( platform::errors::Unimplemented("platform::NPUPlace is not supported")); } - inline ::DLContext operator()(const platform::NPUPinnedPlace &place) const { + inline ::DLDevice operator()(const platform::NPUPinnedPlace &place) const { PADDLE_THROW(platform::errors::Unimplemented( "platform::NPUPinnedPlace is not supported")); } - inline ::DLContext operator()(const platform::CUDAPlace &place) const { + inline ::DLDevice operator()(const platform::CUDAPlace &place) const { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - ::DLContext ctx; - ctx.device_type = kDLGPU; - ctx.device_id = place.device; - return ctx; + ::DLDevice device; + device.device_type = kDLGPU; + device.device_id = place.device; + return device; #else PADDLE_THROW(platform::errors::Unavailable( "platform::CUDAPlace is not supported in CPU only version.")); #endif } - inline ::DLContext operator()(const platform::CUDAPinnedPlace &place) const { + inline ::DLDevice operator()(const platform::CUDAPinnedPlace &place) const { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - ::DLContext ctx; - ctx.device_type = kDLCPUPinned; - ctx.device_id = 0; - return ctx; + ::DLDevice device; + device.device_type = kDLCPUPinned; + device.device_id = 0; + return device; #else PADDLE_THROW(platform::errors::Unavailable( "platform::CUDAPinnedPlace is not supported in CPU only version.")); @@ -130,9 +126,9 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { // init data, data buffer t_.data = const_cast(tensor.data()); - // init ctx, DLContext type with device_type and device_id + // init device, DLDevice type with device_type and device_id auto place = tensor.place(); - t_.ctx = boost::apply_visitor(internal::DLContextVisitor(), place); + t_.device = boost::apply_visitor(internal::DLDeviceVisitor(), place); // init dtype t_.dtype = internal::GetDLDataTypeFromTypeIndex(tensor.type()); @@ -156,10 +152,8 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { t_.byte_offset = 0; } -::DLManagedTensor *DLPackTensor::ToCudfCompatibleDLManagedTensor() { - // init shape, tensor dims - // for DLManagedTensor shape need to be compatible with ndim - // refer to cupy and cudf, we new int64[ndim] +::DLManagedTensor *DLPackTensor::ToDLManagedTensor() { + // init shape auto shape = new int64_t[t_.ndim]; using DimType = decltype(t_.ndim); // int for (DimType i = 0; i < t_.ndim; ++i) { @@ -167,19 +161,15 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { } t_.shape = shape; - // init strides, nullptr means the tensor is compact - // refer to cupy and cudf, the compact tensor first dim's strides need to be 1 - // and second dim's strides need to be length of rows of cudf - // cudf now only support dim=2 - PADDLE_ENFORCE_LE(t_.ndim, 2, platform::errors::InvalidArgument( - "cudf now only supports dimension is 2, " - "but received dimension is %d.", - t_.ndim)); - - if (t_.ndim > 1) - t_.strides = new int64_t[2]{1, t_.shape[1]}; - else - t_.strides = new int64_t[1]{1}; + // init strides + auto strides = new int64_t[t_.ndim]; + for (DimType i = 0; i < t_.ndim; ++i) { + strides[i] = 1; + } + for (DimType i = t_.ndim - 2; i >= 0; --i) { + strides[i] = t_.shape[i + 1] * strides[i + 1]; + } + t_.strides = strides; auto tensor = new DLManagedTensor; tensor->dl_tensor = t_; diff --git a/paddle/fluid/framework/dlpack_tensor.h b/paddle/fluid/framework/dlpack_tensor.h index e342523718b34b3e32e54d0ffd14128a43df34f7..03ed8884925ce4e39912cf916ae16466a8334062 100644 --- a/paddle/fluid/framework/dlpack_tensor.h +++ b/paddle/fluid/framework/dlpack_tensor.h @@ -36,7 +36,7 @@ class DLPackTensor { inline operator ::DLTensor&() { return t_; } - ::DLManagedTensor* ToCudfCompatibleDLManagedTensor(); + ::DLManagedTensor* ToDLManagedTensor(); private: ::DLTensor t_; diff --git a/paddle/fluid/framework/dlpack_tensor_test.cc b/paddle/fluid/framework/dlpack_tensor_test.cc index 8265d105accae0b8a009b1798a6c36053b51ab25..4e2d7bb979b617b9ef088b419a6ab48ed3c79f1d 100644 --- a/paddle/fluid/framework/dlpack_tensor_test.cc +++ b/paddle/fluid/framework/dlpack_tensor_test.cc @@ -30,7 +30,11 @@ template constexpr uint8_t GetDLDataTypeCode() { if (std::is_same>::value || std::is_same>::value) { - return static_cast(5); + return static_cast(kDLComplex); + } + + if (std::is_same::value) { + return static_cast(kDLBfloat); } return std::is_same::value || @@ -55,15 +59,15 @@ void TestMain(const platform::Place &place, uint16_t lanes) { CHECK_EQ(p, dl_tensor.data); if (platform::is_cpu_place(place)) { - CHECK_EQ(kDLCPU, dl_tensor.ctx.device_type); - CHECK_EQ(0, dl_tensor.ctx.device_id); + CHECK_EQ(kDLCPU, dl_tensor.device.device_type); + CHECK_EQ(0, dl_tensor.device.device_id); } else if (platform::is_gpu_place(place)) { - CHECK_EQ(kDLGPU, dl_tensor.ctx.device_type); + CHECK_EQ(kDLGPU, dl_tensor.device.device_type); CHECK_EQ(BOOST_GET_CONST(platform::CUDAPlace, place).device, - dl_tensor.ctx.device_id); + dl_tensor.device.device_id); } else if (platform::is_cuda_pinned_place(place)) { - CHECK_EQ(kDLCPUPinned, dl_tensor.ctx.device_type); - CHECK_EQ(0, dl_tensor.ctx.device_id); + CHECK_EQ(kDLCPUPinned, dl_tensor.device.device_type); + CHECK_EQ(0, dl_tensor.device.device_id); } else { CHECK_EQ(false, true); } @@ -83,8 +87,7 @@ void TestMain(const platform::Place &place, uint16_t lanes) { } template -void TestToCudfCompatibleDLManagedTensor(const platform::Place &place, - uint16_t lanes) { +void TestToDLManagedTensor(const platform::Place &place, uint16_t lanes) { DDim dims{6, 7}; Tensor tensor; tensor.Resize(dims); @@ -92,8 +95,7 @@ void TestToCudfCompatibleDLManagedTensor(const platform::Place &place, DLPackTensor dlpack_tensor(tensor, lanes); - ::DLManagedTensor *dl_managed_tensor = - dlpack_tensor.ToCudfCompatibleDLManagedTensor(); + ::DLManagedTensor *dl_managed_tensor = dlpack_tensor.ToDLManagedTensor(); CHECK_EQ(dl_managed_tensor->manager_ctx == nullptr, true); @@ -101,7 +103,8 @@ void TestToCudfCompatibleDLManagedTensor(const platform::Place &place, CHECK_EQ(dims[i], dl_managed_tensor->dl_tensor.shape[i]); } - CHECK_EQ(dl_managed_tensor->dl_tensor.strides[0] == 1, true); + CHECK_EQ(dl_managed_tensor->dl_tensor.strides[0] == 7, true); + CHECK_EQ(dl_managed_tensor->dl_tensor.strides[1] == 1, true); dl_managed_tensor->deleter(dl_managed_tensor); } @@ -122,7 +125,7 @@ void TestMainLoop() { for (auto &p : places) { for (auto &l : lanes) { TestMain(p, l); - TestToCudfCompatibleDLManagedTensor(p, l); + TestToDLManagedTensor(p, l); } } } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 15021b6267b65604e73abefbd7d8f683942218e7..ee30a82aff6ef050faf309c925ea1d8b8809140c 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -1065,6 +1065,9 @@ void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst, if (type.code == kDLFloat) return static_cast( dst->mutable_data(dst_place)); + if (type.code == kDLBfloat) + return static_cast( + dst->mutable_data(dst_place)); PADDLE_THROW(platform::errors::Unimplemented( "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", type.code, type.bits)); @@ -1081,6 +1084,16 @@ void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst, return static_cast(dst->mutable_data(dst_place)); if (type.code == kDLFloat) return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLComplex) + return static_cast( + dst->mutable_data>(dst_place)); + PADDLE_THROW(platform::errors::Unimplemented( + "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", + type.code, type.bits)); + case 128: + if (type.code == kDLComplex) + return static_cast( + dst->mutable_data>(dst_place)); PADDLE_THROW(platform::errors::Unimplemented( "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", type.code, type.bits)); @@ -1107,15 +1120,15 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst) { auto src_ptr = static_cast(dl_tensor.data); auto size = paddle::framework::product(vddim) * type.bits / 8; - if (dl_tensor.ctx.device_type == kDLCPU) { + if (dl_tensor.device.device_type == kDLCPU) { memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (dl_tensor.ctx.device_type == kDLGPU) { + if (dl_tensor.device.device_type == kDLGPU) { platform::CUDAPlace dst_place = - platform::CUDAPlace(dl_tensor.ctx.device_id); + platform::CUDAPlace(dl_tensor.device.device_id); platform::CUDAPlace src_place = - platform::CUDAPlace(dl_tensor.ctx.device_id); + platform::CUDAPlace(dl_tensor.device.device_id); dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(dst_place); memory::Copy( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c00f529f61793f4c06fec7f0e6ee41bd5aec7733..16e42885c52fb7e61fbe8703d8064a6937587c6c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -537,11 +537,11 @@ PYBIND11_MODULE(core_noavx, m) { DLTensor dl = dmt->dl_tensor; framework::Tensor tensor; - if (dl.ctx.device_type == kDLCPU) { + if (dl.device.device_type == kDLCPU) { paddle::framework::TensorFromDLPack(dl, &tensor); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (dl.ctx.device_type == kDLGPU) { + if (dl.device.device_type == kDLGPU) { paddle::framework::TensorFromDLPack(dl, &tensor); } #endif @@ -776,8 +776,7 @@ PYBIND11_MODULE(core_noavx, m) { .def("_to_dlpack", [](framework::Tensor &self) { DLPackTensor dlpack_tensor(self, 1); - DLManagedTensor *dmt = - dlpack_tensor.ToCudfCompatibleDLManagedTensor(); + DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor(); auto capsule = py::capsule( static_cast(dmt), "dltensor", [](PyObject *ptr) { if (ptr) { diff --git a/python/paddle/tests/test_dlpack.py b/python/paddle/tests/test_dlpack.py index 2880901d1ad16103bb9e0cbd21e1438812e1a03e..3a3f748bb991e78fa579c8c94bb80cb190e25e02 100644 --- a/python/paddle/tests/test_dlpack.py +++ b/python/paddle/tests/test_dlpack.py @@ -22,6 +22,7 @@ import paddle.fluid.core as core class TestDLPack(unittest.TestCase): def test_dlpack_dygraph(self): + paddle.disable_static() tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int')) dlpack = paddle.utils.dlpack.to_dlpack(tensor) out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) @@ -31,6 +32,15 @@ class TestDLPack(unittest.TestCase): np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype( 'int'))) + def test_dlpack_tensor_larger_than_2dim(self): + paddle.disable_static() + numpy_data = np.random.randn(4, 5, 6) + t = paddle.to_tensor(numpy_data) + # TODO: There may be a reference count problem of to_dlpack. + dlpack = paddle.utils.dlpack.to_dlpack(t) + out = paddle.utils.dlpack.from_dlpack(dlpack) + self.assertTrue(np.allclose(numpy_data, out.numpy())) + def test_dlpack_static(self): paddle.enable_static() tensor = fluid.create_lod_tensor( @@ -57,6 +67,37 @@ class TestDLPack(unittest.TestCase): np.array(gout_from_dlpack), np.array([[1], [2], [3], [4]]).astype('int'))) + def test_dlpack_dtype_conversion(self): + paddle.disable_static() + # DLpack does not explicitly support bool data type. + dtypes = [ + "float16", + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + ] + data = np.ones((2, 3, 4)) + for dtype in dtypes: + x = paddle.to_tensor(data, dtype=dtype) + dlpack = paddle.utils.dlpack.to_dlpack(x) + o = paddle.utils.dlpack.from_dlpack(dlpack) + self.assertEqual(x.dtype, o.dtype) + self.assertTrue(np.allclose(x.numpy(), o.numpy())) + + complex_dtypes = ["complex64", "complex128"] + for dtype in complex_dtypes: + x = paddle.to_tensor( + [[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]], + dtype=dtype) + dlpack = paddle.utils.dlpack.to_dlpack(x) + o = paddle.utils.dlpack.from_dlpack(dlpack) + self.assertEqual(x.dtype, o.dtype) + self.assertTrue(np.allclose(x.numpy(), o.numpy())) + class TestRaiseError(unittest.TestCase): def test_from_dlpack_raise_type_error(self): diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index ca2a1ae0e19ec56c1e4d64022aed6992996d819f..01611be3ea56f1f18ba1e83bad3c69ed42a6e78b 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -28,7 +28,9 @@ def to_dlpack(x): Encodes a tensor to DLPack. Args: - x (Tensor): A tensor, and the data type is bool, float32, float64, int32, int64. + x (Tensor): The input tensor, and the data type can be `bool`, `float16`, `float32`, + `float64`, `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, + `complex128`. Returns: dltensor, and the data type is PyCapsule. @@ -51,19 +53,9 @@ def to_dlpack(x): "The type of 'x' in to_dlpack must be paddle.Tensor," " but received {}.".format(type(x))) - dtype = convert_dtype(x.dtype) - - if dtype not in ['bool', 'int32', 'int64', 'float32', 'float64']: - raise TypeError( - "the dtype of 'x' in to_dlpack must be any of [bool, int32, int64, " - "float32, float64], but received {}.".format(dtype)) - return x.value().get_tensor()._to_dlpack() check_type(x, 'x', (LoDTensor), 'to_dlpack') - check_dtype(x._dtype(), 'x', - ['bool', 'int32', 'int64', 'float32', 'float64'], 'to_dlpack') - return x._to_dlpack() @@ -75,7 +67,9 @@ def from_dlpack(dlpack): dlpack (PyCapsule): a PyCapsule object with the dltensor. Returns: - out (Tensor): a tensor decoded from DLPack. + out (Tensor): a tensor decoded from DLPack. One thing to be noted, if we get + an input dltensor with data type as `bool`, we return the decoded + tensor as `uint8`. Examples: .. code-block:: python