diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index f25d3d3f1ee1f89d46b8e7c88ca68048f5203544..69318a6598c8c69eceab7216df6382537153d34f 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -103,6 +103,58 @@ class MaxSeqPoolGradFunctor { } }; +template +class LastSeqPoolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& input, + framework::Tensor* output) { + // Create pointers to input and output data + auto* in_data = input.data(); + auto* out_data = output->data(); + + // Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; + auto lod = input.lod()[0]; + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + // Point to the begin of next sequence + in_data += seq_len * item_size; + // Copy the last item of sequence to output + std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); + out_data += item_size; + } + } +}; + +template +class FirstSeqPoolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& input, + framework::Tensor* output) { + // Create pointers to input and output data + auto* in_data = input.data(); + auto* out_data = output->data(); + + // Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; + auto lod = input.lod()[0]; + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + // Copy the first item of sequence to output + std::memcpy(out_data, in_data, item_size * sizeof(T)); + // Point to the next sequence + in_data += seq_len * item_size; + out_data += item_size; + } + } +}; + template class SequencePoolFunctor { public: @@ -116,6 +168,16 @@ class SequencePoolFunctor { max_pool(context, input, output, index); return; } + if (pooltype == "LAST") { + math::LastSeqPoolFunctor last_pool; + last_pool(context, input, output); + return; + } + if (pooltype == "FIRST") { + math::FirstSeqPoolFunctor first_pool; + first_pool(context, input, output); + return; + } auto lod = input.lod()[0]; auto& place = *context.eigen_device(); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { @@ -133,10 +195,6 @@ class SequencePoolFunctor { } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); - } else if (pooltype == "LAST") { - out_e.device(place) = in_e.chip(h - 1, 0); - } else if (pooltype == "FIRST") { - out_e.device(place) = in_e.chip(0, 0); } else { PADDLE_THROW("unsupported pooling pooltype"); }