From 49dedfad17a9cb80d98247fdbfddda50d33e2381 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 2 May 2018 11:46:08 +0800 Subject: [PATCH] Polish code and tests --- paddle/fluid/operators/math/blas_impl.cu.h | 17 ++++- .../operators/math/math_function_test.cc | 17 +++-- .../operators/math/math_function_test.cu | 62 ++++++++++--------- 3 files changed, 59 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index b7bd8f1d04..86e4946991 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -42,9 +42,20 @@ struct CUBlas { template <> struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasHgemm(args...)); + using float16 = platform::float16; + + static void GEMM(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float16 *alpha, const float16 *A, int lda, + const float16 *B, int ldb, const float16 *beta, float16 *C, + int ldc) { + PADDLE_ENFORCE( + platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, + reinterpret_cast(beta), + reinterpret_cast<__half *>(C), ldc)); } }; diff --git a/paddle/fluid/operators/math/math_function_test.cc b/paddle/fluid/operators/math/math_function_test.cc index 25a9d0111e..6d11dc8c76 100644 --- a/paddle/fluid/operators/math/math_function_test.cc +++ b/paddle/fluid/operators/math/math_function_test.cc @@ -14,6 +14,13 @@ #include "paddle/fluid/operators/math/math_function.h" #include "gtest/gtest.h" +template +inline paddle::operators::math::BlasT +GetBlas(const paddle::platform::CPUDeviceContext& context) { + return paddle::operators::math::GetBlas(context); +} + TEST(math_function, gemm_notrans_cblas) { paddle::framework::Tensor input1; paddle::framework::Tensor input2; @@ -34,9 +41,8 @@ TEST(math_function, gemm_notrans_cblas) { memcpy(input3_ptr, arr3, 8 * sizeof(float)); paddle::platform::CPUDeviceContext context(*cpu_place); - paddle::operators::math::gemm( - context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1, - input3_ptr + 1, 4); + GetBlas(context).GEMM(false, false, m, n, k, 1, input1_ptr, 3, + input2_ptr + 1, 4, 1, input3_ptr + 1, 4); EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[1], 24); @@ -68,9 +74,8 @@ TEST(math_function, gemm_trans_clbas) { memcpy(input3_ptr, arr3, 8 * sizeof(float)); paddle::platform::CPUDeviceContext context(*cpu_place); - paddle::operators::math::gemm( - context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1, - input3_ptr + 1, 4); + GetBlas(context).GEMM(false, true, m, n, k, 1, input1_ptr, 3, + input2_ptr + 3, 3, 1, input3_ptr + 1, 4); EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[1], 24); diff --git a/paddle/fluid/operators/math/math_function_test.cu b/paddle/fluid/operators/math/math_function_test.cu index 7986326e96..22484e1c1a 100644 --- a/paddle/fluid/operators/math/math_function_test.cu +++ b/paddle/fluid/operators/math/math_function_test.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "gtest/gtest.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_context.h" void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, const std::vector& data) { @@ -23,8 +24,8 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, } TEST(math_function, notrans_mul_trans_fp32) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor input1; Tensor input1_gpu; @@ -59,8 +60,8 @@ TEST(math_function, notrans_mul_trans_fp32) { } TEST(math_function, notrans_mul_trans_fp16) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor input1; Tensor input1_gpu; @@ -100,8 +101,8 @@ TEST(math_function, notrans_mul_trans_fp16) { } TEST(math_function, trans_mul_notrans_fp32) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor input1; Tensor input1_gpu; @@ -141,8 +142,8 @@ TEST(math_function, trans_mul_notrans_fp32) { } TEST(math_function, trans_mul_notrans_fp16) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor input1; Tensor input1_gpu; @@ -186,9 +187,16 @@ TEST(math_function, trans_mul_notrans_fp16) { EXPECT_EQ(static_cast(out_ptr[8]), 29); } +template +inline paddle::operators::math::BlasT +GetBlas(const paddle::platform::CUDADeviceContext& context) { + return paddle::operators::math::GetBlas(context); +} + TEST(math_function, gemm_notrans_cublas_fp32) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor input1; Tensor input2; @@ -221,8 +229,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) { float* b = input2_gpu.data(); float* c = input3_gpu.mutable_data(gpu_place); - paddle::operators::math::gemm( - context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); + GetBlas(context).GEMM(false, false, m, n, k, 1, a, 3, b + 1, 4, 1, + c + 1, 4); TensorCopySync(input3_gpu, cpu_place, &input3); @@ -244,8 +252,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) { } TEST(math_function, gemm_notrans_cublas_fp16) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor input1; Tensor input2; @@ -281,9 +289,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) { float16* b = input2_gpu.data(); float16* c = input3_gpu.mutable_data(gpu_place); - paddle::operators::math::gemm( - context, false, false, m, n, k, float16(1), a, 3, b + 1, 4, float16(1), - c + 1, 4); + GetBlas(context).GEMM(false, false, m, n, k, float16(1), a, 3, b + 1, + 4, float16(1), c + 1, 4); TensorCopySync(input3_gpu, cpu_place, &input3); @@ -305,8 +312,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) { } TEST(math_function, gemm_trans_cublas_fp32) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor input1; Tensor input2; @@ -339,8 +346,8 @@ TEST(math_function, gemm_trans_cublas_fp32) { float* b = input2_gpu.data(); float* c = input3_gpu.mutable_data(gpu_place); - paddle::operators::math::gemm( - context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); + GetBlas(context).GEMM(false, true, m, n, k, 1, a, 3, b + 3, 3, 1, + c + 1, 4); TensorCopySync(input3_gpu, cpu_place, &input3); @@ -356,8 +363,8 @@ TEST(math_function, gemm_trans_cublas_fp32) { } TEST(math_function, gemm_trans_cublas_fp16) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor input1; Tensor input2; @@ -393,9 +400,8 @@ TEST(math_function, gemm_trans_cublas_fp16) { float16* b = input2_gpu.data(); float16* c = input3_gpu.mutable_data(gpu_place); - paddle::operators::math::gemm( - context, false, true, m, n, k, float16(1), a, 3, b + 3, 3, float16(1), - c + 1, 4); + GetBlas(context).GEMM(false, true, m, n, k, float16(1), a, 3, b + 3, + 3, float16(1), c + 1, 4); TensorCopySync(input3_gpu, cpu_place, &input3); @@ -412,8 +418,8 @@ TEST(math_function, gemm_trans_cublas_fp16) { template void GemvTest(int m, int n, bool trans) { - using namespace paddle::framework; - using namespace paddle::platform; + using namespace paddle::framework; // NOLINT + using namespace paddle::platform; // NOLINT Tensor mat_a; Tensor vec_b; -- GitLab