diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 650d9086d423cc62de571fc9c83f4d045ed939c1..8d8042a0563a21dad216ffd53a474322c378ace6 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -337,6 +337,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, std::unique_ptr> gc; if (max_memory_size >= 0) { + ctx->ResetReferenceCount(); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { gc.reset(new DefaultStreamGarbageCollector( @@ -357,11 +358,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, std::vector erase_vars; for (auto& input : op->Inputs()) { for (auto& input_name : input.second) { - auto it = ctx->ref_cnts_.find(input_name); - if (it == ctx->ref_cnts_.end()) continue; + auto it = ctx->cur_ref_cnts_.find(input_name); + if (it == ctx->cur_ref_cnts_.end()) continue; if (it->second == 1) { // should delete it erase_vars.emplace_back(input_name); - ctx->ref_cnts_.erase(input_name); + ctx->cur_ref_cnts_.erase(input_name); } else { --(it->second); } @@ -370,11 +371,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, for (auto& output : op->Outputs()) { for (auto& output_name : output.second) { - auto it = ctx->ref_cnts_.find(output_name); - if (it == ctx->ref_cnts_.end()) continue; + auto it = ctx->cur_ref_cnts_.find(output_name); + if (it == ctx->cur_ref_cnts_.end()) continue; if (it->second == 1) { erase_vars.emplace_back(output_name); - ctx->ref_cnts_.erase(output_name); + ctx->cur_ref_cnts_.erase(output_name); } else { --(it->second); } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index b746268760570c56c720c6e3b8fe04f8e3f75b4e..f0cc1338a8af50030a70a9797cbcd1b0567272b5 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -72,11 +72,14 @@ struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ~ExecutorPrepareContext(); + void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; } + const framework::ProgramDesc& prog_; size_t block_id_; std::vector> ops_; std::unordered_map ref_cnts_; + std::unordered_map cur_ref_cnts_; }; class Executor {