You need to sign in or sign up before continuing.
提交 12e4be03 编写于 作者: 6 633WHU 提交者: Zeng Jinle

Dlpack support (#20039)

* support dlpack to tensor and implement python interface test=develop

* add unittest for _to_dlpack and from_dlpack test=develop
上级 5e65c753
...@@ -51,7 +51,7 @@ function(copy TARGET) ...@@ -51,7 +51,7 @@ function(copy TARGET)
endfunction() endfunction()
# third party # third party
set(third_party_deps eigen3 gflags glog boost xxhash zlib) set(third_party_deps eigen3 gflags glog boost xxhash zlib dlpack)
if(NOT PROTOBUF_FOUND OR WIN32) if(NOT PROTOBUF_FOUND OR WIN32)
list(APPEND third_party_deps extern_protobuf) list(APPEND third_party_deps extern_protobuf)
endif () endif ()
...@@ -86,6 +86,11 @@ copy(inference_lib_dist ...@@ -86,6 +86,11 @@ copy(inference_lib_dist
SRCS ${BOOST_INCLUDE_DIR}/boost SRCS ${BOOST_INCLUDE_DIR}/boost
DSTS ${dst_dir}) DSTS ${dst_dir})
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/dlpack")
copy(inference_lib_dist
SRCS ${DLPACK_INCLUDE_DIR}/dlpack
DSTS ${dst_dir})
if(WITH_MKLML) if(WITH_MKLML)
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/mklml") set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/mklml")
if(WIN32) if(WIN32)
...@@ -258,4 +263,4 @@ function(version version_file) ...@@ -258,4 +263,4 @@ function(version version_file)
endif () endif ()
endfunction() endfunction()
version(${FLUID_INSTALL_DIR}/version.txt) version(${FLUID_INSTALL_DIR}/version.txt)
version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt) version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt)
\ No newline at end of file
...@@ -51,9 +51,9 @@ endif() ...@@ -51,9 +51,9 @@ endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
if(WITH_GPU) if(WITH_GPU)
nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor) nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor dlpack_tensor)
else() else()
cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor dlpack_tensor)
endif() endif()
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
......
...@@ -11,9 +11,10 @@ ...@@ -11,9 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <unordered_map>
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/dlpack_tensor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -120,5 +121,41 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { ...@@ -120,5 +121,41 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
t_.byte_offset = 0; t_.byte_offset = 0;
} }
::DLManagedTensor *DLPackTensor::ToCudfCompatibleDLManagedTensor() {
// init shape, tensor dims
// for DLManagedTensor shape need to be compatible with ndim
// refer to cupy and cudf, we new int64[ndim]
auto shape = new int64_t[t_.ndim];
using DimType = decltype(t_.ndim); // int
for (DimType i = 0; i < t_.ndim; ++i) {
shape[i] = t_.shape[i];
}
t_.shape = shape;
// init strides, nullptr means the tensor is compact
// refer to cupy and cudf, the compact tensor first dim's strides need to be 1
// and second dim's strides need to be length of rows of cudf
// cudf now only support dim=2
PADDLE_ENFORCE_LE(t_.ndim, 2, "cudf now only support dim=2.");
if (t_.ndim > 1)
t_.strides = new int64_t[2]{1, t_.shape[1]};
else
t_.strides = new int64_t[1]{1};
auto tensor = new DLManagedTensor;
tensor->dl_tensor = t_;
tensor->deleter = [](DLManagedTensor *arg) {
delete[] arg->dl_tensor.shape;
delete[] arg->dl_tensor.strides;
delete arg;
};
tensor->manager_ctx = nullptr;
return tensor;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -33,6 +33,8 @@ class DLPackTensor { ...@@ -33,6 +33,8 @@ class DLPackTensor {
inline operator ::DLTensor&() { return t_; } inline operator ::DLTensor&() { return t_; }
::DLManagedTensor* ToCudfCompatibleDLManagedTensor();
private: private:
::DLTensor t_; ::DLTensor t_;
......
...@@ -72,6 +72,30 @@ void TestMain(const platform::Place &place, uint16_t lanes) { ...@@ -72,6 +72,30 @@ void TestMain(const platform::Place &place, uint16_t lanes) {
CHECK_EQ(GetDLDataTypeCode<T>(), dl_tensor.dtype.code); CHECK_EQ(GetDLDataTypeCode<T>(), dl_tensor.dtype.code);
} }
template <typename T>
void TestToCudfCompatibleDLManagedTensor(const platform::Place &place,
uint16_t lanes) {
DDim dims{6, 7};
Tensor tensor;
tensor.Resize(dims);
tensor.mutable_data<T>(place);
DLPackTensor dlpack_tensor(tensor, lanes);
::DLManagedTensor *dl_managed_tensor =
dlpack_tensor.ToCudfCompatibleDLManagedTensor();
CHECK_EQ(dl_managed_tensor->manager_ctx == nullptr, true);
for (auto i = 0; i < dims.size(); ++i) {
CHECK_EQ(dims[i], dl_managed_tensor->dl_tensor.shape[i]);
}
CHECK_EQ(dl_managed_tensor->dl_tensor.strides[0] == 1, true);
dl_managed_tensor->deleter(dl_managed_tensor);
}
template <typename T> template <typename T>
void TestMainLoop() { void TestMainLoop() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -88,6 +112,7 @@ void TestMainLoop() { ...@@ -88,6 +112,7 @@ void TestMainLoop() {
for (auto &p : places) { for (auto &p : places) {
for (auto &l : lanes) { for (auto &l : lanes) {
TestMain<T>(p, l); TestMain<T>(p, l);
TestToCudfCompatibleDLManagedTensor<T>(p, l);
} }
} }
} }
......
...@@ -495,6 +495,82 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -495,6 +495,82 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
} }
} }
// get tensor data point by DLDataType
void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst,
const platform::Place& dst_place) {
// vector types not currently supported
PADDLE_ENFORCE_LE(type.lanes, 1, "vector types not currently supported");
switch (type.bits) {
case 8:
if (type.code == kDLInt)
return static_cast<void*>(dst->mutable_data<int8_t>(dst_place));
if (type.code == kDLUInt)
return static_cast<void*>(dst->mutable_data<uint8_t>(dst_place));
PADDLE_THROW("There is no this type.code <%d> when type.bits is <%d>.",
type.code, type.bits);
case 16:
if (type.code == kDLInt)
return static_cast<void*>(dst->mutable_data<int16_t>(dst_place));
if (type.code == kDLFloat)
return static_cast<void*>(
dst->mutable_data<paddle::platform::float16>(dst_place));
PADDLE_THROW("There is no this type.code <%d> when type.bits is <%d>.",
type.code, type.bits);
case 32:
if (type.code == kDLInt)
return static_cast<void*>(dst->mutable_data<int32_t>(dst_place));
if (type.code == kDLFloat)
return static_cast<void*>(dst->mutable_data<float>(dst_place));
PADDLE_THROW("There is no this type.code <%d> when type.bits is <%d>.",
type.code, type.bits);
case 64:
if (type.code == kDLInt)
return static_cast<void*>(dst->mutable_data<int64_t>(dst_place));
if (type.code == kDLFloat)
return static_cast<void*>(dst->mutable_data<double>(dst_place));
PADDLE_THROW("There is no this type.code <%d> when type.bits is <%d>.",
type.code, type.bits);
default:
PADDLE_THROW("Unsupport type.bits %d", type.bits);
}
}
void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst) {
platform::CPUPlace dst_place = platform::CPUPlace();
platform::CPUPlace src_place = platform::CPUPlace();
std::vector<int64_t> vec;
std::copy(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim,
std::back_inserter(vec));
framework::DDim vddim = framework::make_ddim(vec);
dst->Resize(vddim);
::DLDataType type = dl_tensor.dtype;
void* dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place);
auto src_ptr = static_cast<const void*>(dl_tensor.data);
auto size = paddle::framework::product(vddim) * type.bits / 8;
if (dl_tensor.ctx.device_type == kDLCPU) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
if (dl_tensor.ctx.device_type == kDLGPU) {
platform::CUDAPlace dst_place =
platform::CUDAPlace(dl_tensor.ctx.device_id);
platform::CUDAPlace src_place =
platform::CUDAPlace(dl_tensor.ctx.device_id);
dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place);
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(dst_place);
memory::Copy(
dst_place, dst_ptr, src_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(*ctx).stream());
}
#endif
}
template <typename T> template <typename T>
std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<T>(); auto inspect = tensor.data<T>();
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -72,6 +73,9 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -72,6 +73,9 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
void TensorFromStream(std::istream& is, Tensor* tensor, void TensorFromStream(std::istream& is, Tensor* tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
// convert dlpack's DLTensor to tensor
void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst);
// //
// The implementation of template functions. // The implementation of template functions.
// //
......
...@@ -242,6 +242,69 @@ TEST(TensorToVector, Tensor) { ...@@ -242,6 +242,69 @@ TEST(TensorToVector, Tensor) {
#endif #endif
} }
TEST(TensorFromDLPack, Tensor) {
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
paddle::framework::Tensor cpu_tensor;
cpu_tensor.Resize(paddle::framework::make_ddim({3, 3}));
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext cpu_ctx(cpu_place);
paddle::framework::TensorFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
paddle::framework::DLPackTensor dlpack_tensor(cpu_tensor, 1);
paddle::framework::Tensor dst_tensor;
paddle::framework::TensorFromDLPack(dlpack_tensor, &dst_tensor);
auto cpu_ptr = cpu_tensor.data<int>();
auto src_ptr = dst_tensor.data<int>();
EXPECT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
}
#ifdef PADDLE_WITH_CUDA
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
paddle::framework::Tensor cpu_tensor;
paddle::framework::Tensor gpu_tensor;
paddle::framework::Tensor dst_tensor;
paddle::framework::Tensor gpu_tensor_from_dlpack;
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext cpu_ctx(cpu_place);
paddle::framework::TensorFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
// Copy to GPUTensor
gpu_tensor.Resize(paddle::framework::make_ddim({3, 3}));
paddle::platform::CUDAPlace gpu_place;
paddle::platform::CUDADeviceContext gpu_ctx(gpu_place);
paddle::framework::TensorFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
paddle::framework::DLPackTensor dlpack_tensor(gpu_tensor, 1);
paddle::framework::TensorFromDLPack(dlpack_tensor, &gpu_tensor_from_dlpack);
// Copy from GPU to CPU tensor for comparison
paddle::framework::TensorCopy(gpu_tensor_from_dlpack, cpu_place, gpu_ctx,
&dst_tensor);
// Sync before Compare Tensors
gpu_ctx.Wait();
const int* src_ptr = src_vec.data();
const int* cpu_ptr = cpu_tensor.data<int>();
const int* dst_ptr = dst_tensor.data<int>();
EXPECT_NE(src_ptr, cpu_ptr);
EXPECT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
}
#endif
}
TEST(TensorContainsNAN, CPU) { TEST(TensorContainsNAN, CPU) {
{ {
paddle::framework::Tensor src; paddle::framework::Tensor src;
......
...@@ -72,7 +72,7 @@ ENDIF() ...@@ -72,7 +72,7 @@ ENDIF()
# avoiding cycle dependencies # avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps}) ${dgc_deps} dlpack)
if (WITH_DISTRIBUTE) if (WITH_DISTRIBUTE)
cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce)
......
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper nccl_wrapper prune set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper nccl_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler nccl_context imperative_flag save_load_util) analysis_predictor imperative_profiler nccl_context imperative_flag save_load_util dlpack_tensor)
if(WITH_PYTHON) if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op) list(APPEND PYBIND_DEPS py_func_op)
......
...@@ -258,6 +258,24 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -258,6 +258,24 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("set_num_threads", &platform::SetNumThreads); m.def("set_num_threads", &platform::SetNumThreads);
m.def("from_dlpack", [](py::capsule *dltensor) {
DLManagedTensor *dmt = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(dltensor->ptr(), "dltensor"));
PyCapsule_SetName(dltensor->ptr(), "used_dltensor");
DLTensor dl = dmt->dl_tensor;
Tensor tensor;
if (dl.ctx.device_type == kDLCPU) {
paddle::framework::TensorFromDLPack(dl, &tensor);
}
#ifdef PADDLE_WITH_CUDA
if (dl.ctx.device_type == kDLGPU) {
paddle::framework::TensorFromDLPack(dl, &tensor);
}
#endif
return tensor;
});
m.def("_save_static_dict", m.def("_save_static_dict",
[](const std::string &str_file_name, const py::handle &vec_var_list, [](const std::string &str_file_name, const py::handle &vec_var_list,
const Scope &scope) { const Scope &scope) {
...@@ -291,6 +309,7 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -291,6 +309,7 @@ PYBIND11_MODULE(core_noavx, m) {
return map_output; return map_output;
}); });
m.def("save_op_compatible_info", [](framework::ProgramDesc &desc) { m.def("save_op_compatible_info", [](framework::ProgramDesc &desc) {
framework::OpCompatibleMap op_compatible_map; framework::OpCompatibleMap op_compatible_map;
op_compatible_map.InitOpCompatibleMap(); op_compatible_map.InitOpCompatibleMap();
...@@ -467,6 +486,28 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -467,6 +486,28 @@ PYBIND11_MODULE(core_noavx, m) {
t.set(np.ndarray([5, 30]), fluid.CPUPlace()) t.set(np.ndarray([5, 30]), fluid.CPUPlace())
print(t.shape()) # [5, 30] print(t.shape()) # [5, 30]
)DOC") )DOC")
.def("_to_dlpack",
[](Tensor &self) {
DLPackTensor dlpack_tensor(self, 1);
DLManagedTensor *dmt =
dlpack_tensor.ToCudfCompatibleDLManagedTensor();
auto capsule = py::capsule(
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
if (ptr) {
auto dltensor = new DLManagedTensor;
try {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "used_dltensor"));
return;
} catch (...) {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
}
dltensor->deleter(dltensor);
}
});
return capsule;
})
.def("_set_float_element", TensorSetElement<float>) .def("_set_float_element", TensorSetElement<float>)
.def("_get_float_element", TensorGetElement<float>) .def("_get_float_element", TensorGetElement<float>)
.def("_set_double_element", TensorSetElement<double>) .def("_set_double_element", TensorSetElement<double>)
......
...@@ -20,6 +20,7 @@ include_directories("${PADDLE_LIB}/third_party/install/zlib/include") ...@@ -20,6 +20,7 @@ include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3") include_directories("${PADDLE_LIB}/third_party/eigen3")
include_directories("${PADDLE_LIB}/third_party/dlpack")
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
......
...@@ -125,6 +125,30 @@ class TestLoDTensor(unittest.TestCase): ...@@ -125,6 +125,30 @@ class TestLoDTensor(unittest.TestCase):
print(gtensor) print(gtensor)
self.assertTrue(isinstance(str(gtensor), str)) self.assertTrue(isinstance(str(gtensor), str))
def test_dlpack_support(self):
tensor = fluid.create_lod_tensor(
np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]],
fluid.CPUPlace())
dltensor = tensor._to_dlpack()
tensor_from_dlpack = fluid.core.from_dlpack(dltensor)
self.assertTrue(isinstance(tensor_from_dlpack, fluid.core.Tensor))
self.assertTrue(
np.array_equal(
np.array(tensor_from_dlpack),
np.array([[1], [2], [3], [4]]).astype('int')))
# when build with cuda
if core.is_compiled_with_cuda():
gtensor = fluid.create_lod_tensor(
np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]],
fluid.CUDAPlace(0))
gdltensor = gtensor._to_dlpack()
gtensor_from_dlpack = fluid.core.from_dlpack(gdltensor)
self.assertTrue(isinstance(gtensor_from_dlpack, fluid.core.Tensor))
self.assertTrue(
np.array_equal(
np.array(gtensor_from_dlpack),
np.array([[1], [2], [3], [4]]).astype('int')))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册