diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 295431347a449196b415f197c9e2ea8480b873d1..96d481f739540781163f8a96f7ee0321c9b4dd3b 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -153,6 +153,9 @@ class Blas { void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; + template + T DOT(int n, const T* x, const T* y) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index d39a3e7f6eb91ded2585a2b71868a1649376ac86..bbd9d4b60a2bbd3144cf338c0648ea7af89d837c 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -73,6 +73,11 @@ struct CBlas { platform::dynload::cblas_sgemv(args...); } + template + static float DOT(ARGS... args) { + return platform::dynload::cblas_sdot(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -138,6 +143,11 @@ struct CBlas { platform::dynload::cblas_dgemv(args...); } + template + static double DOT(ARGS... args) { + return platform::dynload::cblas_ddot(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -210,6 +220,7 @@ struct CBlas { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } + static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -352,6 +363,21 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +T Blas::DOT(int n, const T *x, const T *y) const { +#ifdef PADDLE_WITH_MKLML + return CBlas::DOT(n, x, y); +#else + // try to find if openblas support cblas_dot + T sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] * y[i]; + } + return sum; +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, @@ -423,7 +449,6 @@ void Blas::MatMul(const int M, const int N, CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, C, &N); return; - #endif CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 15ad4a3b40b1ad13a10dd37449c6f6f3e2029df6..6efa160df0560c724a3309f346c06e47404e72c2 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -66,6 +66,8 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_free); \ __macro(cblas_sgemm_batch); \ __macro(cblas_dgemm_batch); \ + __macro(cblas_sdot); \ + __macro(cblas_ddot); \ __macro(vsAdd); \ __macro(vdAdd); \ __macro(vsMul); \