提交 0ec1f65c 编写于 作者: T tensor-tang

fix blas dot and add cblas scal

上级 a2203d04
......@@ -156,6 +156,9 @@ class Blas {
template <typename T>
T DOT(int n, const T* x, const T* y) const;
template <typename T>
void SCAL(int n, const T a, const T* x) const;
template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C,
......
......@@ -78,6 +78,11 @@ struct CBlas<float> {
return platform::dynload::cblas_sdot(args...);
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
platform::dynload::cblas_sscal(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...);
......@@ -148,6 +153,11 @@ struct CBlas<double> {
return platform::dynload::cblas_ddot(args...);
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
platform::dynload::cblas_dscal(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_dgemm_batch(args...);
......@@ -221,6 +231,7 @@ struct CBlas<platform::float16> {
}
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
......@@ -367,7 +378,7 @@ template <>
template <typename T>
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
#ifdef PADDLE_WITH_MKLML
return CBlas<T>::DOT(n, x, y);
return CBlas<T>::DOT(n, x, 1, y, 1);
#else
// try to find if openblas support cblas_dot
T sum = 0;
......@@ -378,6 +389,20 @@ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a,
const T *x) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::SCAL(n, a, x, 1);
#else
// try to find if openblas support cblas_scal
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......
......@@ -68,6 +68,8 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_batch); \
__macro(cblas_sdot); \
__macro(cblas_ddot); \
__macro(cblas_sscal); \
__macro(cblas_dscal); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(vsMul); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册