From 64a8e6d20e94651a610f6e623b32eb3af3afb2d6 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 11 Jul 2018 20:13:49 +0800 Subject: [PATCH] refine the threshold functions --- paddle/fluid/operators/math/blas_impl.h | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index bb2a076675..701965759e 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include #include #include "paddle/fluid/operators/math/math_function.h" @@ -161,6 +162,25 @@ struct CBlas { } #endif }; +template +inline static bool UseXSMM(const int &m, const int &n, const int &k, + bool transa, bool transb, const T &alpha, + const T &beta) { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom + constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + if (m * n * k > LIBXSMM_THRESHOLD || transa || transb || + std::abs(alpha - static_cast(1) > + std::numeric_limits::epsilon()) || + std::abs(beta) > std::numeric_limits::epsilon()) { + return false; + } else { + return true; + } +#endif + return false; +} template <> template @@ -172,8 +192,8 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; #ifdef PADDLE_WITH_LIBXSMM - if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans && - transB == CblasNoTrans) { + if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, + beta)) { // refer to https://github.com/hfp/libxsmm/blob/master/README.md // Note: SMM use ColMajor const char transa = 'N'; -- GitLab