From 8907cdca933f4ff735d715fc56ba51e4c220b728 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Mon, 13 Feb 2023 15:30:52 +0800 Subject: [PATCH] [Tensor data()] Tensor support `void* data()` function (#50262) * Tensor support void* data() function * add unittest * add selectedRows unittest * polish unittest * polish unittest * polish unittest * polish unittest --- paddle/phi/api/include/tensor.h | 18 +++++++++++++++++ paddle/phi/api/lib/tensor.cc | 20 ++++++++++++++++++ paddle/phi/tests/api/CMakeLists.txt | 6 +++--- paddle/phi/tests/api/test_phi_tensor.cc | 27 +++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/include/tensor.h b/paddle/phi/api/include/tensor.h index 35f7e889432..e412c1a7f05 100644 --- a/paddle/phi/api/include/tensor.h +++ b/paddle/phi/api/include/tensor.h @@ -347,6 +347,24 @@ class PADDLE_API Tensor final { template T* data(); + /** + * @brief Get the const memory pointer directly. + * It's usually used to get the output data pointer. + * + * @tparam T + * @return T* + */ + const void* data() const; + + /** + * @brief Get the memory pointer directly. + * It's usually used to get the mutable output data pointer. + * + * @tparam T + * @return T* + */ + void* data(); + /** * @brief Return a sub-tensor of the given tensor. * It is usually used to extract a sub-tensor (which supports diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index 118be826549..d1883b5b87a 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -298,6 +298,26 @@ template PADDLE_API phi::dtype::complex template PADDLE_API phi::dtype::complex *Tensor::data>(); +const void *Tensor::data() const { + if (is_dense_tensor()) { + return static_cast(impl_.get())->data(); + } else if (is_selected_rows()) { + return static_cast(impl_.get())->value().data(); + } + return nullptr; +} + +void *Tensor::data() { + if (is_dense_tensor()) { + return static_cast(impl_.get())->data(); + } else if (is_selected_rows()) { + return static_cast(impl_.get()) + ->mutable_value() + ->data(); + } + return nullptr; +} + // TODO(chenweihang): replace slice impl by API Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const { if (is_dense_tensor()) { diff --git a/paddle/phi/tests/api/CMakeLists.txt b/paddle/phi/tests/api/CMakeLists.txt index 7e258379754..431e8a33eb7 100644 --- a/paddle/phi/tests/api/CMakeLists.txt +++ b/paddle/phi/tests/api/CMakeLists.txt @@ -2,17 +2,17 @@ if(WITH_GPU) nv_test( test_phi_tensor SRCS test_phi_tensor.cc - DEPS phi_tensor glog) + DEPS phi_tensor glog selected_rows) elseif(WITH_ROCM) hip_test( test_phi_tensor SRCS test_phi_tensor.cc - DEPS phi_tensor glog) + DEPS phi_tensor glog selected_rows) else() cc_test( test_phi_tensor SRCS test_phi_tensor.cc - DEPS phi_tensor glog) + DEPS phi_tensor glog selected_rows) endif() cc_test( diff --git a/paddle/phi/tests/api/test_phi_tensor.cc b/paddle/phi/tests/api/test_phi_tensor.cc index bea36bb0eaf..e21980294cc 100644 --- a/paddle/phi/tests/api/test_phi_tensor.cc +++ b/paddle/phi/tests/api/test_phi_tensor.cc @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/selected_rows.h" namespace paddle { namespace tests { @@ -208,6 +209,30 @@ void TestInitilized() { } } +void TestDataInterface() { + // Test DenseTensor + experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1}); + CHECK(test_tensor.is_initialized() == true); + void* tensor_ptr = test_tensor.data(); + CHECK(tensor_ptr != nullptr); + const void* const_tensor_ptr = test_tensor.data(); + CHECK(const_tensor_ptr != nullptr); + // Test SelectedRows + std::vector rows = {0}; + std::shared_ptr selected_rows = + std::make_shared(rows, 1); + selected_rows->mutable_value()->Resize(phi::make_ddim({1, 1})); + selected_rows->mutable_value()->mutable_data(phi::CPUPlace())[0] = + static_cast(10.0f); + paddle::experimental::Tensor sr_tensor = + paddle::experimental::Tensor(selected_rows); + CHECK(sr_tensor.is_initialized() == true); + tensor_ptr = sr_tensor.data(); + CHECK(tensor_ptr != nullptr); + const_tensor_ptr = sr_tensor.data(); + CHECK(const_tensor_ptr != nullptr); +} + void TestJudgeTensorType() { experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1}); CHECK(test_tensor.is_dense_tensor() == true); @@ -228,6 +253,8 @@ TEST(PhiTensor, All) { GroupTestCast(); VLOG(2) << "TestInitilized"; TestInitilized(); + VLOG(2) << "TestDataInterface"; + TestDataInterface(); VLOG(2) << "TestJudgeTensorType"; TestJudgeTensorType(); } -- GitLab