未验证 提交 b68c4a1e 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Fix abnormal growth of memory in train mode and no_grad for Dy2St (#47398)

Fix abnormal growth of memory in train mode and no_grad for Dy2St
上级 8775545a
......@@ -403,7 +403,7 @@ inline void RunProgramAPI(
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(
out_scope_vec->front());
if (is_test) {
if (is_test || !egr::Controller::Instance().HasGrad()) {
VLOG(4) << "is test, set this scope can reused";
global_inner_scope->SetCanReuesd(true);
details::GcScope(global_inner_scope);
......@@ -481,7 +481,7 @@ inline void RunProgramAPI(
// Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());
// Step 5. Drop all children scopes while testing.
if (is_test) {
if (is_test || !egr::Controller::Instance().HasGrad()) {
out_scope_vec->front()->DropKids();
}
VLOG(2) << "The number of sub scopes after forward: "
......
......@@ -1058,8 +1058,6 @@ def _run_dygraph(instance, input, program_holder):
continue
persistable_var._set_grad_type(grad_var.type())
drop_scope_if_no_grad(instance, tmp_scope_vec)
# 3. prepare output, keep same form with inputs
outs = output_vars
if len(output_vars) == 1:
......@@ -1067,17 +1065,6 @@ def _run_dygraph(instance, input, program_holder):
return outs
def drop_scope_if_no_grad(instance, scope_vec):
tracer = framework._dygraph_tracer()
scope = (
scope_vec.value().get_scope()
if isinstance(scope_vec, (core.VarBase))
else scope_vec[0]
)
if (not instance._is_test) and (not tracer._has_grad):
scope.drop_kids()
def _run_static_graph(input, program_holder, trace_program):
main_program = framework.default_main_program()
param_var_names = _get_persistable_var_names(trace_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册