diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index f25d3d3f1ee1f89d46b8e7c88ca68048f5203544..1ffbe3d820b03dafa0bf7d0f2edba064487cd4ed 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -103,6 +103,39 @@ class MaxSeqPoolGradFunctor { } }; +template +class LastFirstSeqPoolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& input, framework::Tensor* output, + const std::string pooltype) { + auto* in_data = input.data(); + auto* out_data = output->data(); + int64_t word_len = input.numel() / input.dims()[0]; + auto lod = input.lod()[0]; + auto dims = input.dims(); + if (pooltype == "LAST"){ + for (int i=0; i < static_cast(lod.size()) - 1; ++i ){ + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + in_data += seq_len* word_len; + std::memcpy(out_data,(in_data-word_len),word_len*sizeof(int)); + out_data += word_len; + + } + } + else if(pooltype == "FIRST"){ + for (int i=0; i < static_cast(lod.size()) - 1; ++i ){ + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + std::memcpy(out_data,in_data,word_len*sizeof(int)); + in_data += seq_len * word_len; + out_data += word_len; + + } + + } + } +}; + template class SequencePoolFunctor { public: @@ -116,6 +149,12 @@ class SequencePoolFunctor { max_pool(context, input, output, index); return; } + if (pooltype == "LAST" || pooltype == "FIRST") { + math::LastFirstSeqPoolFunctor lastfirst_pool; + lastfirst_pool(context, input, output, pooltype); + return; + } + auto lod = input.lod()[0]; auto& place = *context.eigen_device(); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { @@ -133,10 +172,6 @@ class SequencePoolFunctor { } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); - } else if (pooltype == "LAST") { - out_e.device(place) = in_e.chip(h - 1, 0); - } else if (pooltype == "FIRST") { - out_e.device(place) = in_e.chip(0, 0); } else { PADDLE_THROW("unsupported pooling pooltype"); }