diff --git a/CMakeLists.txt b/CMakeLists.txt index db3c3b8e2069f9ae5ad02286b59decf8fe764c2d..4ba33c8674403500238cbf4e88de6775e4ad68e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,6 +146,7 @@ endif() ######################################################################################## include(external/mklml) # download mklml package +include(external/libxsmm) # download, build, install libxsmm include(external/zlib) # download, build, install zlib include(external/gflags) # download, build, install gflags include(external/glog) # download, build, install glog @@ -232,6 +233,10 @@ if(WITH_MKLML) list(APPEND EXTERNAL_LIBS ${MKLML_IOMP_LIB}) endif() +if(WITH_LIBXSMM) + list(APPEND EXTERNAL_LIBS ${LIBXSMM_LIBS}) +endif() + if(WITH_MKLDNN) list(APPEND EXTERNAL_LIBS ${MKLDNN_LIB}) endif() diff --git a/cmake/external/libxsmm.cmake b/cmake/external/libxsmm.cmake new file mode 100644 index 0000000000000000000000000000000000000000..530f7ebe2813fb2f00c6b5b4d1f7b2f04fe650b0 --- /dev/null +++ b/cmake/external/libxsmm.cmake @@ -0,0 +1,57 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +OPTION(WITH_LIBXSMM "Compile with libxsmm" OFF) + +IF(NOT WITH_LIBXSMM) + return() +ENDIF() + +IF(WIN32 OR APPLE OR ANDROID OR IOS) + MESSAGE(WARNING "Windows, Mac or Mobile are not supported with libxsmm in Paddle yet.") + SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM" FORCE) + return() +ENDIF() + +INCLUDE (ExternalProject) + +SET(LIBXSMM_SOURCES_DIR ${THIRD_PARTY_PATH}/libxsmm) +SET(LIBXSMM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/libxsmm) +SET(LIBXSMM_INCLUDE_DIR "${LIBXSMM_INSTALL_DIR}/include" CACHE PATH "LIBXSMM include directory." FORCE) +SET(LIBXSMM_LIBRARY_DIR "${LIBXSMM_INSTALL_DIR}/lib" CACHE PATH "LIBXSMM library directory." FORCE) +SET(LIBXSMM_LIBS "${LIBXSMM_LIBRARY_DIR}/libxsmm.a" + "${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a") + +ExternalProject_Add( + extern_libxsmm + GIT_REPOSITORY "https://github.com/hfp/libxsmm.git" + GIT_TAG "7cc03b5b342fdbc6b6d990b190671c5dbb8489a2" + PREFIX ${LIBXSMM_SOURCES_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + BUILD_COMMAND $(MAKE) --silent PREFIX=${LIBXSMM_INSTALL_DIR} CXX=g++ CC=gcc WARP=0 install + INSTALL_COMMAND "" +) +ADD_LIBRARY(libxsmm STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET libxsmm PROPERTY IMPORTED_LOCATION "${LIBXSMM_LIBRARY_DIR}/libxsmm.a") +SET_PROPERTY(TARGET libxsmm PROPERTY IMPORTED_LOCATION "${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a") + +MESSAGE(STATUS "Libxsmm library: ${LIBXSMM_LIBS}") +include_directories(${LIBXSMM_INCLUDE_DIR}) +ADD_DEFINITIONS(-DPADDLE_WITH_LIBXSMM) +ADD_DEPENDENCIES(libxsmm extern_libxsmm) +LIST(APPEND external_project_dependencies libxsmm) + diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index ce6a88b51dc98ac46dd3935f12658d60d364ba8c..56024edf5be092f81ed893633a8e7cafc8c8d429 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -121,6 +121,11 @@ ELSE() TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES}) ENDIF("${CBLAS_PROVIDER}" STREQUAL "MKLML") +IF(WITH_LIBXSMM) + TARGET_LINK_LIBRARIES(cblas ${LIBXSMM_LIBS}) + ADD_DEPENDENCIES(cblas extern_libxsmm) +ENDIF() + IF(NOT ${CBLAS_FOUND}) ADD_DEPENDENCIES(cblas extern_openblas) LIST(APPEND external_project_dependencies cblas) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 9f6c1e5c35f02cd4bc729eea78b17fac017aa90e..70f88f24f682e05972ca73ef7b50f96be50d1ef4 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -21,6 +21,10 @@ #include "paddle/fluid/platform/dynload/mklml.h" #endif +#ifdef PADDLE_WITH_LIBXSMM +#include +#endif + #ifdef PADDLE_USE_OPENBLAS #include #endif diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 2ce94cfc93823aa891114ef8fd1e851727ebc623..238bd3f8def9eaa6c18afdab1031c4babfde8ae2 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include #include #include "paddle/fluid/operators/math/math_function.h" @@ -30,6 +31,12 @@ struct CBlas { platform::dynload::cblas_sgemm(args...); } +#ifdef PADDLE_WITH_LIBXSMM + template + static void SMM_GEMM(ARGS... args) { + libxsmm_sgemm(args...); + } +#endif template static void AXPY(ARGS... args) { platform::dynload::cblas_saxpy(args...); @@ -63,6 +70,12 @@ struct CBlas { platform::dynload::cblas_dgemm(args...); } +#ifdef PADDLE_WITH_LIBXSMM + template + static void SMM_GEMM(ARGS... args) { + libxsmm_dgemm(args...); + } +#endif template static void AXPY(ARGS... args) { platform::dynload::cblas_daxpy(args...); @@ -140,6 +153,9 @@ struct CBlas { template <> struct CBlas { static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } + static void SMM_GEMM(...) { + PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); + } #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -147,6 +163,33 @@ struct CBlas { #endif }; +template +inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa, + bool transb, const T &alpha, const T &beta) { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom + constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + if (m * n * k > LIBXSMM_THRESHOLD || transa || transb || + std::abs(alpha - static_cast(1) > + std::numeric_limits::epsilon()) || + std::abs(beta) > std::numeric_limits::epsilon()) { + return false; + } else { + return true; + } +#endif + return false; +} + +template <> +inline bool UseXSMM(const int &m, const int &n, const int &k, + bool transa, bool transb, + const platform::float16 &alpha, + const platform::float16 &beta) { + return false; +} + template <> template void Blas::GEMM(CBLAS_TRANSPOSE transA, @@ -156,8 +199,21 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); +#ifdef PADDLE_WITH_LIBXSMM + if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, + beta)) { + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, + &beta, C, &ldc); + } else { +#endif + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, + ldb, beta, C, ldc); +#ifdef PADDLE_WITH_LIBXSMM + } +#endif } template <> diff --git a/paddle/fluid/operators/math/math_function_test.cc b/paddle/fluid/operators/math/math_function_test.cc index b545671b43d3a453ab03e4774427179617f62db0..078dd448c385dbb8a00025ee2ba08d0c41a4730a 100644 --- a/paddle/fluid/operators/math/math_function_test.cc +++ b/paddle/fluid/operators/math/math_function_test.cc @@ -54,8 +54,64 @@ TEST(math_function, gemm_notrans_cblas) { EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[7], 99); } +#ifdef PADDLE_WITH_LIBXSMM +template +void MklSmmCompare(int m, int n, int k) { + paddle::framework::Tensor mat_a; + paddle::framework::Tensor mat_b; + paddle::framework::Tensor mat_c_smm; + paddle::framework::Tensor mat_c_mkl; + auto* cpu_place = new paddle::platform::CPUPlace(); + + T* A = mat_a.mutable_data({m, k}, *cpu_place); + T* B = mat_b.mutable_data({k, n}, *cpu_place); + T* CSMM = mat_c_smm.mutable_data({m, n}, *cpu_place); + T* CMKL = mat_c_mkl.mutable_data({m, n}, *cpu_place); + T alpha = static_cast(1); + T beta = static_cast(0); + for (int i = 0; i < mat_a.numel(); ++i) { + A[i] = static_cast(i); + } + for (int i = 0; i < mat_b.numel(); ++i) { + B[i] = static_cast(i); + } + // lda,ldb,ldc follow RowMajor + int lda = k; + int ldb = n; + int ldc = n; + + auto smm = [&, m, n, k, lda, ldb, ldc, alpha, beta]() { + const char transa = 'N'; + const char transb = 'N'; + paddle::operators::math::CBlas::SMM_GEMM(&transa, &transb, &n, &m, &k, + &alpha, B, &ldb, A, &lda, &beta, + CSMM, &ldc); + }; + + auto mkl = [&, m, n, k, lda, ldb, ldc, alpha, beta]() { + paddle::operators::math::CBlas::GEMM(CblasRowMajor, CblasNoTrans, + CblasNoTrans, m, n, k, alpha, A, + lda, B, ldb, beta, CMKL, ldc); + }; + + smm(); + mkl(); + ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel()); + for (int i = 0; i < mat_c_mkl.numel(); ++i) { + EXPECT_FLOAT_EQ(CSMM[i], CMKL[i]); + } +} +TEST(math_function, gemm_mkl_vs_smm) { + MklSmmCompare(1, 2, 3); + MklSmmCompare(1, 2, 3); + MklSmmCompare(3, 2, 1); + MklSmmCompare(3, 2, 1); + MklSmmCompare(3, 8, 5); + MklSmmCompare(3, 8, 5); +} +#endif -TEST(math_function, gemm_trans_clbas) { +TEST(math_function, gemm_trans_cblas) { paddle::framework::Tensor input1; paddle::framework::Tensor input2; paddle::framework::Tensor input3;