From 7ae461eb13bdb3d1c116b650110c9968686fc20c Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 24 May 2019 09:18:58 +0800 Subject: [PATCH] [CPU] refine cpu softmax bwd (#17534) * refine softmax fwd test=develop * refine cpu softmax bwd test=develop * fix batch size test=develop * fix compile issue with gpu test=develop * add value clip --- paddle/fluid/operators/math/cpu_vec.h | 73 +++++++++++++++++++++ paddle/fluid/operators/math/cpu_vec_test.cc | 64 ++++++++++++++++++ paddle/fluid/operators/math/softmax.h | 2 +- paddle/fluid/operators/math/softmax_impl.h | 54 +++++++++++++-- 4 files changed, 186 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 65b601fc678..4406a558718 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -176,6 +176,79 @@ inline void vec_sum(const size_t n, const float* x, #endif } +template +inline void vec_mul(const size_t n, const T* x, const T* y, T* z) { + for (size_t i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +} + +template <> +inline void vec_mul(const size_t n, const float* x, + const float* y, float* z) { +#ifdef __AVX__ + constexpr unsigned int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_mul(n, x, y, z); + return; + } + + unsigned int i = 0, end = 0; + end = n & ~(block - 1); + for (i = 0; i < end; i += block) { + _mm256_storeu_ps( + z + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), _mm256_loadu_ps(y + i))); + } + + for (; i < n; i++) { + z[i] = x[i] * y[i]; + } +#else + vec_mul(n, x, y, z); +#endif +} + +template +inline void vec_mul_reduce(const size_t n, const T* x, const T* y, T* z) { + z[0] = x[0] * y[0]; + for (size_t i = 1; i < n; ++i) { + z[0] += x[i] * y[i]; + } +} + +template <> +inline void vec_mul_reduce(const size_t n, const float* x, + const float* y, float* z) { +#ifdef __AVX__ + constexpr unsigned int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_mul_reduce(n, x, y, z); + return; + } + + unsigned int i = 0, end = 0; + z[0] = 0.f; + + end = n & ~(block - 1); + __m256 tmp = _mm256_setzero_ps(); + for (i = 0; i < end; i += block) { + tmp = _mm256_add_ps( + tmp, _mm256_mul_ps(_mm256_loadu_ps(x + i), _mm256_loadu_ps(y + i))); + } + + __m256 hsum = _mm256_hadd_ps(tmp, tmp); + hsum = _mm256_add_ps(hsum, _mm256_permute2f128_ps(hsum, hsum, 0x1)); + _mm_store_ss(z, _mm_hadd_ps(_mm256_castps256_ps128(hsum), + _mm256_castps256_ps128(hsum))); + + for (; i < n; i++) { + z[0] += x[i] * y[i]; + } +#else + vec_mul_reduce(n, x, y, z); +#endif +} + template inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index 04932a07094..f2f80f836fd 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -199,6 +199,70 @@ TEST(CpuVecTest, vec_clip) { vec_clip); } +template +void compare_mul( + size_t n, std::function tgt, + std::function ref) { + std::vector x(n), y(n); + std::vector ztgt(n), zref(n); + + RandomVec(n, x.data(), static_cast(-2), static_cast(2)); + RandomVec(n, y.data(), static_cast(-2), static_cast(2)); + + const T* x_data = x.data(); + const T* y_data = y.data(); + T* ztgt_data = ztgt.data(); + T* zref_data = zref.data(); + + tgt(n, x_data, y_data, ztgt_data); + ref(n, x_data, y_data, zref_data); + for (size_t i = 0; i < n; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } +} + +TEST(CpuVecTest, vec_mul) { + namespace platform = paddle::platform; + using namespace paddle::operators::math; // NOLINT + for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + compare_mul(sz, vec_mul, vec_mul); + compare_mul(sz, vec_mul, + vec_mul); + } + compare_mul(30U, vec_mul, vec_mul); +} + +template +void compare_mul_reduce( + size_t n, std::function tgt, + std::function ref) { + std::vector x(n), y(n); + T ztgt_data, zref_data; + + RandomVec(n, x.data(), static_cast(-2), static_cast(2)); + RandomVec(n, y.data(), static_cast(-2), static_cast(2)); + + const T* x_data = x.data(); + const T* y_data = y.data(); + + tgt(n, x_data, y_data, &ztgt_data); + ref(n, x_data, y_data, &zref_data); + EXPECT_NEAR(ztgt_data, zref_data, 1e-3); +} + +TEST(CpuVecTest, vec_mul_reduce) { + namespace platform = paddle::platform; + using namespace paddle::operators::math; // NOLINT + for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + compare_mul_reduce(sz, vec_mul_reduce, + vec_mul_reduce); + compare_mul_reduce(sz, vec_mul_reduce, + vec_mul_reduce); + } + compare_mul_reduce(30U, vec_mul_reduce, + vec_mul_reduce); +} + template void TestInplace(const int n, std::function tgt, std::function ref) { diff --git a/paddle/fluid/operators/math/softmax.h b/paddle/fluid/operators/math/softmax.h index a7a30a71e4c..7a4306efef9 100644 --- a/paddle/fluid/operators/math/softmax.h +++ b/paddle/fluid/operators/math/softmax.h @@ -27,7 +27,7 @@ class SoftmaxFunctor { const framework::Tensor* X, framework::Tensor* Y); }; -template +template class SoftmaxGradFunctor { public: void operator()(const DeviceContext& context, const int axis_dim, diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 30790c9c1a9..4fb03cdce0c 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -140,16 +140,16 @@ class SoftmaxFunctor> { }; template -void SoftmaxGradFunctor::operator()( - const DeviceContext& context, const int axis_dim, - const framework::Tensor* y, const framework::Tensor* y_grad, - framework::Tensor* x_grad) { +void SoftmaxGradEigen(const DeviceContext& context, const int axis_dim, + const framework::Tensor* y, + const framework::Tensor* y_grad, + framework::Tensor* x_grad) { auto softmax = EigenMatrix::From(*y); auto softmax_grad = EigenMatrix::From(*y_grad); auto logits_grad = EigenMatrix::From(*x_grad); - const int kBatchDim = 0; - const int kClassDim = 1; + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; const int batch_size = softmax.dimension(kBatchDim); const int num_classes = softmax.dimension(kClassDim); @@ -169,6 +169,48 @@ void SoftmaxGradFunctor::operator()( logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax; } +template +void SoftmaxGradFunctor::operator()( + const DeviceContext& context, const int axis_dim, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad) { + SoftmaxGradEigen(context, axis_dim, y, y_grad, x_grad); +} + +template +class SoftmaxGradFunctor> { + public: + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad) { + auto out_dims = y->dims(); + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + const int num_classes = out_dims[kClassDim]; + const int batch_size = out_dims[kBatchDim]; + const int num_remain = num_classes / axis_dim; + + if (num_remain == 1 && platform::MayIUse(platform::avx)) { + const T* out_data = y->data(); + const T* out_grad = y_grad->data(); + T* in_grad = x_grad->data(); + for (int bs = 0; bs < batch_size; ++bs) { + T scalar; + vec_mul_reduce(num_classes, out_grad, out_data, + &scalar); + scalar *= static_cast(-1); + vec_add_bias(num_classes, scalar, out_grad, in_grad); + vec_mul(num_classes, out_data, in_grad, in_grad); + out_data += num_classes; + out_grad += num_classes; + in_grad += num_classes; + } + } else { + SoftmaxGradEigen(context, axis_dim, y, y_grad, x_grad); + } + } +}; + } // namespace math } // namespace operators } // namespace paddle -- GitLab