From c3941745b32cb0e5917b8095996ee65bbf70b588 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 5 Jul 2018 21:17:10 +0800 Subject: [PATCH] add libxsmm_gemm --- cmake/external/libxsmm.cmake | 1 + cmake/external/openblas.cmake | 5 +++ paddle/fluid/operators/math/CMakeLists.txt | 6 +--- paddle/fluid/operators/math/blas_impl.h | 42 +++++++++++++++++++--- 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/cmake/external/libxsmm.cmake b/cmake/external/libxsmm.cmake index f9f1e71311..2d6b538638 100644 --- a/cmake/external/libxsmm.cmake +++ b/cmake/external/libxsmm.cmake @@ -52,4 +52,5 @@ MESSAGE(STATUS "Libxsmm library: ${LIBXSMM_LIBS}") include_directories(${LIBXSMM_INCLUDE_DIR}) ADD_DEFINITIONS(-DPADDLE_WITH_LIBXSMM) ADD_DEPENDENCIES(libxsmm extern_libxsmm) +LIST(APPEND external_project_dependencies libxsmm) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index ce6a88b51d..56024edf5b 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -121,6 +121,11 @@ ELSE() TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES}) ENDIF("${CBLAS_PROVIDER}" STREQUAL "MKLML") +IF(WITH_LIBXSMM) + TARGET_LINK_LIBRARIES(cblas ${LIBXSMM_LIBS}) + ADD_DEPENDENCIES(cblas extern_libxsmm) +ENDIF() + IF(NOT ${CBLAS_FOUND}) ADD_DEPENDENCIES(cblas extern_openblas) LIST(APPEND external_project_dependencies cblas) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 6f672a069b..5571ff9a71 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -41,11 +41,7 @@ math_library(depthwise_conv) math_library(gru_compute DEPS activation_functions math_function) math_library(im2col) math_library(lstm_compute DEPS activation_functions) -set(BLAS_DEPS cblas framework_proto device_context) -if (WITH_LIBXSMM) - list(APPEND BLAS_DEPS libxsmm) -endif() -cc_library(blas SRCS blas.cc DEPS ${BLAS_DEPS}) +cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) math_library(math_function DEPS blas) math_library(maxouting) math_library(pooling) diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 2ce94cfc93..020b5d86b1 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -30,6 +30,12 @@ struct CBlas { platform::dynload::cblas_sgemm(args...); } +#ifdef PADDLE_WITH_LIBXSMM + template + static void SMM_GEMM(ARGS... args) { + libxsmm_sgemm(args...); + } +#endif template static void AXPY(ARGS... args) { platform::dynload::cblas_saxpy(args...); @@ -63,6 +69,12 @@ struct CBlas { platform::dynload::cblas_dgemm(args...); } +#ifdef PADDLE_WITH_LIBXSMM + template + static void SMM_GEMM(ARGS... args) { + libxsmm_dgemm(args...); + } +#endif template static void AXPY(ARGS... args) { platform::dynload::cblas_daxpy(args...); @@ -140,6 +152,9 @@ struct CBlas { template <> struct CBlas { static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } + static void SMM_GEMM(...) { + PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); + } #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -153,11 +168,28 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T *A, const T *B, T beta, T *C) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); +#ifdef PADDLE_WITH_LIBXSMM + if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans && + transB == CblasNoTrans) { + // refer to https://github.com/hfp/libxsmm/blob/master/README.md + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + const int lda = M; + const int ldb = K; + const int ldc = M; + CBlas::SMM_GEMM(&transa, &transb, &M, &N, &K, &alpha, A, &lda, B, &ldb, + &beta, C, &ldc); + } else { +#endif + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, + ldb, beta, C, ldc); +#ifdef PADDLE_WITH_LIBXSMM + } +#endif } template <> -- GitLab