diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h index 623904ec7c29cfa226a530ab138695711e504f14..dd9b611250bf61f82ee23a31717cb4363f0c388e 100644 --- a/paddle/operators/sequence_reshape_op.h +++ b/paddle/operators/sequence_reshape_op.h @@ -46,8 +46,8 @@ class SequenceReshapeKernel : public framework::OpKernel { } else { auto& out_lod = *out->mutable_lod(); out_lod.resize(1); - out_lod[0].clear(); - out_lod[0].push_back(0); + out_lod[0].resize(seq_num + 1); + out_lod[0][0] = 0; for (int i = 0; i < seq_num; ++i) { size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; size_t offset = 0; @@ -57,11 +57,10 @@ class SequenceReshapeKernel : public framework::OpKernel { "be divided by new_dim with no remainder for each " "sequence. The %dth sequence is invalid.", i + 1); - out_lod[0].push_back(out_lod[0].back() + offset); + out_lod[0][i + 1] = out_lod[0][i] + offset; } } - out->mutable_data(context.GetPlace()); framework::Copy(*in, context.GetPlace(), out); out->Resize({static_cast(out->lod()[0].back()), out_width}); }