From c634a8480addf2e3cbbd271853f4c8aa4b10832b Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 28 Sep 2017 17:53:50 -0700 Subject: [PATCH] add SetConstant method in math_function.h --- paddle/operators/math/CMakeLists.txt | 3 ++- paddle/operators/math/math_function.h | 8 ++++++++ paddle/operators/math/math_function_test.cc | 21 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 91ae3d49f1..6bea9817f1 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,6 +1,7 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) + nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax_function SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu @@ -8,9 +9,9 @@ if(WITH_GPU) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) + cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax_function SRCS softmax.cc DEPS operator) cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) endif() -nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 43306fca73..473eff4d19 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -52,6 +52,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -84,6 +85,13 @@ void matmul(const platform::DeviceContext& context, const framework::Tensor& matrix_b, bool trans_b, T alpha, framework::Tensor* matrix_out, T beta); +template +void SetConstant(const platform::DeviceContext& context, + framework::Tensor* tensor, T num) { + auto t = framework::EigenVector::Flatten(*tensor); + t.device(*context.GetEigenDevice()) = t.constant(static_cast(num)); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index f272f7e513..22468a0c4a 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -243,3 +243,24 @@ TEST(math_function, gemm_trans_clbas) { EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[7], 99); } + +TEST(math_function, zero) { + paddle::framework::Tensor tensor; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* t = tensor.mutable_data({2, 2}, *cpu_place); + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::SetConstant( + context, &tensor, 0); + EXPECT_EQ(t[0], 0); + EXPECT_EQ(t[1], 0); + EXPECT_EQ(t[2], 0); + EXPECT_EQ(t[3], 0); + + paddle::operators::math::SetConstant( + context, &tensor, 1); + + EXPECT_EQ(t[0], 1); + EXPECT_EQ(t[1], 1); + EXPECT_EQ(t[2], 1); + EXPECT_EQ(t[3], 1); +} -- GitLab