diff --git a/paddle/phi/kernels/xpu/rnn_kernel.cc b/paddle/phi/kernels/xpu/rnn_kernel.cc index 10fdfdbc4b91fc18b28ace59f94691391cb57f04..87773c8a972674a40c0f6474b5a89ceda788b738 100644 --- a/paddle/phi/kernels/xpu/rnn_kernel.cc +++ b/paddle/phi/kernels/xpu/rnn_kernel.cc @@ -44,7 +44,7 @@ void RnnKernel(const Context& dev_ctx, } dropout_state->Resize(out->dims()); - dev_ctx.template Alloc(dropout_state); + dev_ctx.template Alloc(dropout_state); phi::funcs::SetConstant ones; ones(dev_ctx, dropout_state, static_cast(1)); @@ -97,7 +97,7 @@ void RnnKernel(const Context& dev_ctx, int gate_num = 4; int hidden_data_idx = (num_layers - 1); - hidden_data_idx += (gate_num + 1) * num_layers; + hidden_data_idx += (gate_num + 2) * num_layers; const int& block_size = direction_num * seq_len * batch_size * hidden_size; reserve->Resize({hidden_data_idx, block_size}); dev_ctx.template Alloc(reserve);