diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 969a351d4f2d9f32769879f677008f58eeab976e..f531cc058bf0f0736b0f80f3a9dfb3c18842fd8a 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -104,11 +104,10 @@ class MaxSeqPoolGradFunctor { }; template -class LastFirstSeqPoolFunctor { +class LastSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::LoDTensor& input, framework::Tensor* output, - const std::string pooltype) { + const framework::LoDTensor& input, framework::Tensor* output) { // Create pointers to input and output data auto* in_data = input.data(); auto* out_data = output->data(); @@ -117,29 +116,40 @@ class LastFirstSeqPoolFunctor { int64_t item_size = input.numel() / input.dims()[0]; auto lod = input.lod()[0]; int seq_num = static_cast(lod.size()) - 1; - if (pooltype == "LAST") { - for (int i = 0; i < seq_num; ++i) { - // Calculate the length of each sequence - int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - // Point to the begin of next sequence - in_data += seq_len * item_size; - // Copy the last item of sequence to output - std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); - out_data += item_size; - } - } else if (pooltype == "FIRST") { - for (int i = 0; i < seq_num; ++i) { - // Calculate the length of each sequence - int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - // Copy the first item of sequence to output - std::memcpy(out_data, in_data, item_size * sizeof(T)); - // Point to the next sequence - in_data += seq_len * item_size; - out_data += item_size; + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + // Point to the begin of next sequence + in_data += seq_len * item_size; + // Copy the last item of sequence to output + std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); + out_data += item_size; + } + } +}; + +template +class FirstSeqPoolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& input, framework::Tensor* output) { + // Create pointers to input and output data + auto* in_data = input.data(); + auto* out_data = output->data(); + + // Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; + auto lod = input.lod()[0]; + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + // Copy the first item of sequence to output + std::memcpy(out_data, in_data, item_size * sizeof(T)); + // Point to the next sequence + in_data += seq_len * item_size; + out_data += item_size; } - } else { - PADDLE_THROW("it's not LAST or FIRST pool type"); - } } }; @@ -156,11 +166,17 @@ class SequencePoolFunctor { max_pool(context, input, output, index); return; } - if (pooltype == "LAST" || pooltype == "FIRST") { - math::LastFirstSeqPoolFunctor lastfirst_pool; - lastfirst_pool(context, input, output, pooltype); + if (pooltype == "LAST") { + math::LastSeqPoolFunctor last_pool; + last_pool(context, input, output); return; } + if (pooltype == "FIRST") { + math::FirstSeqPoolFunctor first_pool; + first_pool(context, input, output); + return; + } + auto lod = input.lod()[0]; auto& place = *context.eigen_device();