diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 91ae3d49f1df51d9524547f7765285bff9dbb5c5..6bea9817f1b6c76e68e2a3023bb9eac591aa894f 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,6 +1,7 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) + nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax_function SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu @@ -8,9 +9,9 @@ if(WITH_GPU) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) + cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax_function SRCS softmax.cc DEPS operator) cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) endif() -nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 43306fca73387b7b212f556a2b187df113a1b327..473eff4d198ca9b17b6af8eebd6dfe39d49d138d 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -52,6 +52,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -84,6 +85,13 @@ void matmul(const platform::DeviceContext& context, const framework::Tensor& matrix_b, bool trans_b, T alpha, framework::Tensor* matrix_out, T beta); +template +void SetConstant(const platform::DeviceContext& context, + framework::Tensor* tensor, T num) { + auto t = framework::EigenVector::Flatten(*tensor); + t.device(*context.GetEigenDevice()) = t.constant(static_cast(num)); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index f272f7e5135e7092618b8c94ee55faf1cfd8e8a5..22468a0c4a4b0aca343fe766c8c9d63393a338eb 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -243,3 +243,24 @@ TEST(math_function, gemm_trans_clbas) { EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[7], 99); } + +TEST(math_function, zero) { + paddle::framework::Tensor tensor; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* t = tensor.mutable_data({2, 2}, *cpu_place); + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::SetConstant( + context, &tensor, 0); + EXPECT_EQ(t[0], 0); + EXPECT_EQ(t[1], 0); + EXPECT_EQ(t[2], 0); + EXPECT_EQ(t[3], 0); + + paddle::operators::math::SetConstant( + context, &tensor, 1); + + EXPECT_EQ(t[0], 1); + EXPECT_EQ(t[1], 1); + EXPECT_EQ(t[2], 1); + EXPECT_EQ(t[3], 1); +}