From 0ec1f65cf110ee4e73a7bfa03456b52111426288 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 12:47:10 +0800 Subject: [PATCH] fix blas dot and add cblas scal --- paddle/fluid/operators/math/blas.h | 3 +++ paddle/fluid/operators/math/blas_impl.h | 27 ++++++++++++++++++++++++- paddle/fluid/platform/dynload/mklml.h | 2 ++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 96d481f7395..fc02534a696 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -156,6 +156,9 @@ class Blas { template T DOT(int n, const T* x, const T* y) const; + template + void SCAL(int n, const T a, const T* x) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index bbd9d4b60a2..b7c56e8df18 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -78,6 +78,11 @@ struct CBlas { return platform::dynload::cblas_sdot(args...); } + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_sscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -148,6 +153,11 @@ struct CBlas { return platform::dynload::cblas_ddot(args...); } + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_dscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -221,6 +231,7 @@ struct CBlas { } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; + static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -367,7 +378,7 @@ template <> template T Blas::DOT(int n, const T *x, const T *y) const { #ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, y); + return CBlas::DOT(n, x, 1, y, 1); #else // try to find if openblas support cblas_dot T sum = 0; @@ -378,6 +389,20 @@ T Blas::DOT(int n, const T *x, const T *y) const { #endif } +template <> +template +void Blas::SCAL(int n, const T a, + const T *x) const { +#ifdef PADDLE_WITH_MKLML + CBlas::SCAL(n, a, x, 1); +#else + // try to find if openblas support cblas_scal + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 6efa160df05..e50ea6740aa 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -68,6 +68,8 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_batch); \ __macro(cblas_sdot); \ __macro(cblas_ddot); \ + __macro(cblas_sscal); \ + __macro(cblas_dscal); \ __macro(vsAdd); \ __macro(vdAdd); \ __macro(vsMul); \ -- GitLab