未验证 提交 be04fbff 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #12233 from tensor-tang/refine/mkl/gemm

add option split mkl gemm
...@@ -136,6 +136,12 @@ else() ...@@ -136,6 +136,12 @@ else()
set(THIRD_PARTY_BUILD_TYPE Release) set(THIRD_PARTY_BUILD_TYPE Release)
endif() endif()
if(WITH_MKL)
option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF)
if (MKL_SPLIT_GEMM)
add_definitions(-DPADDLE_MKL_SPLIT_GEMM)
endif()
endif()
set(WITH_MKLML ${WITH_MKL}) set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN) if (NOT DEFINED WITH_MKLDNN)
if (WITH_MKL AND AVX2_FOUND) if (WITH_MKL AND AVX2_FOUND)
......
...@@ -37,6 +37,7 @@ struct CBlas<float> { ...@@ -37,6 +37,7 @@ struct CBlas<float> {
libxsmm_sgemm(args...); libxsmm_sgemm(args...);
} }
#endif #endif
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
platform::dynload::cblas_saxpy(args...); platform::dynload::cblas_saxpy(args...);
...@@ -76,6 +77,7 @@ struct CBlas<double> { ...@@ -76,6 +77,7 @@ struct CBlas<double> {
libxsmm_dgemm(args...); libxsmm_dgemm(args...);
} }
#endif #endif
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
platform::dynload::cblas_daxpy(args...); platform::dynload::cblas_daxpy(args...);
...@@ -150,6 +152,7 @@ struct CBlas<double> { ...@@ -150,6 +152,7 @@ struct CBlas<double> {
} }
}; };
#endif #endif
template <> template <>
struct CBlas<platform::float16> { struct CBlas<platform::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
...@@ -190,30 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k, ...@@ -190,30 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
return false; return false;
} }
template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
int N, int K, T alpha, const T *A, const T *A, int lda, const T *B, int ldb, T beta, T *C,
const T *B, T beta, T *C) const { int ldc) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
beta)) { beta)) {
// Note: SMM use ColMajor // Note: SMM use ColMajor
const char transa = 'N'; const char transa = 'N';
const char transb = 'N'; const char transb = 'N';
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
&beta, C, &ldc); &beta, C, &ldc);
} else { return;
}
#endif #endif
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
ldb, beta, C, ldc); #ifdef PADDLE_MKL_SPLIT_GEMM
#ifdef PADDLE_WITH_LIBXSMM constexpr int bs = 2;
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
for (int off = 0; off < M; off += bs) {
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
}
return;
} }
#endif #endif
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A,
const T *B, T beta, T *C) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
} }
template <> template <>
...@@ -222,7 +243,7 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M, ...@@ -222,7 +243,7 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A, int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb, int lda, const T *B, int ldb,
T beta, T *C, int ldc) const { T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc); lda, B, ldb, beta, C, ldc);
} }
......
...@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) { ...@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) {
} }
delete ctx; delete ctx;
} }
template <typename T>
void GemmWarpTest(int m, int n, int k, T alpha, T beta) {
paddle::framework::Tensor mat_a;
paddle::framework::Tensor mat_b;
paddle::framework::Tensor mat_c_ref;
paddle::framework::Tensor mat_c_mkl;
auto* cpu_place = new paddle::platform::CPUPlace();
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
T* CREF = mat_c_ref.mutable_data<T>({m, n}, *cpu_place);
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
ASSERT_EQ(mat_c_mkl.numel(), mat_c_ref.numel());
for (int i = 0; i < mat_a.numel(); ++i) {
A[i] = static_cast<T>(i);
}
for (int i = 0; i < mat_b.numel(); ++i) {
B[i] = static_cast<T>(i + 1);
}
for (int i = 0; i < mat_c_ref.numel(); ++i) {
CREF[i] = static_cast<T>(i + 2);
CMKL[i] = CREF[i];
}
// this would call gemm_warp
paddle::platform::CPUDeviceContext context(*cpu_place);
GetBlas<T>(context).GEMM(CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, B,
beta, CREF);
// lda,ldb,ldc follow RowMajor
int lda = k;
int ldb = n;
int ldc = n;
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
CblasNoTrans, m, n, k, alpha, A, lda,
B, ldb, beta, CMKL, ldc);
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
EXPECT_FLOAT_EQ(CREF[i], CMKL[i]);
}
}
TEST(math_function, gemm_warp) {
GemmWarpTest<float>(3, 2, 5, 1.f, 0.f);
GemmWarpTest<float>(3, 2, 5, 2.f, 1.f);
GemmWarpTest<float>(8, 5, 6, 1.f, 0.f);
GemmWarpTest<float>(8, 5, 6, 2.f, 1.f);
GemmWarpTest<double>(3, 2, 5, 1.0, 0.0);
GemmWarpTest<double>(3, 2, 5, 2.0, 1.0);
GemmWarpTest<double>(8, 5, 6, 1.0, 0.0);
GemmWarpTest<double>(8, 5, 6, 2.0, 1.0);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册