diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 1bfbc755736658f45d22681d863d8b8e4a88fe9a..5833fc90a76480551811ca59f24dc2af197fea33 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -19,74 +19,29 @@ namespace operators { namespace math { template <> -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int lda, - const float* B, - const int ldb, - const float beta, - float* C, - const int ldc, - platform::DeviceContext* context) { - cblas_sgemm(CblasRowMajor, - transA, - transB, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); +void gemm( + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, const float alpha, const float* A, const int lda, + const float* B, const int ldb, const float beta, float* C, const int ldc, + platform::DeviceContext* context) { + cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const double alpha, - const double* A, - const int lda, - const double* B, - const int ldb, - const double beta, - double* C, - const int ldc, - platform::DeviceContext* context) { - cblas_dgemm(CblasRowMajor, - transA, - transB, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); +void gemm( + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, const double alpha, const double* A, + const int lda, const double* B, const int ldb, const double beta, double* C, + const int ldc, platform::DeviceContext* context) { + cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> -void matmul(const framework::Tensor& in1, - bool in1_T, - const framework::Tensor& in2, - bool in2_T, - float alpha, - framework::Tensor* out, +void matmul(const framework::Tensor& in1, bool in1_T, + const framework::Tensor& in2, bool in2_T, + float alpha, framework::Tensor* out, float beta, platform::DeviceContext* context) { auto in1_dim = in1.dims(); @@ -111,30 +66,17 @@ void matmul(const framework::Tensor& in1, CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, - in2_Trans, - M, - N, - K, - alpha, - in1.data(), - K, - in2.data(), - N, - beta, - out->data(), - N, - context); + gemm(in1_Trans, in2_Trans, M, N, K, alpha, + in1.data(), K, in2.data(), N, + beta, out->data(), N, context); } template <> -void matmul(const framework::Tensor& in1, +void matmul(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, - bool in2_T, - float alpha, - framework::Tensor* out, - float beta, + bool in2_T, float alpha, + framework::Tensor* out, float beta, platform::DeviceContext* context) { auto in1_dim = in1.dims(); auto in2_dim = in2.dims(); @@ -157,20 +99,9 @@ void matmul(const framework::Tensor& in1, CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, - in2_Trans, - M, - N, - K, - alpha, - in1.data(), - K, - in2.data(), - N, - beta, - out->data(), - N, - context); + gemm(in1_Trans, in2_Trans, M, N, K, alpha, + in1.data(), K, in2.data(), N, + beta, out->data(), N, context); } } // namespace math