未验证 提交 6fa29c55 编写于 作者: H HongyuJia 提交者: GitHub

[Polist Unittest] Polish test_phi_tensor (#50440)

* fix py::array_t calling bug

* polish test_phi_tensor

* stop fix inference bug in this PR

* polish unittest

* change int->int32_t

* fix unittest

* fix compile error

* modify cmake

* remove redundancy codes

* fix selectedRow unittest

* fix cmake relay

* declare kernel
上级 b5809912
set(COMMON_API_TEST_DEPS phi_tensor phi_api phi_api_utils)
if(WITH_GPU) if(WITH_GPU)
nv_test( nv_test(
test_phi_tensor test_phi_tensor
SRCS test_phi_tensor.cc SRCS test_phi_tensor.cc
DEPS phi_tensor glog selected_rows) DEPS glog selected_rows ${COMMON_API_TEST_DEPS})
elseif(WITH_ROCM) elseif(WITH_ROCM)
hip_test( hip_test(
test_phi_tensor test_phi_tensor
SRCS test_phi_tensor.cc SRCS test_phi_tensor.cc
DEPS phi_tensor glog selected_rows) DEPS glog selected_rows ${COMMON_API_TEST_DEPS})
else() else()
cc_test( cc_test(
test_phi_tensor test_phi_tensor
SRCS test_phi_tensor.cc SRCS test_phi_tensor.cc
DEPS phi_tensor glog selected_rows) DEPS glog selected_rows ${COMMON_API_TEST_DEPS})
endif() endif()
cc_test( cc_test(
...@@ -20,7 +22,6 @@ cc_test( ...@@ -20,7 +22,6 @@ cc_test(
SRCS test_phi_exception.cc SRCS test_phi_exception.cc
DEPS gtest) DEPS gtest)
set(COMMON_API_TEST_DEPS phi_tensor phi_api phi_api_utils)
cc_test( cc_test(
test_to_api test_to_api
SRCS test_to_api.cc SRCS test_to_api.cc
......
...@@ -14,18 +14,30 @@ ...@@ -14,18 +14,30 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
PD_DECLARE_KERNEL(empty, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(empty, GPU, ALL_LAYOUT);
#endif
namespace paddle { namespace paddle {
namespace tests { namespace tests {
using Tensor = paddle::experimental::Tensor;
using DataType = paddle::experimental::DataType;
template <typename T> template <typename T>
experimental::Tensor InitCPUTensorForTest() { Tensor InitCPUTensorForTest() {
std::vector<int64_t> tensor_shape{5, 5}; std::vector<int64_t> tensor_shape{5, 5};
auto t1 = experimental::Tensor(paddle::PlaceType::kCPU, tensor_shape); DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU); Tensor t1 = paddle::experimental::empty(tensor_shape, dtype, phi::CPUPlace());
auto* p_data_ptr = t1.data<T>();
for (int64_t i = 0; i < t1.size(); i++) { for (int64_t i = 0; i < t1.size(); i++) {
p_data_ptr[i] = T(5); p_data_ptr[i] = T(5);
} }
...@@ -35,22 +47,22 @@ experimental::Tensor InitCPUTensorForTest() { ...@@ -35,22 +47,22 @@ experimental::Tensor InitCPUTensorForTest() {
template <typename T> template <typename T>
void TestCopyTensor() { void TestCopyTensor() {
auto t1 = InitCPUTensorForTest<T>(); auto t1 = InitCPUTensorForTest<T>();
auto t1_cpu_cp = t1.template copy_to<T>(paddle::PlaceType::kCPU); auto t1_cpu_cp = t1.copy_to(phi::CPUPlace(), /*blocking=*/false);
CHECK((paddle::PlaceType::kCPU == t1_cpu_cp.place())); CHECK((phi::CPUPlace() == t1_cpu_cp.place()));
for (int64_t i = 0; i < t1.size(); i++) { for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_cpu_cp.template mutable_data<T>()[i], T(5)); CHECK_EQ(t1_cpu_cp.template data<T>()[i], T(5));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
VLOG(2) << "Do GPU copy test"; VLOG(2) << "Do GPU copy test";
auto t1_gpu_cp = t1_cpu_cp.template copy_to<T>(paddle::PlaceType::kGPU); auto t1_gpu_cp = t1_cpu_cp.copy_to(phi::GPUPlace(), /*blocking=*/false);
CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place())); CHECK((phi::GPUPlace() == t1_gpu_cp.place()));
auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kGPU); auto t1_gpu_cp_cp = t1_gpu_cp.copy_to(phi::GPUPlace(), /*blocking=*/false);
CHECK((paddle::PlaceType::kGPU == t1_gpu_cp_cp.place())); CHECK((phi::GPUPlace() == t1_gpu_cp_cp.place()));
auto t1_gpu_cp_cp_cpu = auto t1_gpu_cp_cp_cpu =
t1_gpu_cp_cp.template copy_to<T>(paddle::PlaceType::kCPU); t1_gpu_cp_cp.copy_to(phi::CPUPlace(), /*blocking=*/false);
CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place())); CHECK((phi::CPUPlace() == t1_gpu_cp_cp_cpu.place()));
for (int64_t i = 0; i < t1.size(); i++) { for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_gpu_cp_cp_cpu.template mutable_data<T>()[i], T(5)); CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], T(5));
} }
#endif #endif
} }
...@@ -58,18 +70,18 @@ void TestCopyTensor() { ...@@ -58,18 +70,18 @@ void TestCopyTensor() {
void TestAPIPlace() { void TestAPIPlace() {
std::vector<int64_t> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto t1 = experimental::Tensor(paddle::PlaceType::kGPU, tensor_shape); auto t1 = paddle::experimental::empty(
t1.mutable_data<float>(paddle::PlaceType::kGPU); tensor_shape, DataType::FLOAT32, phi::GPUPlace());
CHECK((paddle::PlaceType::kGPU == t1.place())); CHECK((phi::GPUPlace() == t1.place()));
#endif #endif
auto t2 = experimental::Tensor(paddle::PlaceType::kCPU, tensor_shape); auto t2 = paddle::experimental::empty(
t2.mutable_data<float>(paddle::PlaceType::kCPU); tensor_shape, DataType::FLOAT32, phi::CPUPlace());
CHECK((paddle::PlaceType::kCPU == t2.place())); CHECK((phi::CPUPlace() == t2.place()));
} }
void TestAPISizeAndShape() { void TestAPISizeAndShape() {
std::vector<int64_t> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = experimental::Tensor(paddle::PlaceType::kCPU, tensor_shape); auto t1 = paddle::experimental::empty(tensor_shape);
CHECK_EQ(t1.size(), 25); CHECK_EQ(t1.size(), 25);
CHECK(t1.shape() == tensor_shape); CHECK(t1.shape() == tensor_shape);
} }
...@@ -80,31 +92,30 @@ void TestAPISlice() { ...@@ -80,31 +92,30 @@ void TestAPISlice() {
std::vector<int64_t> tensor_shape_origin2 = {5, 5, 5}; std::vector<int64_t> tensor_shape_origin2 = {5, 5, 5};
std::vector<int64_t> tensor_shape_sub2 = {1, 5, 5}; std::vector<int64_t> tensor_shape_sub2 = {1, 5, 5};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto t1 = experimental::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin1); auto t1 = paddle::experimental::empty(
t1.mutable_data<float>(paddle::PlaceType::kGPU); tensor_shape_origin1, DataType::FLOAT32, phi::GPUPlace());
CHECK(t1.slice(0, 5).shape() == tensor_shape_origin1); CHECK(t1.slice(0, 5).shape() == tensor_shape_origin1);
CHECK(t1.slice(0, 3).shape() == tensor_shape_sub1); CHECK(t1.slice(0, 3).shape() == tensor_shape_sub1);
auto t2 = experimental::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin2); auto t2 = paddle::experimental::empty(
t2.mutable_data<float>(paddle::PlaceType::kGPU); tensor_shape_origin2, DataType::FLOAT32, phi::GPUPlace());
CHECK(t2.slice(4, 5).shape() == tensor_shape_sub2); CHECK(t2.slice(4, 5).shape() == tensor_shape_sub2);
#endif #endif
auto t3 = experimental::Tensor(paddle::PlaceType::kCPU, tensor_shape_origin1); auto t3 = paddle::experimental::empty(
t3.mutable_data<float>(paddle::PlaceType::kCPU); tensor_shape_origin1, DataType::FLOAT32, phi::CPUPlace());
CHECK(t3.slice(0, 5).shape() == tensor_shape_origin1); CHECK(t3.slice(0, 5).shape() == tensor_shape_origin1);
CHECK(t3.slice(0, 3).shape() == tensor_shape_sub1); CHECK(t3.slice(0, 3).shape() == tensor_shape_sub1);
auto t4 = experimental::Tensor(paddle::PlaceType::kCPU, tensor_shape_origin2); auto t4 = paddle::experimental::empty(
t4.mutable_data<float>(paddle::PlaceType::kCPU); tensor_shape_origin2, DataType::FLOAT32, phi::CPUPlace());
CHECK(t4.slice(4, 5).shape() == tensor_shape_sub2); CHECK(t4.slice(4, 5).shape() == tensor_shape_sub2);
// Test writing function for sliced tensor // Test writing function for sliced tensor
auto t = InitCPUTensorForTest<float>(); auto t = InitCPUTensorForTest<float>();
auto t_sliced = t.slice(0, 1); auto t_sliced = t.slice(0, 1);
auto* t_sliced_data_ptr = auto* t_sliced_data_ptr = t_sliced.data<float>();
t_sliced.mutable_data<float>(paddle::PlaceType::kCPU);
for (int64_t i = 0; i < t_sliced.size(); i++) { for (int64_t i = 0; i < t_sliced.size(); i++) {
t_sliced_data_ptr[i] += static_cast<float>(5); t_sliced_data_ptr[i] += static_cast<float>(5);
} }
auto* t_data_ptr = t.mutable_data<float>(paddle::PlaceType::kCPU); auto* t_data_ptr = t.data<float>();
for (int64_t i = 0; i < t_sliced.size(); i++) { for (int64_t i = 0; i < t_sliced.size(); i++) {
CHECK_EQ(t_data_ptr[i], static_cast<float>(10)); CHECK_EQ(t_data_ptr[i], static_cast<float>(10));
} }
...@@ -113,22 +124,20 @@ void TestAPISlice() { ...@@ -113,22 +124,20 @@ void TestAPISlice() {
template <typename T> template <typename T>
paddle::DataType TestDtype() { paddle::DataType TestDtype() {
std::vector<int64_t> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = experimental::Tensor(paddle::PlaceType::kCPU, tensor_shape); DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
t1.template mutable_data<T>(paddle::PlaceType::kCPU); auto t1 = paddle::experimental::empty(tensor_shape, dtype, phi::CPUPlace());
return t1.type(); return t1.type();
} }
template <typename T> template <typename T>
void TestCast(paddle::DataType data_type) { void TestCast(paddle::DataType data_type) {
std::vector<int64_t> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = experimental::Tensor(paddle::PlaceType::kCPU, tensor_shape); DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
t1.template mutable_data<T>(paddle::PlaceType::kCPU); auto t1 = paddle::experimental::empty(tensor_shape, dtype, phi::CPUPlace());
auto t2 = t1.cast(data_type); auto t2 = t1.cast(data_type);
CHECK(t2.type() == data_type); CHECK(t2.type() == data_type);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tg1 = experimental::Tensor(paddle::PlaceType::kGPU); auto tg1 = paddle::experimental::empty(tensor_shape, dtype, phi::GPUPlace());
tg1.reshape(tensor_shape);
tg1.template mutable_data<T>(paddle::PlaceType::kGPU);
auto tg2 = tg1.cast(data_type); auto tg2 = tg1.cast(data_type);
CHECK(tg2.type() == data_type); CHECK(tg2.type() == data_type);
#endif #endif
...@@ -140,7 +149,7 @@ void GroupTestCopy() { ...@@ -140,7 +149,7 @@ void GroupTestCopy() {
VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<double>(); TestCopyTensor<double>();
VLOG(2) << "int cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "int cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int>(); TestCopyTensor<int32_t>();
VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int64_t>(); TestCopyTensor<int64_t>();
VLOG(2) << "int16 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "int16 cpu-cpu-gpu-gpu-cpu";
...@@ -158,8 +167,8 @@ void GroupTestCopy() { ...@@ -158,8 +167,8 @@ void GroupTestCopy() {
} }
void GroupTestCast() { void GroupTestCast() {
VLOG(2) << "int cast"; VLOG(2) << "int16_t cast";
TestCast<int>(paddle::DataType::FLOAT32); TestCast<int16_t>(paddle::DataType::FLOAT32);
VLOG(2) << "int32 cast"; VLOG(2) << "int32 cast";
TestCast<int32_t>(paddle::DataType::FLOAT32); TestCast<int32_t>(paddle::DataType::FLOAT32);
VLOG(2) << "int64 cast"; VLOG(2) << "int64 cast";
...@@ -185,7 +194,6 @@ void GroupTestDtype() { ...@@ -185,7 +194,6 @@ void GroupTestDtype() {
CHECK(TestDtype<int8_t>() == paddle::DataType::INT8); CHECK(TestDtype<int8_t>() == paddle::DataType::INT8);
CHECK(TestDtype<uint8_t>() == paddle::DataType::UINT8); CHECK(TestDtype<uint8_t>() == paddle::DataType::UINT8);
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16); CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
CHECK(TestDtype<int>() == paddle::DataType::INT32);
CHECK(TestDtype<int32_t>() == paddle::DataType::INT32); CHECK(TestDtype<int32_t>() == paddle::DataType::INT32);
CHECK(TestDtype<int64_t>() == paddle::DataType::INT64); CHECK(TestDtype<int64_t>() == paddle::DataType::INT64);
CHECK(TestDtype<paddle::float16>() == paddle::DataType::FLOAT16); CHECK(TestDtype<paddle::float16>() == paddle::DataType::FLOAT16);
...@@ -196,11 +204,9 @@ void GroupTestDtype() { ...@@ -196,11 +204,9 @@ void GroupTestDtype() {
} }
void TestInitilized() { void TestInitilized() {
experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1}); auto test_tensor = paddle::experimental::empty({1, 1});
CHECK(test_tensor.is_initialized() == true);
test_tensor.mutable_data<float>(paddle::PlaceType::kCPU);
CHECK(test_tensor.is_initialized() == true); CHECK(test_tensor.is_initialized() == true);
float* tensor_data = test_tensor.mutable_data<float>(); float* tensor_data = test_tensor.data<float>();
for (int i = 0; i < test_tensor.size(); i++) { for (int i = 0; i < test_tensor.size(); i++) {
tensor_data[i] = 0.5; tensor_data[i] = 0.5;
} }
...@@ -211,7 +217,7 @@ void TestInitilized() { ...@@ -211,7 +217,7 @@ void TestInitilized() {
void TestDataInterface() { void TestDataInterface() {
// Test DenseTensor // Test DenseTensor
experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1}); auto test_tensor = paddle::experimental::empty({1, 1});
CHECK(test_tensor.is_initialized() == true); CHECK(test_tensor.is_initialized() == true);
void* tensor_ptr = test_tensor.data(); void* tensor_ptr = test_tensor.data();
CHECK(tensor_ptr != nullptr); CHECK(tensor_ptr != nullptr);
...@@ -234,7 +240,7 @@ void TestDataInterface() { ...@@ -234,7 +240,7 @@ void TestDataInterface() {
} }
void TestJudgeTensorType() { void TestJudgeTensorType() {
experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1}); experimental::Tensor test_tensor(phi::CPUPlace(), {1, 1});
CHECK(test_tensor.is_dense_tensor() == true); CHECK(test_tensor.is_dense_tensor() == true);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册