提交 3501a3be 编写于 作者: C chenjiawen

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into ce

......@@ -146,6 +146,7 @@ endif()
########################################################################################
include(external/mklml) # download mklml package
include(external/libxsmm) # download, build, install libxsmm
include(external/zlib) # download, build, install zlib
include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog
......@@ -232,6 +233,10 @@ if(WITH_MKLML)
list(APPEND EXTERNAL_LIBS ${MKLML_IOMP_LIB})
endif()
if(WITH_LIBXSMM)
list(APPEND EXTERNAL_LIBS ${LIBXSMM_LIBS})
endif()
if(WITH_MKLDNN)
list(APPEND EXTERNAL_LIBS ${MKLDNN_LIB})
endif()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
OPTION(WITH_LIBXSMM "Compile with libxsmm" OFF)
IF(NOT WITH_LIBXSMM)
return()
ENDIF()
IF(WIN32 OR APPLE OR ANDROID OR IOS)
MESSAGE(WARNING "Windows, Mac or Mobile are not supported with libxsmm in Paddle yet.")
SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM" FORCE)
return()
ENDIF()
INCLUDE (ExternalProject)
SET(LIBXSMM_SOURCES_DIR ${THIRD_PARTY_PATH}/libxsmm)
SET(LIBXSMM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/libxsmm)
SET(LIBXSMM_INCLUDE_DIR "${LIBXSMM_INSTALL_DIR}/include" CACHE PATH "LIBXSMM include directory." FORCE)
SET(LIBXSMM_LIBRARY_DIR "${LIBXSMM_INSTALL_DIR}/lib" CACHE PATH "LIBXSMM library directory." FORCE)
SET(LIBXSMM_LIBS "${LIBXSMM_LIBRARY_DIR}/libxsmm.a"
"${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a")
ExternalProject_Add(
extern_libxsmm
GIT_REPOSITORY "https://github.com/hfp/libxsmm.git"
GIT_TAG "7cc03b5b342fdbc6b6d990b190671c5dbb8489a2"
PREFIX ${LIBXSMM_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
BUILD_COMMAND $(MAKE) --silent PREFIX=${LIBXSMM_INSTALL_DIR} CXX=g++ CC=gcc WARP=0 install
INSTALL_COMMAND ""
)
ADD_LIBRARY(libxsmm STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET libxsmm PROPERTY IMPORTED_LOCATION "${LIBXSMM_LIBRARY_DIR}/libxsmm.a")
SET_PROPERTY(TARGET libxsmm PROPERTY IMPORTED_LOCATION "${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a")
MESSAGE(STATUS "Libxsmm library: ${LIBXSMM_LIBS}")
include_directories(${LIBXSMM_INCLUDE_DIR})
ADD_DEFINITIONS(-DPADDLE_WITH_LIBXSMM)
ADD_DEPENDENCIES(libxsmm extern_libxsmm)
LIST(APPEND external_project_dependencies libxsmm)
......@@ -121,6 +121,11 @@ ELSE()
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
ENDIF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
IF(WITH_LIBXSMM)
TARGET_LINK_LIBRARIES(cblas ${LIBXSMM_LIBS})
ADD_DEPENDENCIES(cblas extern_libxsmm)
ENDIF()
IF(NOT ${CBLAS_FOUND})
ADD_DEPENDENCIES(cblas extern_openblas)
LIST(APPEND external_project_dependencies cblas)
......
......@@ -21,6 +21,10 @@
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef PADDLE_WITH_LIBXSMM
#include <libxsmm.h>
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#endif
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <vector>
#include "paddle/fluid/operators/math/math_function.h"
......@@ -30,6 +31,12 @@ struct CBlas<float> {
platform::dynload::cblas_sgemm(args...);
}
#ifdef PADDLE_WITH_LIBXSMM
template <typename... ARGS>
static void SMM_GEMM(ARGS... args) {
libxsmm_sgemm(args...);
}
#endif
template <typename... ARGS>
static void AXPY(ARGS... args) {
platform::dynload::cblas_saxpy(args...);
......@@ -63,6 +70,12 @@ struct CBlas<double> {
platform::dynload::cblas_dgemm(args...);
}
#ifdef PADDLE_WITH_LIBXSMM
template <typename... ARGS>
static void SMM_GEMM(ARGS... args) {
libxsmm_dgemm(args...);
}
#endif
template <typename... ARGS>
static void AXPY(ARGS... args) {
platform::dynload::cblas_daxpy(args...);
......@@ -140,6 +153,9 @@ struct CBlas<double> {
template <>
struct CBlas<platform::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
static void SMM_GEMM(...) {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
}
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
......@@ -147,6 +163,33 @@ struct CBlas<platform::float16> {
#endif
};
template <typename T>
inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
bool transb, const T &alpha, const T &beta) {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
std::abs<T>(alpha - static_cast<T>(1) >
std::numeric_limits<T>::epsilon()) ||
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
return false;
} else {
return true;
}
#endif
return false;
}
template <>
inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
bool transa, bool transb,
const platform::float16 &alpha,
const platform::float16 &beta) {
return false;
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
......@@ -156,8 +199,21 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
#ifdef PADDLE_WITH_LIBXSMM
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
beta)) {
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
&beta, C, &ldc);
} else {
#endif
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
ldb, beta, C, ldc);
#ifdef PADDLE_WITH_LIBXSMM
}
#endif
}
template <>
......
......@@ -54,8 +54,64 @@ TEST(math_function, gemm_notrans_cblas) {
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
}
#ifdef PADDLE_WITH_LIBXSMM
template <typename T>
void MklSmmCompare(int m, int n, int k) {
paddle::framework::Tensor mat_a;
paddle::framework::Tensor mat_b;
paddle::framework::Tensor mat_c_smm;
paddle::framework::Tensor mat_c_mkl;
auto* cpu_place = new paddle::platform::CPUPlace();
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
T* CSMM = mat_c_smm.mutable_data<T>({m, n}, *cpu_place);
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
T alpha = static_cast<T>(1);
T beta = static_cast<T>(0);
for (int i = 0; i < mat_a.numel(); ++i) {
A[i] = static_cast<T>(i);
}
for (int i = 0; i < mat_b.numel(); ++i) {
B[i] = static_cast<T>(i);
}
// lda,ldb,ldc follow RowMajor
int lda = k;
int ldb = n;
int ldc = n;
auto smm = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
const char transa = 'N';
const char transb = 'N';
paddle::operators::math::CBlas<T>::SMM_GEMM(&transa, &transb, &n, &m, &k,
&alpha, B, &ldb, A, &lda, &beta,
CSMM, &ldc);
};
auto mkl = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
CblasNoTrans, m, n, k, alpha, A,
lda, B, ldb, beta, CMKL, ldc);
};
smm();
mkl();
ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel());
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
EXPECT_FLOAT_EQ(CSMM[i], CMKL[i]);
}
}
TEST(math_function, gemm_mkl_vs_smm) {
MklSmmCompare<float>(1, 2, 3);
MklSmmCompare<double>(1, 2, 3);
MklSmmCompare<float>(3, 2, 1);
MklSmmCompare<double>(3, 2, 1);
MklSmmCompare<float>(3, 8, 5);
MklSmmCompare<double>(3, 8, 5);
}
#endif
TEST(math_function, gemm_trans_clbas) {
TEST(math_function, gemm_trans_cblas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册