diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index ae63e47e94e0ad0eec90c793569b128f35d8561e..969a351d4f2d9f32769879f677008f58eeab976e 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -109,40 +109,38 @@ class LastFirstSeqPoolFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& input, framework::Tensor* output, const std::string pooltype) { - //Create pointers to input and output data - auto* in_data = input.data(); - auto* out_data = output->data(); - - //Calculate the size of each item in sequence - int64_t item_size = input.numel() / input.dims()[0]; - auto lod = input.lod()[0]; - int seq_num = static_cast(lod.size()) - 1; - if (pooltype == "LAST"){ - for (int i=0; i < seq_num; ++i ){ - //Calculate the length of each sequence - int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - //Point to the begin of next sequence - in_data += seq_len* item_size; - //Copy the last item to output - std::memcpy(out_data,(in_data-item_size),item_size*sizeof(T)); - out_data += item_size; - } - } - else if(pooltype == "FIRST"){ - for (int i=0; i < seq_num; ++i ){ - //Calculate the length of each sequence - int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - //Copy the first item of sequence to output - std::memcpy(out_data,in_data,item_size*sizeof(T)); - //Point to the next sequence - in_data += seq_len * item_size; - out_data += item_size; - } - } - else { - PADDLE_THROW("it's not LAST or FIRST pool type"); + // Create pointers to input and output data + auto* in_data = input.data(); + auto* out_data = output->data(); + + // Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; + auto lod = input.lod()[0]; + int seq_num = static_cast(lod.size()) - 1; + if (pooltype == "LAST") { + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + // Point to the begin of next sequence + in_data += seq_len * item_size; + // Copy the last item of sequence to output + std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); + out_data += item_size; } - } + } else if (pooltype == "FIRST") { + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + // Copy the first item of sequence to output + std::memcpy(out_data, in_data, item_size * sizeof(T)); + // Point to the next sequence + in_data += seq_len * item_size; + out_data += item_size; + } + } else { + PADDLE_THROW("it's not LAST or FIRST pool type"); + } + } }; template