diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index c2849d9776bc0f85d410ed2d710a35540728e060..7be8539a7b0f1890898fd386a3056601fda8a7c3 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -157,6 +157,31 @@ class FirstSeqPoolFunctor { } }; +template +class SumSeqPoolGradFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& out_grad, + framework::LoDTensor* in_grad) { + auto lod = in_grad->lod()[0]; + int64_t out_w = out_grad.numel() / out_grad.dims()[0]; + int64_t in_w = in_grad->numel() / in_grad->dims()[0]; + PADDLE_ENFORCE(in_w == out_w); + const T* out_g_data = out_grad.data(); + T* in_g_data = in_grad->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t in_offset = lod[i] * in_w; + const T* out_pos = out_g_data + i * out_w; + T* in_pos = in_g_data + in_offset; + for (int r = 0; r != h; ++r) { + blas.VCOPY(in_w, out_pos, in_pos + r * in_w); + } + } + } +}; + template class SequencePoolFunctor { public: @@ -233,23 +258,8 @@ class SequencePoolGradFunctor { } if (pooltype == "SUM") { - auto lod = in_grad->lod()[0]; - int64_t out_w = out_grad.numel() / out_grad.dims()[0]; - int64_t in_w = in_grad->numel() / in_grad->dims()[0]; - PADDLE_ENFORCE(in_w == out_w); - const T* out_g_data = out_grad.data(); - T* in_g_data = in_grad->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { - int64_t h = static_cast(lod[i + 1] - lod[i]); - int64_t in_offset = lod[i] * in_w; - const T* out_pos = out_g_data + i * out_w; - T* in_pos = in_g_data + in_offset; - for (int r = 0; r != h; ++r) { - blas.VCOPY(in_w, out_pos, in_pos + r * in_w); - } - } - + math::SumSeqPoolGradFunctor sum_pool_grad; + sum_pool_grad(context, out_grad, in_grad); return; }