提交 1be85d01 编写于 作者: T tensor-tang

add mkl vsqr and vpow

上级 38f499df
...@@ -152,6 +152,12 @@ class Blas { ...@@ -152,6 +152,12 @@ class Blas {
template <typename T> template <typename T>
void VEXP(int n, const T* x, T* y) const; void VEXP(int n, const T* x, T* y) const;
template <typename T>
void VSQR(int n, const T* x, T* y) const;
template <typename T>
void VPOW(int n, const T* x, T alpha, T* y) const;
template <typename T> template <typename T>
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const; T* C) const;
...@@ -238,6 +244,16 @@ class BlasT : private Blas<DeviceContext> { ...@@ -238,6 +244,16 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VEXP<T>(args...); Base()->template VEXP<T>(args...);
} }
template <typename... ARGS>
void VSQR(ARGS... args) const {
Base()->template VSQR<T>(args...);
}
template <typename... ARGS>
void VPOW(ARGS... args) const {
Base()->template VPOW<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void GEMV(ARGS... args) const { void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...); Base()->template GEMV<T>(args...);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cmath>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -102,6 +103,16 @@ struct CBlas<float> { ...@@ -102,6 +103,16 @@ struct CBlas<float> {
static void VEXP(ARGS... args) { static void VEXP(ARGS... args) {
platform::dynload::vsExp(args...); platform::dynload::vsExp(args...);
} }
template <typename... ARGS>
static void VSQR(ARGS... args) {
platform::dynload::vsSqr(args...);
}
template <typename... ARGS>
static void VPOW(ARGS... args) {
platform::dynload::vsPowx(args...);
}
}; };
template <> template <>
...@@ -182,6 +193,16 @@ struct CBlas<double> { ...@@ -182,6 +193,16 @@ struct CBlas<double> {
static void VEXP(ARGS... args) { static void VEXP(ARGS... args) {
platform::dynload::vdExp(args...); platform::dynload::vdExp(args...);
} }
template <typename... ARGS>
static void VSQR(ARGS... args) {
platform::dynload::vdSqr(args...);
}
template <typename... ARGS>
static void VPOW(ARGS... args) {
platform::dynload::vdPowx(args...);
}
}; };
#else #else
...@@ -241,6 +262,8 @@ struct CBlas<platform::float16> { ...@@ -241,6 +262,8 @@ struct CBlas<platform::float16> {
} }
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); }
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
...@@ -398,6 +421,31 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const { ...@@ -398,6 +421,31 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
#endif #endif
} }
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VSQR(n, x, y);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::sqrt(x[i]);
}
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a,
T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VPOW(n, x, a, y);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::pow(x[i], a);
}
#endif
}
template <> template <>
template <typename T> template <typename T>
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const { T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
......
...@@ -76,6 +76,10 @@ extern void* mklml_dso_handle; ...@@ -76,6 +76,10 @@ extern void* mklml_dso_handle;
__macro(vdMul); \ __macro(vdMul); \
__macro(vsExp); \ __macro(vsExp); \
__macro(vdExp); \ __macro(vdExp); \
__macro(vsSqr); \
__macro(vdSqr); \
__macro(vsPowx); \
__macro(vdPowx); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册