未验证 提交 e0e98627 编写于 作者: W wawltor 提交者: GitHub

Cherrypick fix rnn batch size diff (#30462)

* fix the rnn mask memory bug for out of read

* update the code for the rnn
上级 8ab8c620
......@@ -960,9 +960,10 @@ class RNNCPUKernel : public framework::OpKernel<T> {
if (has_seq_length) {
sequence_length = ctx.Input<Tensor>("SequenceLength");
}
if (!dropout_mask->IsInitialized()) {
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
if (dropout_mask->IsInitialized()) {
if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
}
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
// init the output and allocate the memory
output->mutable_data<T>(ctx.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册