未验证 提交 0cd422b6 编写于 作者: L Lucas 提交者: GitHub

fix bugs in rnn op (#55656)

上级 690ffe81
...@@ -44,7 +44,7 @@ void RnnKernel(const Context& dev_ctx, ...@@ -44,7 +44,7 @@ void RnnKernel(const Context& dev_ctx,
} }
dropout_state->Resize(out->dims()); dropout_state->Resize(out->dims());
dev_ctx.template Alloc<T>(dropout_state); dev_ctx.template Alloc<uint8_t>(dropout_state);
phi::funcs::SetConstant<phi::XPUContext, uint8_t> ones; phi::funcs::SetConstant<phi::XPUContext, uint8_t> ones;
ones(dev_ctx, dropout_state, static_cast<uint8_t>(1)); ones(dev_ctx, dropout_state, static_cast<uint8_t>(1));
...@@ -97,7 +97,7 @@ void RnnKernel(const Context& dev_ctx, ...@@ -97,7 +97,7 @@ void RnnKernel(const Context& dev_ctx,
int gate_num = 4; int gate_num = 4;
int hidden_data_idx = (num_layers - 1); int hidden_data_idx = (num_layers - 1);
hidden_data_idx += (gate_num + 1) * num_layers; hidden_data_idx += (gate_num + 2) * num_layers;
const int& block_size = direction_num * seq_len * batch_size * hidden_size; const int& block_size = direction_num * seq_len * batch_size * hidden_size;
reserve->Resize({hidden_data_idx, block_size}); reserve->Resize({hidden_data_idx, block_size});
dev_ctx.template Alloc<T>(reserve); dev_ctx.template Alloc<T>(reserve);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册