未验证 提交 e4dcf0bf 编写于 作者: J Jack Zhou 提交者: GitHub

Fix RNN OP multi-threads predict bug (#41529) (#41560)

上级 b810961d
......@@ -832,11 +832,13 @@ void RnnKernel(const Context& dev_ctx,
DenseTensor* dropout_state,
std::vector<DenseTensor*> state,
DenseTensor* reserve) {
if (dropout_state->IsInitialized()) {
if (dropout_state->numel() != out->numel()) dropout_state->clear();
if (!is_test) {
if (dropout_state->IsInitialized()) {
if (dropout_state->numel() != out->numel()) dropout_state->clear();
}
const auto& out_dim = out->dims();
Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state);
}
const auto& out_dim = out->dims();
Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state);
// init the output and allocate the memory
dev_ctx.template Alloc<T>(out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册