From a2203d0466462fcde20bdd80d79a0f7964760eb8 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 12:08:31 +0800 Subject: [PATCH] add cblas dot --- paddle/fluid/operators/math/blas.h | 3 +++ paddle/fluid/operators/math/blas_impl.h | 27 ++++++++++++++++++++++++- paddle/fluid/platform/dynload/mklml.h | 2 ++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 295431347a4..96d481f7395 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -153,6 +153,9 @@ class Blas { void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; + template + T DOT(int n, const T* x, const T* y) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index d39a3e7f6eb..bbd9d4b60a2 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -73,6 +73,11 @@ struct CBlas { platform::dynload::cblas_sgemv(args...); } + template + static float DOT(ARGS... args) { + return platform::dynload::cblas_sdot(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -138,6 +143,11 @@ struct CBlas { platform::dynload::cblas_dgemv(args...); } + template + static double DOT(ARGS... args) { + return platform::dynload::cblas_ddot(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -210,6 +220,7 @@ struct CBlas { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } + static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -352,6 +363,21 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +T Blas::DOT(int n, const T *x, const T *y) const { +#ifdef PADDLE_WITH_MKLML + return CBlas::DOT(n, x, y); +#else + // try to find if openblas support cblas_dot + T sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] * y[i]; + } + return sum; +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, @@ -423,7 +449,6 @@ void Blas::MatMul(const int M, const int N, CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, C, &N); return; - #endif CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 15ad4a3b40b..6efa160df05 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -66,6 +66,8 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_free); \ __macro(cblas_sgemm_batch); \ __macro(cblas_dgemm_batch); \ + __macro(cblas_sdot); \ + __macro(cblas_ddot); \ __macro(vsAdd); \ __macro(vdAdd); \ __macro(vsMul); \ -- GitLab