提交 a2203d04 编写于 作者: T tensor-tang

add cblas dot

上级 f72ab896
...@@ -153,6 +153,9 @@ class Blas { ...@@ -153,6 +153,9 @@ class Blas {
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const; T* C) const;
template <typename T>
T DOT(int n, const T* x, const T* y) const;
template <typename T> template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C, int K, T alpha, const T* A, const T* B, T beta, T* C,
......
...@@ -73,6 +73,11 @@ struct CBlas<float> { ...@@ -73,6 +73,11 @@ struct CBlas<float> {
platform::dynload::cblas_sgemv(args...); platform::dynload::cblas_sgemv(args...);
} }
template <typename... ARGS>
static float DOT(ARGS... args) {
return platform::dynload::cblas_sdot(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...); platform::dynload::cblas_sgemm_batch(args...);
...@@ -138,6 +143,11 @@ struct CBlas<double> { ...@@ -138,6 +143,11 @@ struct CBlas<double> {
platform::dynload::cblas_dgemv(args...); platform::dynload::cblas_dgemv(args...);
} }
template <typename... ARGS>
static double DOT(ARGS... args) {
return platform::dynload::cblas_ddot(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_dgemm_batch(args...); platform::dynload::cblas_dgemm_batch(args...);
...@@ -210,6 +220,7 @@ struct CBlas<platform::float16> { ...@@ -210,6 +220,7 @@ struct CBlas<platform::float16> {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
} }
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) { static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
...@@ -352,6 +363,21 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y, ...@@ -352,6 +363,21 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
#endif #endif
} }
template <>
template <typename T>
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
#ifdef PADDLE_WITH_MKLML
return CBlas<T>::DOT(n, x, y);
#else
// try to find if openblas support cblas_dot
T sum = 0;
for (int i = 0; i < n; ++i) {
sum += x[i] * y[i];
}
return sum;
#endif
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha, void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
...@@ -423,7 +449,6 @@ void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N, ...@@ -423,7 +449,6 @@ void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
C, &N); C, &N);
return; return;
#endif #endif
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
......
...@@ -66,6 +66,8 @@ extern void* mklml_dso_handle; ...@@ -66,6 +66,8 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_free); \ __macro(cblas_dgemm_free); \
__macro(cblas_sgemm_batch); \ __macro(cblas_sgemm_batch); \
__macro(cblas_dgemm_batch); \ __macro(cblas_dgemm_batch); \
__macro(cblas_sdot); \
__macro(cblas_ddot); \
__macro(vsAdd); \ __macro(vsAdd); \
__macro(vdAdd); \ __macro(vdAdd); \
__macro(vsMul); \ __macro(vsMul); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册