diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 5833fc90a76480551811ca59f24dc2af197fea33..7827c213fec953a5c7f403e536199fb3d8b80d7d 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -19,21 +19,30 @@ namespace operators { namespace math { template <> -void gemm( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const float alpha, const float* A, const int lda, - const float* B, const int ldb, const float beta, float* C, const int ldc, - platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { + int lda = K; + int ldb = N; + int ldc = N; cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } template <> -void gemm( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const double alpha, const double* A, - const int lda, const double* B, const int ldb, const double beta, double* C, - const int ldc, platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { + int lda = K; + int ldb = N; + int ldc = N; cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } @@ -67,8 +76,8 @@ void matmul(const framework::Tensor& in1, bool in1_T, CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), K, in2.data(), N, - beta, out->data(), N, context); + in1.data(), in2.data(), beta, + out->data(), context); } template <> @@ -100,8 +109,8 @@ void matmul(const framework::Tensor& in1, CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), K, in2.data(), N, - beta, out->data(), N, context); + in1.data(), in2.data(), beta, + out->data(), context); } } // namespace math diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index eb07bc89966e25d85b9f21a0afff07d2f116f798..12ddd2146f1619bb78bcf03e73b3ac8cf6576044 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -18,14 +18,16 @@ namespace operators { namespace math { template <> -void gemm( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const float alpha, const float* A, const int lda, - const float* B, const int ldb, const float beta, float* C, const int ldc, - platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - /* + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -34,8 +36,6 @@ void gemm( PADDLE_ENFORCE(platform::dynload::cublasSgemm( reinterpret_cast(context)->cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); - */ - PADDLE_THROW("not implemented now"); } template <> @@ -46,7 +46,8 @@ void gemm( const int ldc, platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - /* + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -54,8 +55,6 @@ void gemm( PADDLE_ENFORCE(platform::dynload::cublasDgemm( reinterpret_cast(context)->cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); - */ - PADDLE_THROW("not implemented now"); } template <> @@ -87,8 +86,8 @@ void matmul(const framework::Tensor& in1, bool in1_T, CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), K, in2.data(), N, - beta, out->data(), N, context); + in1.data(), in2.data(), beta, + out->data(), context); } template <> @@ -120,8 +119,8 @@ void matmul(const framework::Tensor& in1, CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), K, in2.data(), N, - beta, out->data(), N, context); + in1.data(), in2.data(), beta, + out->data(), context); } } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 0f8e7169f7b7218c4ef349bb28370e0b17ac34ab..12d1706afb82d37f0381facb7d448251372497d9 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -60,11 +60,11 @@ namespace paddle { namespace operators { namespace math { +// support continuous memory now template void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, - const int lda, const T* B, const int ldb, const T beta, T* C, - const int ldc, platform::DeviceContext* context); + const T* B, const T beta, T* C, platform::DeviceContext* context); // matrix multiply with continuous memory template diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 35975865c949ab6993d235797a84836ccc3a4fa1..346a7e505d123b5e4e831daa39a1f6349b3dcccf 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -16,5 +16,4 @@ #include "paddle/operators/mul_op.h" namespace ops = paddle::operators; -// REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel);