diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 96d481f739540781163f8a96f7ee0321c9b4dd3b..fc02534a696222050045f0c35de313c5e55e91c3 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -156,6 +156,9 @@ class Blas { template T DOT(int n, const T* x, const T* y) const; + template + void SCAL(int n, const T a, const T* x) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index bbd9d4b60a2bbd3144cf338c0648ea7af89d837c..b7c56e8df18f70df95f0b5bcdc07dbb753ff9035 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -78,6 +78,11 @@ struct CBlas { return platform::dynload::cblas_sdot(args...); } + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_sscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -148,6 +153,11 @@ struct CBlas { return platform::dynload::cblas_ddot(args...); } + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_dscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -221,6 +231,7 @@ struct CBlas { } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; + static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -367,7 +378,7 @@ template <> template T Blas::DOT(int n, const T *x, const T *y) const { #ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, y); + return CBlas::DOT(n, x, 1, y, 1); #else // try to find if openblas support cblas_dot T sum = 0; @@ -378,6 +389,20 @@ T Blas::DOT(int n, const T *x, const T *y) const { #endif } +template <> +template +void Blas::SCAL(int n, const T a, + const T *x) const { +#ifdef PADDLE_WITH_MKLML + CBlas::SCAL(n, a, x, 1); +#else + // try to find if openblas support cblas_scal + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 6efa160df0560c724a3309f346c06e47404e72c2..e50ea6740aa8e56d57ef91434c7dae6069f69836 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -68,6 +68,8 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_batch); \ __macro(cblas_sdot); \ __macro(cblas_ddot); \ + __macro(cblas_sscal); \ + __macro(cblas_dscal); \ __macro(vsAdd); \ __macro(vdAdd); \ __macro(vsMul); \