未验证 提交 09203e46 编写于 作者: J Jack Zhou 提交者: GitHub

Fix RNN OP multi-threads predict bug (#41529)

上级 d4710dfe
...@@ -832,11 +832,13 @@ void RnnKernel(const Context& dev_ctx, ...@@ -832,11 +832,13 @@ void RnnKernel(const Context& dev_ctx,
DenseTensor* dropout_state, DenseTensor* dropout_state,
std::vector<DenseTensor*> state, std::vector<DenseTensor*> state,
DenseTensor* reserve) { DenseTensor* reserve) {
if (!is_test) {
if (dropout_state->IsInitialized()) { if (dropout_state->IsInitialized()) {
if (dropout_state->numel() != out->numel()) dropout_state->clear(); if (dropout_state->numel() != out->numel()) dropout_state->clear();
} }
const auto& out_dim = out->dims(); const auto& out_dim = out->dims();
Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state); Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state);
}
// init the output and allocate the memory // init the output and allocate the memory
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册