diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index bb2a0766757b2764245a3682b2a6f2cf326de72b..701965759eb22581fba527e9f75ed703d422aa46 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include #include #include "paddle/fluid/operators/math/math_function.h" @@ -161,6 +162,25 @@ struct CBlas { } #endif }; +template +inline static bool UseXSMM(const int &m, const int &n, const int &k, + bool transa, bool transb, const T &alpha, + const T &beta) { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom + constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + if (m * n * k > LIBXSMM_THRESHOLD || transa || transb || + std::abs(alpha - static_cast(1) > + std::numeric_limits::epsilon()) || + std::abs(beta) > std::numeric_limits::epsilon()) { + return false; + } else { + return true; + } +#endif + return false; +} template <> template @@ -172,8 +192,8 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; #ifdef PADDLE_WITH_LIBXSMM - if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans && - transB == CblasNoTrans) { + if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, + beta)) { // refer to https://github.com/hfp/libxsmm/blob/master/README.md // Note: SMM use ColMajor const char transa = 'N';