From 8bfa1fa9bbdfa68c61fa8c861b1070ee8f72b879 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Tue, 27 Nov 2018 09:59:07 +0100 Subject: [PATCH] - ASUM MKL integration --- paddle/fluid/operators/math/blas.h | 8 +++++++ paddle/fluid/operators/math/blas_impl.h | 28 ++++++++++++++++++++++ paddle/fluid/operators/math/softmax_impl.h | 7 ++---- paddle/fluid/operators/softmax_op.h | 4 +--- paddle/fluid/platform/dynload/mklml.h | 2 ++ 5 files changed, 41 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 6734df153..9f3a81f22 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -168,6 +168,9 @@ class Blas { template void SCAL(int n, const T a, T* x) const; + template + T ASUM(int n, T* x, int inc) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, @@ -269,6 +272,11 @@ class BlasT : private Blas { Base()->template SCAL(args...); } + template + T ASUM(ARGS... args) const { + return Base()->template ASUM(args...); + } + template void BatchedGEMM(ARGS... args) const { Base()->template BatchedGEMM(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 93bf7c7c8..ffdfc69b9 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -84,6 +84,11 @@ struct CBlas { platform::dynload::cblas_sscal(args...); } + template + static float ASUM(ARGS... args) { + return platform::dynload::cblas_sasum(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -174,6 +179,11 @@ struct CBlas { platform::dynload::cblas_dscal(args...); } + template + static double ASUM(ARGS... args) { + return platform::dynload::cblas_dasum(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -268,6 +278,7 @@ struct CBlas { static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; + static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -476,6 +487,23 @@ void Blas::SCAL(int n, const T a, T *x) const { #endif } + +template <> +template +T Blas::ASUM(int n, T *x, int inc) const { + auto sum = static_cast(0.0); +#ifdef PADDLE_WITH_MKLML + sum = Blas::ASUM(n, x, inc); +#else + //TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum + for (int c = 0; c < n; ++c) { + sum += x[c]; + } +#endif + return sum; +} + + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 0f3e5b200..31ed51966 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -100,11 +100,8 @@ class SoftmaxFunctor> { blas.VEXP(num_classes * batch_size, out_data, out_data); for (int n = 0; n < batch_size; ++n) { - entities[n] = out_data[n * num_classes]; - for (int c = 1; c < num_classes; ++c) { - entities[n] += out_data[n * num_classes + c]; - } - blas.SCAL(num_classes, 1.0f / entities[n], &out_data[n * num_classes]); + auto sum = blas.ASUM(num_classes, &out_data[n * num_classes], 1); + blas.SCAL(num_classes, 1.0f / sum, &out_data[n * num_classes]); } } }; diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 8eb5c7691..91829d576 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -36,9 +36,7 @@ class SoftmaxKernel : public framework::OpKernel { Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); #ifdef PADDLE_ON_INFERENCE - math::SoftmaxFunctor< - DeviceContext, T, - std::is_same::value>()( + math::SoftmaxFunctor()( context.template device_context(), &X_2d, &Out_2d); #else math::SoftmaxFunctor()( diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 9273e9b1e..f0a973662 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -68,6 +68,8 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_batch); \ __macro(cblas_sdot); \ __macro(cblas_ddot); \ + __macro(cblas_sasum); \ + __macro(cblas_dasum); \ __macro(cblas_sscal); \ __macro(cblas_dscal); \ __macro(vsAdd); \ -- GitLab