提交 43cee33a 编写于 作者: T tensor-tang

add mkl packed gemm

上级 0964de11
...@@ -90,6 +90,23 @@ class Blas { ...@@ -90,6 +90,23 @@ class Blas {
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const; int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
const int K) const;
template <typename T>
void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M,
int N, int K, const T alpha, const T* src, const int ld,
T* dst) const;
template <typename T>
void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A,
const int lda, const T* B, const int ldb, T beta, T* C,
const int ldc) const;
template <typename T>
void GEMM_FREE(T* data) const;
template <typename T> template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a, void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b, T alpha, const framework::Tensor& mat_b, bool trans_b, T alpha,
...@@ -146,6 +163,26 @@ class BlasT : private Blas<DeviceContext> { ...@@ -146,6 +163,26 @@ class BlasT : private Blas<DeviceContext> {
Base()->template GEMM<T>(args...); Base()->template GEMM<T>(args...);
} }
template <typename... ARGS>
T* GEMM_ALLOC(ARGS... args) const {
Base()->template GEMM_ALLOC<T>(args...);
}
template <typename... ARGS>
void GEMM_PACK(ARGS... args) const {
Base()->template GEMM_PACK<T>(args...);
}
template <typename... ARGS>
void GEMM_COMPUTE(ARGS... args) const {
Base()->template GEMM_COMPUTE<T>(args...);
}
template <typename... ARGS>
void GEMM_FREE(ARGS... args) const {
Base()->template GEMM_FREE<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void MatMul(ARGS... args) const { void MatMul(ARGS... args) const {
Base()->template MatMul<T>(args...); Base()->template MatMul<T>(args...);
......
...@@ -31,6 +31,26 @@ struct CBlas<float> { ...@@ -31,6 +31,26 @@ struct CBlas<float> {
platform::dynload::cblas_sgemm(args...); platform::dynload::cblas_sgemm(args...);
} }
template <typename... ARGS>
static float *GEMM_ALLOC(ARGS... args) {
return platform::dynload::cblas_sgemm_alloc(args...);
}
template <typename... ARGS>
static void GEMM_PACK(ARGS... args) {
platform::dynload::cblas_sgemm_pack(args...);
}
template <typename... ARGS>
static void GEMM_COMPUTE(ARGS... args) {
platform::dynload::cblas_sgemm_compute(args...);
}
template <typename... ARGS>
static void GEMM_FREE(ARGS... args) {
platform::dynload::cblas_sgemm_free(args...);
}
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
template <typename... ARGS> template <typename... ARGS>
static void SMM_GEMM(ARGS... args) { static void SMM_GEMM(ARGS... args) {
...@@ -71,6 +91,26 @@ struct CBlas<double> { ...@@ -71,6 +91,26 @@ struct CBlas<double> {
platform::dynload::cblas_dgemm(args...); platform::dynload::cblas_dgemm(args...);
} }
template <typename... ARGS>
static double *GEMM_ALLOC(ARGS... args) {
return platform::dynload::cblas_dgemm_alloc(args...);
}
template <typename... ARGS>
static void GEMM_PACK(ARGS... args) {
platform::dynload::cblas_dgemm_pack(args...);
}
template <typename... ARGS>
static void GEMM_COMPUTE(ARGS... args) {
platform::dynload::cblas_dgemm_compute(args...);
}
template <typename... ARGS>
static void GEMM_FREE(ARGS... args) {
platform::dynload::cblas_dgemm_free(args...);
}
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
template <typename... ARGS> template <typename... ARGS>
static void SMM_GEMM(ARGS... args) { static void SMM_GEMM(ARGS... args) {
...@@ -224,6 +264,39 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, ...@@ -224,6 +264,39 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
beta, C, ldc); beta, C, ldc);
} }
template <>
template <typename T>
T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
const int M, const int N,
const int K) const {
return CBlas<T>::GEMM_ALLOC(id, M, N, K);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
const CBLAS_TRANSPOSE trans,
int M, int N, int K,
const T alpha, const T *src,
const int ld, T *dst) const {
CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_COMPUTE(
int transA, int transB, int M, int N, int K, const T *A, const int lda,
const T *B, const int ldb, T beta, T *C, const int ldc) const {
CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
CBlas<T>::GEMM_FREE(data);
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
......
...@@ -60,6 +60,14 @@ extern void* mklml_dso_handle; ...@@ -60,6 +60,14 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_batch); \ __macro(cblas_dgemm_batch); \
__macro(vsAdd); \ __macro(vsAdd); \
__macro(vdAdd); \ __macro(vdAdd); \
__macro(cblas_sgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_alloc); \
__macro(cblas_dgemm_pack); \
__macro(cblas_dgemm_compute); \
__macro(cblas_dgemm_free); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册