diff --git a/paddle/phi/kernels/cpu/rnn_kernel.cc b/paddle/phi/kernels/cpu/rnn_kernel.cc index 4d3976b0aba687a4b10599a22f3eb36a4fbd2caa..cae97eb0764533c903aef9b098b8857f55e0a470 100644 --- a/paddle/phi/kernels/cpu/rnn_kernel.cc +++ b/paddle/phi/kernels/cpu/rnn_kernel.cc @@ -832,11 +832,13 @@ void RnnKernel(const Context& dev_ctx, DenseTensor* dropout_state, std::vector state, DenseTensor* reserve) { - if (dropout_state->IsInitialized()) { - if (dropout_state->numel() != out->numel()) dropout_state->clear(); + if (!is_test) { + if (dropout_state->IsInitialized()) { + if (dropout_state->numel() != out->numel()) dropout_state->clear(); + } + const auto& out_dim = out->dims(); + Full(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state); } - const auto& out_dim = out->dims(); - Full(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state); // init the output and allocate the memory dev_ctx.template Alloc(out);