From 3dd66390b2702fe3083fee5e84f2ad6d5322b76b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 13:13:58 +0800 Subject: [PATCH] add blas vexp --- paddle/fluid/operators/math/blas.h | 3 +++ paddle/fluid/operators/math/blas_impl.h | 24 ++++++++++++++++++++++++ paddle/fluid/platform/dynload/mklml.h | 2 ++ 3 files changed, 29 insertions(+) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index fc02534a696..5aba170221f 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -149,6 +149,9 @@ class Blas { template void VCOPY(int n, const T* x, T* y) const; + template + void VEXP(int n, const T* x, T* y) const; + template void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index b7c56e8df18..eaad83ba182 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -97,6 +97,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vsMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vsExp(args...); + } }; template <> @@ -172,6 +177,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vdMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vdExp(args...); + } }; #else @@ -230,6 +240,7 @@ struct CBlas { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } + static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML @@ -374,6 +385,19 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +void Blas::VEXP(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VEXP(n, x, y); +#else + // try to find if openblas support vexp + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +#endif +} + template <> template T Blas::DOT(int n, const T *x, const T *y) const { diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index e50ea6740aa..aa20553ceff 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -74,6 +74,8 @@ extern void* mklml_dso_handle; __macro(vdAdd); \ __macro(vsMul); \ __macro(vdMul); \ + __macro(vsExp); \ + __macro(vdExp); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); -- GitLab