diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index def4b01da098fc960ce7c0e497732fbcc2579945..ba653afa2cb175ae2e5e21088b6dc7ba76a6018f 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -48,6 +48,32 @@ void gemm(const platform::DeviceContext& context, beta, C, ldc); } +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const float alpha, const float* A, + const int lda, const float* B, + const int ldb, const float beta, float* C, + const int ldc) { + cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const double alpha, const double* A, + const int lda, const double* B, + const int ldb, const double beta, + double* C, const int ldc) { + cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + template <> void matmul( const platform::DeviceContext& context, const framework::Tensor& matrix_a, diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 71563b77b4b262c3f1e17ae7c4381da56ba780a3..649f1f352c2a4a5ebaa0cb00ffb2e4de8aa4961a 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -63,6 +63,42 @@ void gemm(const platform::DeviceContext& context, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const float alpha, const float* A, + const int lda, const float* B, + const int ldb, const float beta, float* C, + const int ldc) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasSgemm( + reinterpret_cast(context) + .cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); +} + +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const double alpha, const double* A, + const int lda, const double* B, + const int ldb, const double beta, + double* C, const int ldc) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasDgemm( + reinterpret_cast(context) + .cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); +} + template <> void matmul( const platform::DeviceContext& context, const framework::Tensor& matrix_a, diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index d8518e77fa7b4abdbcf08b7983013c24806e14ca..43306fca73387b7b212f556a2b187df113a1b327 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, const T beta, T* C); +// gemm wrapper with stride args for matrix uncontinuous in memory +template +void gemm(const platform::DeviceContext& context, const bool transA, + const bool transB, const int M, const int N, const int K, + const T alpha, const T* A, const int lda, const T* B, const int ldb, + const T beta, T* C, const int ldc); + // matrix multiply with continuous memory template void matmul(const platform::DeviceContext& context, diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 7e339457f7f08ff16162f399064a4b4dca594d7f..f272f7e5135e7092618b8c94ee55faf1cfd8e8a5 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -72,4 +72,174 @@ TEST(math_function, trans_mul_notrans) { EXPECT_EQ(out_ptr[8], 29); delete gpu_place; } + +TEST(math_function, gemm_notrans_cublas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor input3_gpu; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({3, 4}, *cpu_place); + float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input2, *gpu_place); + input3_gpu.CopyFrom(input3, *gpu_place); + float* a = input1_gpu.data(); + float* b = input2_gpu.data(); + float* c = input3_gpu.mutable_data(*gpu_place); + + paddle::operators::math::gemm( + context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); + + input3.CopyFrom(input3_gpu, *cpu_place); + + // numpy code: + // a = np.arange(6).reshape(2, 3) + // b = np.arange(12).reshape(3, 4)[:, 1:] + // c = np.arange(8).reshape(2, 4)[:, 1:] + // out = np.arange(8).reshape(2, 4) + // out[:, 1:] = np.dot(a, b) + c + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); + delete gpu_place; +} + +TEST(math_function, gemm_trans_cublas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor input3_gpu; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({4, 3}, *cpu_place); + float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input2, *gpu_place); + input3_gpu.CopyFrom(input3, *gpu_place); + float* a = input1_gpu.data(); + float* b = input2_gpu.data(); + float* c = input3_gpu.mutable_data(*gpu_place); + + paddle::operators::math::gemm( + context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); + + input3.CopyFrom(input3_gpu, *cpu_place); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); + delete gpu_place; +} #endif + +TEST(math_function, gemm_notrans_cblas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({3, 4}, *cpu_place); + float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::gemm( + context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1, + input3_ptr + 1, 4); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); +} + +TEST(math_function, gemm_trans_clbas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({4, 3}, *cpu_place); + float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::gemm( + context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1, + input3_ptr + 1, 4); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); +}