diff --git a/CMakeLists.txt b/CMakeLists.txt index 0ab80987b3ad6c4793ceeac1bf3808d2e87fbd5b..231224f9249848b6e4981a98e0538794bf5d3c08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -136,6 +136,12 @@ else() set(THIRD_PARTY_BUILD_TYPE Release) endif() +if(WITH_MKL) + option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF) + if (MKL_SPLIT_GEMM) + add_definitions(-DPADDLE_MKL_SPLIT_GEMM) + endif() +endif() set(WITH_MKLML ${WITH_MKL}) if (NOT DEFINED WITH_MKLDNN) if (WITH_MKL AND AVX2_FOUND) diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 238bd3f8def9eaa6c18afdab1031c4babfde8ae2..977ef3ba2c301a01d5ff13fe549c7226dfca596f 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -209,8 +209,23 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, &beta, C, &ldc); } else { #endif - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, - ldb, beta, C, ldc); + +#ifdef PADDLE_MKL_SPLIT_GEMM + constexpr int bs = 2; + if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) { + for (int off = 0; off < M; off += bs) { + CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, off, N, K, + alpha, A + off * lda, lda, B, ldb, beta, C + off * ldb, + ldc); + } + } else { +#endif + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, + ldb, beta, C, ldc); +#ifdef PADDLE_MKL_SPLIT_GEMM + } +#endif + #ifdef PADDLE_WITH_LIBXSMM } #endif