diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 6734df1530893777fca3ccf66b1e8aab40e41cfc..9f3a81f22cc52bef719f472e43f91bc81dfe2af6 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -168,6 +168,9 @@ class Blas { template void SCAL(int n, const T a, T* x) const; + template + T ASUM(int n, T* x, int inc) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, @@ -269,6 +272,11 @@ class BlasT : private Blas { Base()->template SCAL(args...); } + template + T ASUM(ARGS... args) const { + return Base()->template ASUM(args...); + } + template void BatchedGEMM(ARGS... args) const { Base()->template BatchedGEMM(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 93bf7c7c88db36807143b136ea800d6e5e49dd43..c84087bb1e4849b27d53e05f046c93f631150f6f 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -84,6 +84,11 @@ struct CBlas { platform::dynload::cblas_sscal(args...); } + template + static float ASUM(ARGS... args) { + return platform::dynload::cblas_sasum(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -174,6 +179,11 @@ struct CBlas { platform::dynload::cblas_dscal(args...); } + template + static double ASUM(ARGS... args) { + return platform::dynload::cblas_dasum(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -268,6 +278,7 @@ struct CBlas { static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; + static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -476,6 +487,21 @@ void Blas::SCAL(int n, const T a, T *x) const { #endif } +template <> +template +T Blas::ASUM(int n, T *x, int inc) const { + auto sum = static_cast(0.0); +#ifdef PADDLE_WITH_MKLML + sum = CBlas::ASUM(n, x, inc); +#else + // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum + for (int c = 0; c < n; ++c) { + sum += x[c]; + } +#endif + return sum; +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 0f3e5b20086378da8ef1138a5f5c005b724f7fa2..31ed5196668954bc387423c34a0667622db71373 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -100,11 +100,8 @@ class SoftmaxFunctor> { blas.VEXP(num_classes * batch_size, out_data, out_data); for (int n = 0; n < batch_size; ++n) { - entities[n] = out_data[n * num_classes]; - for (int c = 1; c < num_classes; ++c) { - entities[n] += out_data[n * num_classes + c]; - } - blas.SCAL(num_classes, 1.0f / entities[n], &out_data[n * num_classes]); + auto sum = blas.ASUM(num_classes, &out_data[n * num_classes], 1); + blas.SCAL(num_classes, 1.0f / sum, &out_data[n * num_classes]); } } }; diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 8eb5c7691efe930e9f79ad6a381cb290107d1a14..91829d5761bfdd1f9806af6589a2967fe866fec8 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -36,9 +36,7 @@ class SoftmaxKernel : public framework::OpKernel { Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); #ifdef PADDLE_ON_INFERENCE - math::SoftmaxFunctor< - DeviceContext, T, - std::is_same::value>()( + math::SoftmaxFunctor()( context.template device_context(), &X_2d, &Out_2d); #else math::SoftmaxFunctor()( diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 9273e9b1e72f0ad7abd6c20d4a34283fbe24378a..f0a973662360fd9ff35e1006cce937d86f3e563c 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -68,6 +68,8 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_batch); \ __macro(cblas_sdot); \ __macro(cblas_ddot); \ + __macro(cblas_sasum); \ + __macro(cblas_dasum); \ __macro(cblas_sscal); \ __macro(cblas_dscal); \ __macro(vsAdd); \