diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 767bbb524f45b45a26bad3011acf65afd4b10eb8..7eab87601594e4405b66479a6d390659c153ba79 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -419,7 +419,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, int64_t max_memory_size = GetEagerDeletionThreshold(); std::unique_ptr gc; - if (max_memory_size >= 0) { + // skip while_op and while_grad_op temporarily + if (max_memory_size >= 0 && !keep_kids) { ctx->ResetReferenceCount(); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 06920a47ee0caf629ee30e009be5dff74f1c71d0..5ab0918c486cc56c7d55f24f4952a013044971ee 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -365,51 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { // while operator could be renamed. while_grad->SetAttr("original_output_grad", output_grads_list); - /* The following codes are used in eager deletion mode */ - std::unordered_set bwd_skip_vars; - if (framework::GetEagerDeletionThreshold() >= 0) { - std::unordered_set fwd_skip_vars; - for (auto *op_desc : grad_block->AllOps()) { - auto skippable = [&](const std::string &name) { - return !grad_block->HasVar(name) && - (fwd_block->HasVarRecursive(name) || - parent_block->HasVarRecursive(name)); - }; - for (auto &in_arg_name : op_desc->InputArgumentNames()) { - if (skippable(in_arg_name)) { - fwd_skip_vars.insert(in_arg_name); - } - } - - for (auto &out_arg_name : op_desc->OutputArgumentNames()) { - if (skippable(out_arg_name)) { - fwd_skip_vars.insert(out_arg_name); - } - } - } - - if (!fwd_skip_vars.empty()) { - // FIXME(zjl): ugly const_cast here, maybe we should find a better way - // to modify forward while_op - auto &fwd_while_op = const_cast(ForwardOp()); - fwd_while_op.SetAttr(kSkipEagerDeletionVars, - std::vector(fwd_skip_vars.begin(), - fwd_skip_vars.end())); - } - - // Find backward skip vars - auto fwd_input = Input(kX); - for (size_t i = 0; i < igs.size(); ++i) { - if (igs[i] == framework::kEmptyVarName) { - continue; - } - bwd_skip_vars.insert(igs[i]); - bwd_skip_vars.insert(framework::GradVarName(fwd_input[i])); - } - } - while_grad->SetAttr( - kSkipEagerDeletionVars, - std::vector(bwd_skip_vars.begin(), bwd_skip_vars.end())); + while_grad->SetAttr(kSkipEagerDeletionVars, std::vector()); return std::unique_ptr(while_grad); }