diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index f48c321c731a94894895a735737f6d3e34ba0853..ae63e47e94e0ad0eec90c793569b128f35d8561e 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -113,34 +113,35 @@ class LastFirstSeqPoolFunctor { auto* in_data = input.data(); auto* out_data = output->data(); - //Calculate length of each word - int64_t word_len = input.numel() / input.dims()[0]; + //Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; auto lod = input.lod()[0]; + int seq_num = static_cast(lod.size()) - 1; if (pooltype == "LAST"){ - for (int i=0; i < static_cast(lod.size()) - 1; ++i ){ + for (int i=0; i < seq_num; ++i ){ //Calculate the length of each sequence int64_t seq_len = static_cast(lod[i + 1] - lod[i]); //Point to the begin of next sequence - in_data += seq_len* word_len; - //Copy the last words to output - std::memcpy(out_data,(in_data-word_len),word_len*sizeof(T)); - out_data += word_len; - + in_data += seq_len* item_size; + //Copy the last item to output + std::memcpy(out_data,(in_data-item_size),item_size*sizeof(T)); + out_data += item_size; } } else if(pooltype == "FIRST"){ - for (int i=0; i < static_cast(lod.size()) - 1; ++i ){ + for (int i=0; i < seq_num; ++i ){ //Calculate the length of each sequence int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - //Copy the first words of sequence to output - std::memcpy(out_data,in_data,word_len*sizeof(T)); + //Copy the first item of sequence to output + std::memcpy(out_data,in_data,item_size*sizeof(T)); //Point to the next sequence - in_data += seq_len * word_len; - out_data += word_len; - + in_data += seq_len * item_size; + out_data += item_size; } - } + else { + PADDLE_THROW("it's not LAST or FIRST pool type"); + } } };