diff --git a/CMakeLists.txt b/CMakeLists.txt index 48e52961a95d50264b201eec50ccb3a462f39c54..317f7f9eb46a96e9f6ea393abf82d608af50fc4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -138,12 +138,6 @@ else() set(THIRD_PARTY_BUILD_TYPE Release) endif() -if(WITH_MKL) - option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF) - if (MKL_SPLIT_GEMM) - add_definitions(-DPADDLE_MKL_SPLIT_GEMM) - endif() -endif() set(WITH_MKLML ${WITH_MKL}) if (NOT DEFINED WITH_MKLDNN) if (WITH_MKL AND AVX2_FOUND) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 8dcf7c99f3860789dee834787eeb8b7ad4cc3530..295431347a449196b415f197c9e2ea8480b873d1 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,11 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; + template + void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, + int ldc) const; + #ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, @@ -109,6 +114,10 @@ class Blas { void GEMM_FREE(T* data) const; #endif + template + void MatMul(const int M, const int N, const int K, const T* A, const T* B, + T* C) const; + template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, T alpha, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index dc77b6d793702458a22a2f59b68e9d9f2c23b4ff..d39a3e7f6eb91ded2585a2b71868a1649376ac86 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -217,64 +217,6 @@ struct CBlas { #endif }; -template -inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa, - bool transb, const T &alpha, const T &beta) { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom - constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - if (m * n * k > LIBXSMM_THRESHOLD || transa || transb || - std::abs(alpha - static_cast(1) > - std::numeric_limits::epsilon()) || - std::abs(beta) > std::numeric_limits::epsilon()) { - return false; - } else { - return true; - } -#endif - return false; -} - -template <> -inline bool UseXSMM(const int &m, const int &n, const int &k, - bool transa, bool transb, - const platform::float16 &alpha, - const platform::float16 &beta) { - return false; -} - -template -inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, - const T *A, int lda, const T *B, int ldb, T beta, T *C, - int ldc) { -#ifdef PADDLE_WITH_LIBXSMM - if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, - beta)) { - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, - &beta, C, &ldc); - return; - } -#endif - -#ifdef PADDLE_MKL_SPLIT_GEMM - constexpr int bs = 2; - if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) { - for (int off = 0; off < M; off += bs) { - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha, - A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc); - } - return; - } -#endif - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - #ifdef PADDLE_WITH_MKLML template <> template @@ -319,8 +261,8 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; - GEMM_WARP(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> @@ -329,9 +271,20 @@ void Blas::GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T *A, int lda, const T *B, int ldb, T beta, T *C, int ldc) const { - GEMM_WARP(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); + CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, + int N, int K, T alpha, const T *A, + int lda, const T *B, int ldb, + T beta, T *C, int ldc) const { + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template @@ -440,6 +393,43 @@ void Blas::BatchedGEMM( #endif } +template +template +void Blas::MatMul(const int M, const int N, const int K, + const T *A, const T *B, T *C) const { + this->template GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, + N); +} + +template <> +template +void Blas::MatMul(const int M, const int N, + const int K, const T *A, + const T *B, T *C) const { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + + // Since the matrix is very small, + // so the unit of calculation is already very fast, + // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, + // use xsmm directly. + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + const T alpha = static_cast(1); + const T beta = static_cast(0); + CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, + C, &N); + return; + +#endif + + CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, N); +} + template template void Blas::MatMul(const framework::Tensor &mat_a, diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 8600fa9e2c4db9d54cbe0ffb68f82d52c086d4f7..1f5a49c0ab5a10b0d7dc1febd258ce76c467cb1c 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -25,17 +25,25 @@ namespace math { template inline void FCCompute(const BlasT& blas, const int M, const int N, const int K, const T* X, const T* W, T* Y, - const T* B = NULL) { - blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), X, W, - static_cast(0), Y); - if (B) { + const T* B = NULL, bool relu = false) { + blas.MatMul(M, N, K, X, W, Y); + if (B == NULL) { + return; + } + #ifdef PADDLE_WITH_MKLML #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) #endif - for (int i = 0; i < M; i++) { - blas.AXPY(N, static_cast(1), B, Y + i * N); - } + for (int i = 0; i < M; i++) { + blas.AXPY(N, static_cast(1), B, Y + i * N); } + + if (!relu) { + return; + } + + // TODO(TJ): fuse relu + LOG(FATAL) << "Not implemented!"; } } // namespace math