提交 54c95e49 编写于 作者: T tensor-tang

fix blas

上级 8c23f7c4
......@@ -90,6 +90,7 @@ class Blas {
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
#ifdef PADDLE_WITH_MKLML
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
const int K) const;
......@@ -106,6 +107,7 @@ class Blas {
template <typename T>
void GEMM_FREE(T* data) const;
#endif
template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a,
......@@ -163,6 +165,7 @@ class BlasT : private Blas<DeviceContext> {
Base()->template GEMM<T>(args...);
}
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
T* GEMM_ALLOC(ARGS... args) const {
return Base()->template GEMM_ALLOC<T>(args...);
......@@ -182,6 +185,7 @@ class BlasT : private Blas<DeviceContext> {
void GEMM_FREE(ARGS... args) const {
Base()->template GEMM_FREE<T>(args...);
}
#endif
template <typename... ARGS>
void MatMul(ARGS... args) const {
......
......@@ -264,6 +264,7 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
beta, C, ldc);
}
#ifdef PADDLE_WITH_MKLML
template <>
template <typename T>
T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
......@@ -296,6 +297,7 @@ template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
CBlas<T>::GEMM_FREE(data);
}
#endif
template <>
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册