From 5da8db6116a3767377cc475d108ab5fbe023f7b2 Mon Sep 17 00:00:00 2001 From: 633WHU Date: Mon, 14 Oct 2019 19:36:43 +0800 Subject: [PATCH] support convert tensor to cudf depends on dlpack test=release/1.6 (#20611) --- cmake/inference_lib.cmake | 9 ++- paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/framework/dlpack_tensor.cc | 40 ++++++++++- paddle/fluid/framework/dlpack_tensor.h | 2 + paddle/fluid/framework/dlpack_tensor_test.cc | 25 +++++++ paddle/fluid/framework/tensor_util.cc | 76 ++++++++++++++++++++ paddle/fluid/framework/tensor_util.h | 4 ++ paddle/fluid/framework/tensor_util_test.cc | 63 ++++++++++++++++ paddle/fluid/platform/CMakeLists.txt | 2 +- paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/pybind.cc | 40 +++++++++++ paddle/fluid/train/demo/CMakeLists.txt | 1 + python/paddle/fluid/tests/test_lod_tensor.py | 24 +++++++ 13 files changed, 285 insertions(+), 7 deletions(-) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index a1a9dbbbd8..9dd65438d5 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -51,7 +51,7 @@ function(copy TARGET) endfunction() # third party -set(third_party_deps eigen3 gflags glog boost xxhash zlib) +set(third_party_deps eigen3 gflags glog boost xxhash zlib dlpack) if(NOT PROTOBUF_FOUND OR WIN32) list(APPEND third_party_deps extern_protobuf) endif () @@ -86,6 +86,11 @@ copy(inference_lib_dist SRCS ${BOOST_INCLUDE_DIR}/boost DSTS ${dst_dir}) +set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/dlpack") +copy(inference_lib_dist + SRCS ${DLPACK_INCLUDE_DIR}/dlpack + DSTS ${dst_dir}) + if(WITH_MKLML) set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/mklml") if(WIN32) @@ -258,4 +263,4 @@ function(version version_file) endif () endfunction() version(${FLUID_INSTALL_DIR}/version.txt) -version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt) \ No newline at end of file +version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index c7fa68c26a..49be7f0e23 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -51,9 +51,9 @@ endif() cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) if(WITH_GPU) - nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor) + nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor dlpack_tensor) else() - cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) + cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor dlpack_tensor) endif() cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 39652706c4..e15baa7c65 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/dlpack_tensor.h" +#include + #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/dlpack_tensor.h" namespace paddle { namespace framework { @@ -120,5 +122,41 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { t_.byte_offset = 0; } +::DLManagedTensor *DLPackTensor::ToCudfCompatibleDLManagedTensor() { + // init shape, tensor dims + // for DLManagedTensor shape need to be compatible with ndim + // refer to cupy and cudf, we new int64[ndim] + auto shape = new int64_t[t_.ndim]; + using DimType = decltype(t_.ndim); // int + for (DimType i = 0; i < t_.ndim; ++i) { + shape[i] = t_.shape[i]; + } + t_.shape = shape; + + // init strides, nullptr means the tensor is compact + // refer to cupy and cudf, the compact tensor first dim's strides need to be 1 + // and second dim's strides need to be length of rows of cudf + // cudf now only support dim=2 + PADDLE_ENFORCE_LE(t_.ndim, 2, "cudf now only support dim=2."); + + if (t_.ndim > 1) + t_.strides = new int64_t[2]{1, t_.shape[1]}; + else + t_.strides = new int64_t[1]{1}; + + auto tensor = new DLManagedTensor; + tensor->dl_tensor = t_; + + tensor->deleter = [](DLManagedTensor *arg) { + delete[] arg->dl_tensor.shape; + delete[] arg->dl_tensor.strides; + delete arg; + }; + + tensor->manager_ctx = nullptr; + + return tensor; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/dlpack_tensor.h b/paddle/fluid/framework/dlpack_tensor.h index e48b0d5c88..5346ba6289 100644 --- a/paddle/fluid/framework/dlpack_tensor.h +++ b/paddle/fluid/framework/dlpack_tensor.h @@ -33,6 +33,8 @@ class DLPackTensor { inline operator ::DLTensor&() { return t_; } + ::DLManagedTensor* ToCudfCompatibleDLManagedTensor(); + private: ::DLTensor t_; diff --git a/paddle/fluid/framework/dlpack_tensor_test.cc b/paddle/fluid/framework/dlpack_tensor_test.cc index c0a8e1bcdf..35afc7e5f4 100644 --- a/paddle/fluid/framework/dlpack_tensor_test.cc +++ b/paddle/fluid/framework/dlpack_tensor_test.cc @@ -72,6 +72,30 @@ void TestMain(const platform::Place &place, uint16_t lanes) { CHECK_EQ(GetDLDataTypeCode(), dl_tensor.dtype.code); } +template +void TestToCudfCompatibleDLManagedTensor(const platform::Place &place, + uint16_t lanes) { + DDim dims{6, 7}; + Tensor tensor; + tensor.Resize(dims); + tensor.mutable_data(place); + + DLPackTensor dlpack_tensor(tensor, lanes); + + ::DLManagedTensor *dl_managed_tensor = + dlpack_tensor.ToCudfCompatibleDLManagedTensor(); + + CHECK_EQ(dl_managed_tensor->manager_ctx == nullptr, true); + + for (auto i = 0; i < dims.size(); ++i) { + CHECK_EQ(dims[i], dl_managed_tensor->dl_tensor.shape[i]); + } + + CHECK_EQ(dl_managed_tensor->dl_tensor.strides[0] == 1, true); + + dl_managed_tensor->deleter(dl_managed_tensor); +} + template void TestMainLoop() { #ifdef PADDLE_WITH_CUDA @@ -88,6 +112,7 @@ void TestMainLoop() { for (auto &p : places) { for (auto &l : lanes) { TestMain(p, l); + TestToCudfCompatibleDLManagedTensor(p, l); } } } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index fb6cc1f210..812612580c 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -495,6 +495,82 @@ void TensorFromStream(std::istream& is, Tensor* tensor, } } +// get tensor data point by DLDataType +void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst, + const platform::Place& dst_place) { + // vector types not currently supported + PADDLE_ENFORCE_LE(type.lanes, 1, "vector types not currently supported"); + + switch (type.bits) { + case 8: + if (type.code == kDLInt) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLUInt) + return static_cast(dst->mutable_data(dst_place)); + PADDLE_THROW("There is no this type.code <%d> when type.bits is <%d>.", + type.code, type.bits); + case 16: + if (type.code == kDLInt) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLFloat) + return static_cast( + dst->mutable_data(dst_place)); + PADDLE_THROW("There is no this type.code <%d> when type.bits is <%d>.", + type.code, type.bits); + case 32: + if (type.code == kDLInt) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLFloat) + return static_cast(dst->mutable_data(dst_place)); + PADDLE_THROW("There is no this type.code <%d> when type.bits is <%d>.", + type.code, type.bits); + case 64: + if (type.code == kDLInt) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLFloat) + return static_cast(dst->mutable_data(dst_place)); + PADDLE_THROW("There is no this type.code <%d> when type.bits is <%d>.", + type.code, type.bits); + default: + PADDLE_THROW("Unsupport type.bits %d", type.bits); + } +} + +void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst) { + platform::CPUPlace dst_place = platform::CPUPlace(); + platform::CPUPlace src_place = platform::CPUPlace(); + + std::vector vec; + std::copy(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim, + std::back_inserter(vec)); + + framework::DDim vddim = framework::make_ddim(vec); + + dst->Resize(vddim); + ::DLDataType type = dl_tensor.dtype; + void* dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); + + auto src_ptr = static_cast(dl_tensor.data); + auto size = paddle::framework::product(vddim) * type.bits / 8; + + if (dl_tensor.ctx.device_type == kDLCPU) { + memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } +#ifdef PADDLE_WITH_CUDA + if (dl_tensor.ctx.device_type == kDLGPU) { + platform::CUDAPlace dst_place = + platform::CUDAPlace(dl_tensor.ctx.device_id); + platform::CUDAPlace src_place = + platform::CUDAPlace(dl_tensor.ctx.device_id); + dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(dst_place); + memory::Copy( + dst_place, dst_ptr, src_place, src_ptr, size, + reinterpret_cast(*ctx).stream()); + } +#endif +} + template std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { auto inspect = tensor.data(); diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index cab72e294f..dd535dfb6b 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/dlpack_tensor.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/tensor.h" @@ -72,6 +73,9 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, void TensorFromStream(std::istream& is, Tensor* tensor, const platform::DeviceContext& dev_ctx); +// convert dlpack's DLTensor to tensor +void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst); + // // The implementation of template functions. // diff --git a/paddle/fluid/framework/tensor_util_test.cc b/paddle/fluid/framework/tensor_util_test.cc index 17c5537817..bf9dabcd23 100644 --- a/paddle/fluid/framework/tensor_util_test.cc +++ b/paddle/fluid/framework/tensor_util_test.cc @@ -242,6 +242,69 @@ TEST(TensorToVector, Tensor) { #endif } +TEST(TensorFromDLPack, Tensor) { + { + std::vector src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + paddle::framework::Tensor cpu_tensor; + + cpu_tensor.Resize(paddle::framework::make_ddim({3, 3})); + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext cpu_ctx(cpu_place); + paddle::framework::TensorFromVector(src_vec, cpu_ctx, &cpu_tensor); + paddle::framework::DLPackTensor dlpack_tensor(cpu_tensor, 1); + + paddle::framework::Tensor dst_tensor; + paddle::framework::TensorFromDLPack(dlpack_tensor, &dst_tensor); + + auto cpu_ptr = cpu_tensor.data(); + auto src_ptr = dst_tensor.data(); + EXPECT_NE(src_ptr, cpu_ptr); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + } + } + +#ifdef PADDLE_WITH_CUDA + { + std::vector src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + paddle::framework::Tensor cpu_tensor; + paddle::framework::Tensor gpu_tensor; + paddle::framework::Tensor dst_tensor; + paddle::framework::Tensor gpu_tensor_from_dlpack; + + // Copy to CPU Tensor + cpu_tensor.Resize(make_ddim({3, 3})); + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext cpu_ctx(cpu_place); + paddle::framework::TensorFromVector(src_vec, cpu_ctx, &cpu_tensor); + + // Copy to GPUTensor + gpu_tensor.Resize(paddle::framework::make_ddim({3, 3})); + paddle::platform::CUDAPlace gpu_place; + paddle::platform::CUDADeviceContext gpu_ctx(gpu_place); + paddle::framework::TensorFromVector(src_vec, gpu_ctx, &gpu_tensor); + + paddle::framework::DLPackTensor dlpack_tensor(gpu_tensor, 1); + + paddle::framework::TensorFromDLPack(dlpack_tensor, &gpu_tensor_from_dlpack); + // Copy from GPU to CPU tensor for comparison + paddle::framework::TensorCopy(gpu_tensor_from_dlpack, cpu_place, gpu_ctx, + &dst_tensor); + // Sync before Compare Tensors + gpu_ctx.Wait(); + const int* src_ptr = src_vec.data(); + const int* cpu_ptr = cpu_tensor.data(); + const int* dst_ptr = dst_tensor.data(); + EXPECT_NE(src_ptr, cpu_ptr); + EXPECT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + EXPECT_EQ(src_ptr[i], dst_ptr[i]); + } + } +#endif +} + TEST(TensorContainsNAN, CPU) { { paddle::framework::Tensor src; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index a84f521f58..001fcfea42 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -72,7 +72,7 @@ ENDIF() # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} - ${dgc_deps}) + ${dgc_deps} dlpack) if (WITH_DISTRIBUTE) cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index bd78027f4b..9563e7b6fe 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper nccl_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool - analysis_predictor imperative_profiler nccl_context imperative_flag save_load_util) + analysis_predictor imperative_profiler nccl_context imperative_flag save_load_util dlpack_tensor) if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 9858b6c8a1..cb6a77f29f 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -256,6 +256,24 @@ PYBIND11_MODULE(core_noavx, m) { BindException(&m); + m.def("from_dlpack", [](py::capsule *dltensor) { + DLManagedTensor *dmt = reinterpret_cast( + PyCapsule_GetPointer(dltensor->ptr(), "dltensor")); + PyCapsule_SetName(dltensor->ptr(), "used_dltensor"); + DLTensor dl = dmt->dl_tensor; + Tensor tensor; + + if (dl.ctx.device_type == kDLCPU) { + paddle::framework::TensorFromDLPack(dl, &tensor); + } +#ifdef PADDLE_WITH_CUDA + if (dl.ctx.device_type == kDLGPU) { + paddle::framework::TensorFromDLPack(dl, &tensor); + } +#endif + return tensor; + }); + m.def("set_num_threads", &platform::SetNumThreads); m.def("_save_static_dict", @@ -467,6 +485,28 @@ PYBIND11_MODULE(core_noavx, m) { t.set(np.ndarray([5, 30]), fluid.CPUPlace()) print(t.shape()) # [5, 30] )DOC") + .def("_to_dlpack", + [](Tensor &self) { + DLPackTensor dlpack_tensor(self, 1); + DLManagedTensor *dmt = + dlpack_tensor.ToCudfCompatibleDLManagedTensor(); + auto capsule = py::capsule( + static_cast(dmt), "dltensor", [](PyObject *ptr) { + if (ptr) { + auto dltensor = new DLManagedTensor; + try { + dltensor = reinterpret_cast( + PyCapsule_GetPointer(ptr, "used_dltensor")); + return; + } catch (...) { + dltensor = reinterpret_cast( + PyCapsule_GetPointer(ptr, "dltensor")); + } + dltensor->deleter(dltensor); + } + }); + return capsule; + }) .def("_set_float_element", TensorSetElement) .def("_get_float_element", TensorGetElement) .def("_set_double_element", TensorSetElement) diff --git a/paddle/fluid/train/demo/CMakeLists.txt b/paddle/fluid/train/demo/CMakeLists.txt index 5a370b813f..a15ddc9273 100644 --- a/paddle/fluid/train/demo/CMakeLists.txt +++ b/paddle/fluid/train/demo/CMakeLists.txt @@ -20,6 +20,7 @@ include_directories("${PADDLE_LIB}/third_party/install/zlib/include") include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") +include_directories("${PADDLE_LIB}/third_party/dlpack") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib") diff --git a/python/paddle/fluid/tests/test_lod_tensor.py b/python/paddle/fluid/tests/test_lod_tensor.py index a3eae5a3c8..00bfb84602 100644 --- a/python/paddle/fluid/tests/test_lod_tensor.py +++ b/python/paddle/fluid/tests/test_lod_tensor.py @@ -125,6 +125,30 @@ class TestLoDTensor(unittest.TestCase): print(gtensor) self.assertTrue(isinstance(str(gtensor), str)) + def test_dlpack_support(self): + tensor = fluid.create_lod_tensor( + np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]], + fluid.CPUPlace()) + dltensor = tensor._to_dlpack() + tensor_from_dlpack = fluid.core.from_dlpack(dltensor) + self.assertTrue(isinstance(tensor_from_dlpack, fluid.core.Tensor)) + self.assertTrue( + np.array_equal( + np.array(tensor_from_dlpack), + np.array([[1], [2], [3], [4]]).astype('int'))) + # when build with cuda + if core.is_compiled_with_cuda(): + gtensor = fluid.create_lod_tensor( + np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]], + fluid.CUDAPlace(0)) + gdltensor = gtensor._to_dlpack() + gtensor_from_dlpack = fluid.core.from_dlpack(gdltensor) + self.assertTrue(isinstance(gtensor_from_dlpack, fluid.core.Tensor)) + self.assertTrue( + np.array_equal( + np.array(gtensor_from_dlpack), + np.array([[1], [2], [3], [4]]).astype('int'))) + if __name__ == '__main__': unittest.main() -- GitLab