From 53185fde11cfd97fdd2f5c4787b3f7d1d6031461 Mon Sep 17 00:00:00 2001 From: bingyanghuang Date: Tue, 11 Sep 2018 11:55:51 +0800 Subject: [PATCH] Rewrite sequence pooling last and first mode with memcpy and clean code --- .../fluid/operators/math/sequence_pooling.cc | 43 +++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index f25d3d3f1ee..1ffbe3d820b 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -103,6 +103,39 @@ class MaxSeqPoolGradFunctor { } }; +template +class LastFirstSeqPoolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& input, framework::Tensor* output, + const std::string pooltype) { + auto* in_data = input.data(); + auto* out_data = output->data(); + int64_t word_len = input.numel() / input.dims()[0]; + auto lod = input.lod()[0]; + auto dims = input.dims(); + if (pooltype == "LAST"){ + for (int i=0; i < static_cast(lod.size()) - 1; ++i ){ + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + in_data += seq_len* word_len; + std::memcpy(out_data,(in_data-word_len),word_len*sizeof(int)); + out_data += word_len; + + } + } + else if(pooltype == "FIRST"){ + for (int i=0; i < static_cast(lod.size()) - 1; ++i ){ + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + std::memcpy(out_data,in_data,word_len*sizeof(int)); + in_data += seq_len * word_len; + out_data += word_len; + + } + + } + } +}; + template class SequencePoolFunctor { public: @@ -116,6 +149,12 @@ class SequencePoolFunctor { max_pool(context, input, output, index); return; } + if (pooltype == "LAST" || pooltype == "FIRST") { + math::LastFirstSeqPoolFunctor lastfirst_pool; + lastfirst_pool(context, input, output, pooltype); + return; + } + auto lod = input.lod()[0]; auto& place = *context.eigen_device(); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { @@ -133,10 +172,6 @@ class SequencePoolFunctor { } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); - } else if (pooltype == "LAST") { - out_e.device(place) = in_e.chip(h - 1, 0); - } else if (pooltype == "FIRST") { - out_e.device(place) = in_e.chip(0, 0); } else { PADDLE_THROW("unsupported pooling pooltype"); } -- GitLab