提交 54c95e49 编写于 作者: T tensor-tang

fix blas

上级 8c23f7c4
...@@ -90,6 +90,7 @@ class Blas { ...@@ -90,6 +90,7 @@ class Blas {
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const; int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
#ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
const int K) const; const int K) const;
...@@ -106,6 +107,7 @@ class Blas { ...@@ -106,6 +107,7 @@ class Blas {
template <typename T> template <typename T>
void GEMM_FREE(T* data) const; void GEMM_FREE(T* data) const;
#endif
template <typename T> template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a, void MatMul(const framework::Tensor& mat_a, bool trans_a,
...@@ -163,6 +165,7 @@ class BlasT : private Blas<DeviceContext> { ...@@ -163,6 +165,7 @@ class BlasT : private Blas<DeviceContext> {
Base()->template GEMM<T>(args...); Base()->template GEMM<T>(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS> template <typename... ARGS>
T* GEMM_ALLOC(ARGS... args) const { T* GEMM_ALLOC(ARGS... args) const {
return Base()->template GEMM_ALLOC<T>(args...); return Base()->template GEMM_ALLOC<T>(args...);
...@@ -182,6 +185,7 @@ class BlasT : private Blas<DeviceContext> { ...@@ -182,6 +185,7 @@ class BlasT : private Blas<DeviceContext> {
void GEMM_FREE(ARGS... args) const { void GEMM_FREE(ARGS... args) const {
Base()->template GEMM_FREE<T>(args...); Base()->template GEMM_FREE<T>(args...);
} }
#endif
template <typename... ARGS> template <typename... ARGS>
void MatMul(ARGS... args) const { void MatMul(ARGS... args) const {
......
...@@ -264,6 +264,7 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, ...@@ -264,6 +264,7 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
beta, C, ldc); beta, C, ldc);
} }
#ifdef PADDLE_WITH_MKLML
template <> template <>
template <typename T> template <typename T>
T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id, T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
...@@ -296,6 +297,7 @@ template <typename T> ...@@ -296,6 +297,7 @@ template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const { void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
CBlas<T>::GEMM_FREE(data); CBlas<T>::GEMM_FREE(data);
} }
#endif
template <> template <>
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册