From 9c687090365e0721c04c7623b08651c6e211b7aa Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 24 Oct 2018 16:12:57 +0800 Subject: [PATCH] Accelerate sequence_pool functor --- .../fluid/operators/math/sequence_pooling.cc | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 235b5405f..fd93e431a 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -231,9 +231,30 @@ class SequencePoolGradFunctor { math::SetConstant functor; functor(context, in_grad, 0); } + + if (pooltype == "SUM") { + auto lod = in_grad->lod()[0]; + int64_t out_w = out_grad.numel() / out_grad.dims()[0]; + int64_t in_w = in_grad->numel() / in_grad->dims()[0]; + PADDLE_ENFORCE(in_w == out_w); + const T* out_g_data = out_grad.data(); + T* in_g_data = in_grad->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t in_offset = lod[i]; + const T* out_pos = out_g_data + i * out_w; + T* in_pos = in_g_data + in_offset; + for (int r = 0; r != h; ++r) { + blas.VCOPY(in_w, out_pos, in_pos + r * in_w); + } + } + + return; + } + auto lod = in_grad->lod()[0]; auto& place = *context.eigen_device(); - auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { auto in_g_t = in_grad->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -247,12 +268,6 @@ class SequencePoolGradFunctor { if (pooltype == "AVERAGE") { in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); - } else if (pooltype == "SUM") { - const T* out_g_data = out_g_t.data(); - T* in_g_data = in_g_t.mutable_data(context.GetPlace()); - for (int r = 0; r != h; ++r) { - blas.VCOPY(w, out_g_data, in_g_data + r * w); - } } else if (pooltype == "SQRT") { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); -- GitLab