From e4dcf0bf5b4992b3ba449d195bed57c974bb2031 Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Tue, 12 Apr 2022 09:16:57 +0800 Subject: [PATCH] Fix RNN OP multi-threads predict bug (#41529) (#41560) --- paddle/phi/kernels/cpu/rnn_kernel.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/cpu/rnn_kernel.cc b/paddle/phi/kernels/cpu/rnn_kernel.cc index 4d3976b0aba..cae97eb0764 100644 --- a/paddle/phi/kernels/cpu/rnn_kernel.cc +++ b/paddle/phi/kernels/cpu/rnn_kernel.cc @@ -832,11 +832,13 @@ void RnnKernel(const Context& dev_ctx, DenseTensor* dropout_state, std::vector state, DenseTensor* reserve) { - if (dropout_state->IsInitialized()) { - if (dropout_state->numel() != out->numel()) dropout_state->clear(); + if (!is_test) { + if (dropout_state->IsInitialized()) { + if (dropout_state->numel() != out->numel()) dropout_state->clear(); + } + const auto& out_dim = out->dims(); + Full(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state); } - const auto& out_dim = out->dims(); - Full(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state); // init the output and allocate the memory dev_ctx.template Alloc(out); -- GitLab