未验证 提交 c7a6a1f9 编写于 作者: W winter-wang 提交者: GitHub

fix runtime crash when rnn model inference, test=develop (#31833) (#31846)

上级 d44d1730
......@@ -105,6 +105,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
"merge_lod_tensor",
"equal",
"sequence_pool",
"recurrent",
"lod_reset"};
for (auto* tmp : node->inputs) {
CHECK(tmp->IsOp());
......
......@@ -211,9 +211,10 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
auto ctx = executor.Prepare(
*program, block->ID(), Attr<std::vector<std::string>>(
kSkipEagerDeletionVars) /*skip_ref_cnt_vars*/);
auto ctx = executor.Prepare(*program, block->ID(),
Attr<std::vector<std::string>>(
kSkipEagerDeletionVars), /*skip_ref_cnt_vars*/
true);
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
......@@ -256,16 +257,6 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
// Link inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
if (i > 0) {
LinkTensorWithCallback(scope, Outputs(kOutputs), cur_scope,
Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
framework::Tensor src_slice =
src_tensor.Slice(seq_offset, seq_offset + 1);
dst_tensor->ShareDataWith(src_slice);
});
}
// Linked now, execute!
executor.RunPreparedContext(ctx.get(), &cur_scope,
......@@ -285,6 +276,14 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
// early.
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
});
} else {
LinkTensorWithCallback(
cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
});
}
scopes.ForwardNext();
......
......@@ -189,10 +189,10 @@ def batch_norm(x,
if in_dygraph_mode():
# for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout",
data_format, "use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", use_global_stats, "trainable_statistics",
trainable_statistics)
attrs = ("momentum", momentum, "epsilon", epsilon, "is_test",
not training, "data_layout", data_format, "use_mkldnn", False,
"fuse_with_relu", False, "use_global_stats", use_global_stats,
"trainable_statistics", trainable_statistics)
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
x, weight, bias, running_mean, running_var, mean_out, variance_out,
*attrs)
......@@ -207,6 +207,7 @@ def batch_norm(x,
attrs = {
"momentum": momentum,
"epsilon": epsilon,
"is_test": not training,
"data_layout": data_format,
"use_mkldnn": False,
"fuse_with_relu": False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册