未验证 提交 e7f28d6c 编写于 作者: W winter-wang 提交者: GitHub

fix runtime crash when rnn model inference, test=develop (#31833)

上级 5d89ec36
...@@ -103,6 +103,7 @@ void MemoryOptimizePass::CollectVarMemorySize( ...@@ -103,6 +103,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
"merge_lod_tensor", "merge_lod_tensor",
"equal", "equal",
"sequence_pool", "sequence_pool",
"recurrent",
"lod_reset"}; "lod_reset"};
for (auto* tmp : node->inputs) { for (auto* tmp : node->inputs) {
CHECK(tmp->IsOp()); CHECK(tmp->IsOp());
......
...@@ -210,9 +210,10 @@ void RecurrentOp::RunImpl(const framework::Scope &scope, ...@@ -210,9 +210,10 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto ctx = executor.Prepare( auto ctx = executor.Prepare(*program, block->ID(),
*program, block->ID(), Attr<std::vector<std::string>>( Attr<std::vector<std::string>>(
kSkipEagerDeletionVars) /*skip_ref_cnt_vars*/); kSkipEagerDeletionVars), /*skip_ref_cnt_vars*/
true);
static std::mutex mutex; static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
...@@ -255,16 +256,6 @@ void RecurrentOp::RunImpl(const framework::Scope &scope, ...@@ -255,16 +256,6 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
// Link inside::output -> outside::output // Link inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output // outside::output[seq_offset: seq_offset + 1] = inside::output
executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_); executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
if (i > 0) {
LinkTensorWithCallback(scope, Outputs(kOutputs), cur_scope,
Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
framework::Tensor src_slice =
src_tensor.Slice(seq_offset, seq_offset + 1);
dst_tensor->ShareDataWith(src_slice);
});
}
// Linked now, execute! // Linked now, execute!
executor.RunPreparedContext(ctx.get(), &cur_scope, executor.RunPreparedContext(ctx.get(), &cur_scope,
...@@ -284,6 +275,14 @@ void RecurrentOp::RunImpl(const framework::Scope &scope, ...@@ -284,6 +275,14 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
// early. // early.
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out); framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
}); });
} else {
LinkTensorWithCallback(
cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
});
} }
scopes.ForwardNext(); scopes.ForwardNext();
......
...@@ -188,10 +188,10 @@ def batch_norm(x, ...@@ -188,10 +188,10 @@ def batch_norm(x,
if in_dygraph_mode(): if in_dygraph_mode():
# for dygraph need tuple # for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout", attrs = ("momentum", momentum, "epsilon", epsilon, "is_test",
data_format, "use_mkldnn", False, "fuse_with_relu", False, not training, "data_layout", data_format, "use_mkldnn", False,
"use_global_stats", use_global_stats, "trainable_statistics", "fuse_with_relu", False, "use_global_stats", use_global_stats,
trainable_statistics) "trainable_statistics", trainable_statistics)
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm( batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
x, weight, bias, running_mean, running_var, mean_out, variance_out, x, weight, bias, running_mean, running_var, mean_out, variance_out,
*attrs) *attrs)
...@@ -205,6 +205,7 @@ def batch_norm(x, ...@@ -205,6 +205,7 @@ def batch_norm(x,
attrs = { attrs = {
"momentum": momentum, "momentum": momentum,
"epsilon": epsilon, "epsilon": epsilon,
"is_test": not training,
"data_layout": data_format, "data_layout": data_format,
"use_mkldnn": False, "use_mkldnn": False,
"fuse_with_relu": False, "fuse_with_relu": False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册