提交 6644ce79 编写于 作者: T tensor-tang

add mklml vmul

上级 ff92b6ba
...@@ -134,6 +134,9 @@ class Blas { ...@@ -134,6 +134,9 @@ class Blas {
template <typename T> template <typename T>
void VADD(int n, const T* x, const T* y, T* z) const; void VADD(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VMUL(int n, const T* x, const T* y, T* z) const;
template <typename T> template <typename T>
void VCOPY(int n, const T* x, T* y) const; void VCOPY(int n, const T* x, T* y) const;
...@@ -202,6 +205,11 @@ class BlasT : private Blas<DeviceContext> { ...@@ -202,6 +205,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VADD<T>(args...); Base()->template VADD<T>(args...);
} }
template <typename... ARGS>
void VMUL(ARGS... args) const {
Base()->template VMUL<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void VCOPY(ARGS... args) const { void VCOPY(ARGS... args) const {
Base()->template VCOPY<T>(args...); Base()->template VCOPY<T>(args...);
......
...@@ -82,6 +82,11 @@ struct CBlas<float> { ...@@ -82,6 +82,11 @@ struct CBlas<float> {
static void VADD(ARGS... args) { static void VADD(ARGS... args) {
platform::dynload::vsAdd(args...); platform::dynload::vsAdd(args...);
} }
template <typename... ARGS>
static void VMUL(ARGS... args) {
platform::dynload::vsMul(args...);
}
}; };
template <> template <>
...@@ -142,6 +147,11 @@ struct CBlas<double> { ...@@ -142,6 +147,11 @@ struct CBlas<double> {
static void VADD(ARGS... args) { static void VADD(ARGS... args) {
platform::dynload::vdAdd(args...); platform::dynload::vdAdd(args...);
} }
template <typename... ARGS>
static void VMUL(ARGS... args) {
platform::dynload::vdMul(args...);
}
}; };
#else #else
...@@ -199,6 +209,7 @@ struct CBlas<platform::float16> { ...@@ -199,6 +209,7 @@ struct CBlas<platform::float16> {
static void SMM_GEMM(...) { static void SMM_GEMM(...) {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
} }
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) { static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
...@@ -374,6 +385,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y, ...@@ -374,6 +385,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
#endif #endif
} }
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMUL(n, x, y, z);
#else
// try to find if openblas support vmul
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
#endif
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha, void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......
...@@ -49,25 +49,27 @@ extern void* mklml_dso_handle; ...@@ -49,25 +49,27 @@ extern void* mklml_dso_handle;
#define MKLML_ROUTINE_EACH(__macro) \ #define MKLML_ROUTINE_EACH(__macro) \
__macro(cblas_sgemm); \ __macro(cblas_sgemm); \
__macro(cblas_saxpy); \
__macro(cblas_scopy); \
__macro(cblas_sgemv); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm); \ __macro(cblas_dgemm); \
__macro(cblas_saxpy); \
__macro(cblas_daxpy); \ __macro(cblas_daxpy); \
__macro(cblas_scopy); \
__macro(cblas_dcopy); \ __macro(cblas_dcopy); \
__macro(cblas_sgemv); \
__macro(cblas_dgemv); \ __macro(cblas_dgemv); \
__macro(cblas_dgemm_batch); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(cblas_sgemm_alloc); \ __macro(cblas_sgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_alloc); \ __macro(cblas_dgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_dgemm_pack); \ __macro(cblas_dgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_dgemm_compute); \ __macro(cblas_dgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_free); \ __macro(cblas_dgemm_free); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm_batch); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(vsMul); \
__macro(vdMul); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册