“86fd748808dee2448bf368f3b1389f91ec6e9d29”上不存在“develop/doc/howto/usage/cmd_parameter/index_en.html”
提交 8bfa1fa9 编写于 作者: J Jacek Czaja

- ASUM MKL integration

上级 05b7ee7e
...@@ -168,6 +168,9 @@ class Blas { ...@@ -168,6 +168,9 @@ class Blas {
template <typename T> template <typename T>
void SCAL(int n, const T a, T* x) const; void SCAL(int n, const T a, T* x) const;
template <typename T>
T ASUM(int n, T* x, int inc) const;
template <typename T> template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C, int K, T alpha, const T* A, const T* B, T beta, T* C,
...@@ -269,6 +272,11 @@ class BlasT : private Blas<DeviceContext> { ...@@ -269,6 +272,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template SCAL<T>(args...); Base()->template SCAL<T>(args...);
} }
template <typename... ARGS>
T ASUM(ARGS... args) const {
return Base()->template ASUM<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void BatchedGEMM(ARGS... args) const { void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...); Base()->template BatchedGEMM<T>(args...);
......
...@@ -84,6 +84,11 @@ struct CBlas<float> { ...@@ -84,6 +84,11 @@ struct CBlas<float> {
platform::dynload::cblas_sscal(args...); platform::dynload::cblas_sscal(args...);
} }
template <typename... ARGS>
static float ASUM(ARGS... args) {
return platform::dynload::cblas_sasum(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...); platform::dynload::cblas_sgemm_batch(args...);
...@@ -174,6 +179,11 @@ struct CBlas<double> { ...@@ -174,6 +179,11 @@ struct CBlas<double> {
platform::dynload::cblas_dscal(args...); platform::dynload::cblas_dscal(args...);
} }
template <typename... ARGS>
static double ASUM(ARGS... args) {
return platform::dynload::cblas_dasum(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_dgemm_batch(args...); platform::dynload::cblas_dgemm_batch(args...);
...@@ -268,6 +278,7 @@ struct CBlas<platform::float16> { ...@@ -268,6 +278,7 @@ struct CBlas<platform::float16> {
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) { static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
...@@ -476,6 +487,23 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const { ...@@ -476,6 +487,23 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
#endif #endif
} }
template <>
template <typename T>
T Blas<platform::CPUDeviceContext>::ASUM(int n, T *x, int inc) const {
auto sum = static_cast<T>(0.0);
#ifdef PADDLE_WITH_MKLML
sum = Blas<T>::ASUM(n, x, inc);
#else
//TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum
for (int c = 0; c < n; ++c) {
sum += x[c];
}
#endif
return sum;
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha, void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......
...@@ -100,11 +100,8 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -100,11 +100,8 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
blas.VEXP(num_classes * batch_size, out_data, out_data); blas.VEXP(num_classes * batch_size, out_data, out_data);
for (int n = 0; n < batch_size; ++n) { for (int n = 0; n < batch_size; ++n) {
entities[n] = out_data[n * num_classes]; auto sum = blas.ASUM(num_classes, &out_data[n * num_classes], 1);
for (int c = 1; c < num_classes; ++c) { blas.SCAL(num_classes, 1.0f / sum, &out_data[n * num_classes]);
entities[n] += out_data[n * num_classes + c];
}
blas.SCAL(num_classes, 1.0f / entities[n], &out_data[n * num_classes]);
} }
} }
}; };
......
...@@ -36,9 +36,7 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -36,9 +36,7 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
#ifdef PADDLE_ON_INFERENCE #ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor< math::SoftmaxFunctor<DeviceContext, T, true>()(
DeviceContext, T,
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d); context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
#else #else
math::SoftmaxFunctor<DeviceContext, T, false>()( math::SoftmaxFunctor<DeviceContext, T, false>()(
......
...@@ -68,6 +68,8 @@ extern void* mklml_dso_handle; ...@@ -68,6 +68,8 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_batch); \ __macro(cblas_dgemm_batch); \
__macro(cblas_sdot); \ __macro(cblas_sdot); \
__macro(cblas_ddot); \ __macro(cblas_ddot); \
__macro(cblas_sasum); \
__macro(cblas_dasum); \
__macro(cblas_sscal); \ __macro(cblas_sscal); \
__macro(cblas_dscal); \ __macro(cblas_dscal); \
__macro(vsAdd); \ __macro(vsAdd); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册