diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index fc02534a696222050045f0c35de313c5e55e91c3..5aba170221fa9d5a9b686af83a92385a4d7f48bb 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -149,6 +149,9 @@ class Blas { template void VCOPY(int n, const T* x, T* y) const; + template + void VEXP(int n, const T* x, T* y) const; + template void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index b7c56e8df18f70df95f0b5bcdc07dbb753ff9035..eaad83ba18248eee57a39ab7889ac462b4baac4e 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -97,6 +97,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vsMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vsExp(args...); + } }; template <> @@ -172,6 +177,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vdMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vdExp(args...); + } }; #else @@ -230,6 +240,7 @@ struct CBlas { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } + static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML @@ -374,6 +385,19 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +void Blas::VEXP(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VEXP(n, x, y); +#else + // try to find if openblas support vexp + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +#endif +} + template <> template T Blas::DOT(int n, const T *x, const T *y) const { diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index e50ea6740aa8e56d57ef91434c7dae6069f69836..aa20553ceffceded09447693c6e92f55fb48702d 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -74,6 +74,8 @@ extern void* mklml_dso_handle; __macro(vdAdd); \ __macro(vsMul); \ __macro(vdMul); \ + __macro(vsExp); \ + __macro(vdExp); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);