From 0cd422b66ad449af65685db3a345a423f90dbd71 Mon Sep 17 00:00:00 2001 From: Lucas <33367939+cqulilujia@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:25:42 +0800 Subject: [PATCH] fix bugs in rnn op (#55656) --- paddle/phi/kernels/xpu/rnn_kernel.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/xpu/rnn_kernel.cc b/paddle/phi/kernels/xpu/rnn_kernel.cc index 10fdfdbc4b9..87773c8a972 100644 --- a/paddle/phi/kernels/xpu/rnn_kernel.cc +++ b/paddle/phi/kernels/xpu/rnn_kernel.cc @@ -44,7 +44,7 @@ void RnnKernel(const Context& dev_ctx, } dropout_state->Resize(out->dims()); - dev_ctx.template Alloc(dropout_state); + dev_ctx.template Alloc(dropout_state); phi::funcs::SetConstant ones; ones(dev_ctx, dropout_state, static_cast(1)); @@ -97,7 +97,7 @@ void RnnKernel(const Context& dev_ctx, int gate_num = 4; int hidden_data_idx = (num_layers - 1); - hidden_data_idx += (gate_num + 1) * num_layers; + hidden_data_idx += (gate_num + 2) * num_layers; const int& block_size = direction_num * seq_len * batch_size * hidden_size; reserve->Resize({hidden_data_idx, block_size}); dev_ctx.template Alloc(reserve); -- GitLab