diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 977ef3ba2c301a01d5ff13fe549c7226dfca596f..a0802ef90ca7e30a2b22d187cb9092163518d8e9 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -37,6 +37,7 @@ struct CBlas { libxsmm_sgemm(args...); } #endif + template static void AXPY(ARGS... args) { platform::dynload::cblas_saxpy(args...); @@ -76,6 +77,7 @@ struct CBlas { libxsmm_dgemm(args...); } #endif + template static void AXPY(ARGS... args) { platform::dynload::cblas_daxpy(args...); @@ -150,6 +152,7 @@ struct CBlas { } }; #endif + template <> struct CBlas { static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } @@ -190,45 +193,48 @@ inline bool UseXSMM(const int &m, const int &n, const int &k, return false; } -template <> template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, - int N, int K, T alpha, const T *A, - const T *B, T beta, T *C) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; +inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, + const T *A, int lda, const T *B, int ldb, T beta, T *C, + int ldc) { #ifdef PADDLE_WITH_LIBXSMM - if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, - beta)) { + if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, + beta)) { // Note: SMM use ColMajor const char transa = 'N'; const char transb = 'N'; CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, &beta, C, &ldc); - } else { + return; + } #endif #ifdef PADDLE_MKL_SPLIT_GEMM - constexpr int bs = 2; - if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) { - for (int off = 0; off < M; off += bs) { - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, off, N, K, - alpha, A + off * lda, lda, B, ldb, beta, C + off * ldb, - ldc); - } - } else { -#endif - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, - ldb, beta, C, ldc); -#ifdef PADDLE_MKL_SPLIT_GEMM + constexpr int bs = 2; + if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) { + for (int off = 0; off < M; off += bs) { + CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha, + A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc); } -#endif - -#ifdef PADDLE_WITH_LIBXSMM + return; } #endif + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, + int N, int K, T alpha, const T *A, + const T *B, T beta, T *C) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + GEMM_WARP(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> @@ -237,9 +243,9 @@ void Blas::GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T *A, int lda, const T *B, int ldb, T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); + GEMM_WARP(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); } template