From 0600b370ea714f7793a15f6756ec376161ad3d5c Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 23 May 2019 19:04:02 +0800 Subject: [PATCH] [CPU] refine softmax op fwd on CPU (#17522) * refine softmax fwd test=develop * fix compile issue wih gpu test=develop * add value clip to avoid exp --- paddle/fluid/operators/math/cpu_vec.h | 83 ++++++++++++++++++++- paddle/fluid/operators/math/cpu_vec_test.cc | 61 ++++++++++++++- paddle/fluid/operators/math/softmax_impl.h | 61 +++++++++++++-- 3 files changed, 194 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 57726956cfb..65b601fc678 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -54,7 +54,14 @@ inline void vec_scal(const int n, const T a, T* x) { #ifdef PADDLE_WITH_MKLML template <> inline void vec_exp(const int n, const float* x, float* y) { - platform::dynload::vsExp(n, x, y); + constexpr int small_enough = 128; + if (n < small_enough) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } + } else { + platform::dynload::vsExp(n, x, y); + } } template <> @@ -128,6 +135,47 @@ inline void vec_scal(const int n, const float a, vec_scal(n, a, x, y); } +template +inline void vec_sum(const size_t n, const T* x, T* s) { + s[0] = x[0]; + for (size_t i = 1; i < n; ++i) { + s[0] += x[i]; + } +} + +template <> +inline void vec_sum(const size_t n, const float* x, + float* s) { +#ifdef __AVX__ + constexpr unsigned int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_sum(n, x, s); + return; + } + + unsigned int i, end; + i = end = 0; + s[0] = 0.f; + + end = n & ~(block - 1); + __m256 tmp = _mm256_setzero_ps(); + for (i = 0; i < end; i += block) { + tmp = _mm256_add_ps(tmp, _mm256_load_ps(x + i)); + } + + __m256 hsum = _mm256_hadd_ps(tmp, tmp); + hsum = _mm256_add_ps(hsum, _mm256_permute2f128_ps(hsum, hsum, 0x1)); + _mm_store_ss(s, _mm_hadd_ps(_mm256_castps256_ps128(hsum), + _mm256_castps256_ps128(hsum))); + + for (; i < n; i++) { + s[0] += x[i]; + } +#else + vec_sum(n, x, s); +#endif +} + template inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { @@ -242,6 +290,39 @@ inline void vec_cross(const int n, const float* x, vec_cross(n, x, y, z, out); } +template +inline void vec_clip(const size_t n, const T a, const T* x, T* y) { + for (size_t i = 0; i < n; ++i) { + y[i] = x[i] < a ? a : x[i]; + } +} + +template <> +inline void vec_clip(const size_t n, const float a, + const float* x, float* y) { +#ifdef __AVX__ + constexpr unsigned int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_clip(n, a, x, y); + return; + } + + unsigned int i = 0, end = 0; + end = n & ~(block - 1); + __m256 threshold = _mm256_set1_ps(a); + + for (i = 0; i < end; i += block) { + _mm256_storeu_ps(y + i, _mm256_max_ps(_mm256_loadu_ps(x + i), threshold)); + } + + for (; i < n; i++) { + y[i] = x[i] < a ? a : x[i]; + } +#else + vec_clip(n, a, x, y); +#endif +} + template inline void vec_add_bias(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index 28eb9cadc9d..04932a07094 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -65,12 +65,11 @@ void ref_relu(const int n, const T* x, T* y) { } template -void RandomVec(const int n, T* a) { +void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), + const T upper = static_cast(20.f)) { static unsigned int seed = 100; std::mt19937 rng(seed++); std::uniform_real_distribution uniform_dist(0, 1); - const T lower = static_cast(-20.f); - const T upper = static_cast(20.f); for (int i = 0; i < n; ++i) { a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } @@ -144,6 +143,62 @@ TEST(CpuVecTest, relu) { TestAndBench(30, vec_relu, ref_relu); } +template +void compare_sum(size_t n, std::function tgt, + std::function ref) { + std::vector x(n); + T ytgt_data, yref_data; + RandomVec(n, x.data(), static_cast(-2), static_cast(2)); + + const T* x_data = x.data(); + tgt(n, x_data, &ytgt_data); + ref(n, x_data, &yref_data); + EXPECT_NEAR(ytgt_data, yref_data, 1e-3); +} + +TEST(CpuVecTest, vec_sum) { + namespace platform = paddle::platform; + using namespace paddle::operators::math; // NOLINT + for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + compare_sum(sz, vec_sum, vec_sum); + compare_sum(sz, vec_sum, + vec_sum); + } + compare_sum(30U, vec_sum, vec_sum); +} + +template +void compare_clip( + size_t n, T threshold, + std::function tgt, + std::function ref) { + std::vector x(n); + std::vector ytgt(n), yref(n); + RandomVec(n, x.data(), static_cast(-2), static_cast(2)); + + const T* x_data = x.data(); + T* yref_data = yref.data(); + T* ytgt_data = ytgt.data(); + tgt(n, threshold, x_data, ytgt_data); + ref(n, threshold, x_data, yref_data); + for (int i = 0; i < n; ++i) { + EXPECT_NEAR(ytgt_data[i], yref_data[i], 1e-3); + } +} + +TEST(CpuVecTest, vec_clip) { + namespace platform = paddle::platform; + using namespace paddle::operators::math; // NOLINT + for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + compare_clip(sz, -4.f, vec_clip, + vec_clip); + compare_clip(sz, -1.1f, vec_clip, + vec_clip); + } + compare_clip(30U, 1.0, vec_clip, + vec_clip); +} + template void TestInplace(const int n, std::function tgt, std::function ref) { diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 6f6f33345f5..30790c9c1a9 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -17,6 +17,8 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -34,16 +36,15 @@ struct ValueClip { } }; -template -void SoftmaxFunctor::operator()( - const DeviceContext& context, const int axis_dim, - const framework::Tensor* X, framework::Tensor* Y) { +template +void SoftmaxEigen(const DeviceContext& context, const int axis_dim, + const framework::Tensor* X, framework::Tensor* Y) { + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); - const int kBatchDim = 0; - const int kClassDim = 1; - const int batch_size = logits.dimension(kBatchDim); const int num_classes = logits.dimension(kClassDim); const int num_remain = num_classes / axis_dim; @@ -70,12 +71,58 @@ void SoftmaxFunctor::operator()( .broadcast(one_axis)); } +template +void SoftmaxFunctor::operator()( + const DeviceContext& context, const int axis_dim, + const framework::Tensor* X, framework::Tensor* Y) { + SoftmaxEigen(context, axis_dim, X, Y); +} + template using enable_if_CPU = typename std::enable_if< std::is_same::value>::type; +template +class SoftmaxFunctor> { + public: + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* X, framework::Tensor* Y) { + auto in_dims = X->dims(); + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + + const int num_classes = in_dims[kClassDim]; + const int batch_size = in_dims[kBatchDim]; + const int num_remain = num_classes / axis_dim; + + if (num_remain == 1 && platform::MayIUse(platform::avx)) { + const T* in_data = X->data(); + T* out_data = Y->data(); + for (int bs = 0; bs < batch_size; ++bs) { + T max_val = *std::max_element(in_data, in_data + num_classes); + max_val *= static_cast(-1); + vec_add_bias(num_classes, max_val, in_data, out_data); + vec_clip(num_classes, static_cast(-64), out_data, + out_data); + vec_exp(num_classes, out_data, out_data); + + T sum = 0; + vec_sum(num_classes, out_data, &sum); + sum = static_cast(1) / sum; + vec_scal(num_classes, sum, out_data, out_data); + + in_data += num_classes; + out_data += num_classes; + } + } else { + SoftmaxEigen(context, axis_dim, X, Y); + } + } +}; + template class SoftmaxFunctor> { + public: void operator()(const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { auto in_dims = X->dims(); -- GitLab