未验证 提交 8907cdca 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor data()] Tensor support `void* data()` function (#50262)

* Tensor support void* data() function

* add unittest

* add selectedRows unittest

* polish unittest

* polish unittest

* polish unittest

* polish unittest
上级 1281b612
......@@ -347,6 +347,24 @@ class PADDLE_API Tensor final {
template <typename T>
T* data();
/**
* @brief Get the const memory pointer directly.
* It's usually used to get the output data pointer.
*
* @tparam T
* @return T*
*/
const void* data() const;
/**
* @brief Get the memory pointer directly.
* It's usually used to get the mutable output data pointer.
*
* @tparam T
* @return T*
*/
void* data();
/**
* @brief Return a sub-tensor of the given tensor.
* It is usually used to extract a sub-tensor (which supports
......
......@@ -298,6 +298,26 @@ template PADDLE_API phi::dtype::complex<float>
template PADDLE_API phi::dtype::complex<double>
*Tensor::data<phi::dtype::complex<double>>();
const void *Tensor::data() const {
if (is_dense_tensor()) {
return static_cast<phi::DenseTensor *>(impl_.get())->data();
} else if (is_selected_rows()) {
return static_cast<phi::SelectedRows *>(impl_.get())->value().data();
}
return nullptr;
}
void *Tensor::data() {
if (is_dense_tensor()) {
return static_cast<phi::DenseTensor *>(impl_.get())->data();
} else if (is_selected_rows()) {
return static_cast<phi::SelectedRows *>(impl_.get())
->mutable_value()
->data();
}
return nullptr;
}
// TODO(chenweihang): replace slice impl by API
Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
if (is_dense_tensor()) {
......
......@@ -2,17 +2,17 @@ if(WITH_GPU)
nv_test(
test_phi_tensor
SRCS test_phi_tensor.cc
DEPS phi_tensor glog)
DEPS phi_tensor glog selected_rows)
elseif(WITH_ROCM)
hip_test(
test_phi_tensor
SRCS test_phi_tensor.cc
DEPS phi_tensor glog)
DEPS phi_tensor glog selected_rows)
else()
cc_test(
test_phi_tensor
SRCS test_phi_tensor.cc
DEPS phi_tensor glog)
DEPS phi_tensor glog selected_rows)
endif()
cc_test(
......
......@@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/selected_rows.h"
namespace paddle {
namespace tests {
......@@ -208,6 +209,30 @@ void TestInitilized() {
}
}
void TestDataInterface() {
// Test DenseTensor
experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1});
CHECK(test_tensor.is_initialized() == true);
void* tensor_ptr = test_tensor.data();
CHECK(tensor_ptr != nullptr);
const void* const_tensor_ptr = test_tensor.data();
CHECK(const_tensor_ptr != nullptr);
// Test SelectedRows
std::vector<int64_t> rows = {0};
std::shared_ptr<phi::SelectedRows> selected_rows =
std::make_shared<phi::SelectedRows>(rows, 1);
selected_rows->mutable_value()->Resize(phi::make_ddim({1, 1}));
selected_rows->mutable_value()->mutable_data<float>(phi::CPUPlace())[0] =
static_cast<float>(10.0f);
paddle::experimental::Tensor sr_tensor =
paddle::experimental::Tensor(selected_rows);
CHECK(sr_tensor.is_initialized() == true);
tensor_ptr = sr_tensor.data();
CHECK(tensor_ptr != nullptr);
const_tensor_ptr = sr_tensor.data();
CHECK(const_tensor_ptr != nullptr);
}
void TestJudgeTensorType() {
experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1});
CHECK(test_tensor.is_dense_tensor() == true);
......@@ -228,6 +253,8 @@ TEST(PhiTensor, All) {
GroupTestCast();
VLOG(2) << "TestInitilized";
TestInitilized();
VLOG(2) << "TestDataInterface";
TestDataInterface();
VLOG(2) << "TestJudgeTensorType";
TestJudgeTensorType();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册